From ff9761b23ac2390718225930aca34acff8fea7e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20G=C3=B6ttgens?= Date: Fri, 18 Aug 2023 14:08:26 +0200 Subject: [PATCH] Enhance Piracy detection by considering more methods (#156) --- src/piracy.jl | 97 ++++++++++--------- .../src/PiracyForeignProject.jl | 4 + test/test_piracy.jl | 83 ++++++++++------ 3 files changed, 109 insertions(+), 75 deletions(-) diff --git a/src/piracy.jl b/src/piracy.jl index a552b623..e29018bf 100644 --- a/src/piracy.jl +++ b/src/piracy.jl @@ -1,47 +1,58 @@ module Piracy -using Test: @test, @test_broken -using ..Aqua: walkmodules - -const DEFAULT_PKGS = (Base.PkgId(Base), Base.PkgId(Core)) - -function all_methods!( - mod::Module, - done_callables::Base.IdSet{Any}, # cached to prevent duplicates - result::Vector{Method}, - filter_default::Bool, -)::Vector{Method} - for name in names(mod; all = true, imported = true) - # names can list undefined symbols which cannot be eval'd - isdefined(mod, name) || continue - - # Skip closures - startswith(String(name), "#") && continue - val = getfield(mod, name) - - if !in(val, done_callables) - # In old versions of Julia, Vararg errors when methods is called on it - val === Vararg && continue - for method in methods(val) - # Default filtering removes all methods defined in DEFAULT_PKGs, - # since these may pirate each other. - if !(filter_default && in(Base.PkgId(method.module), DEFAULT_PKGS)) - push!(result, method) - end - end - push!(done_callables, val) +if VERSION >= v"1.6-" + using Test: is_in_mods +else + function is_in_mods(m::Module, recursive::Bool, mods) + while true + m in mods && return true + recursive || return false + p = parentmodule(m) + p === m && return false + m = p end end - result end -function all_methods(mod::Module; filter_default::Bool = true) - result = Method[] - done_callables = Base.IdSet() - walkmodules(mod) do mod - all_methods!(mod, done_callables, result, filter_default) +# based on Test/Test.jl#detect_ambiguities +# https://github.com/JuliaLang/julia/blob/v1.9.1/stdlib/Test/src/Test.jl#L1838-L1896 +function all_methods(mods::Module...; skip_deprecated::Bool = true) + meths = Method[] + mods = collect(mods)::Vector{Module} + + function examine(mt::Core.MethodTable) + examine(Base.MethodList(mt)) end - return result + function examine(ml::Base.MethodList) + for m in ml + is_in_mods(m.module, true, mods) || continue + push!(meths, m) + end + end + + work = Base.loaded_modules_array() + filter!(mod -> mod === parentmodule(mod), work) # some items in loaded_modules_array are not top modules (really just Base) + while !isempty(work) + mod = pop!(work) + for name in names(mod; all = true) + (skip_deprecated && Base.isdeprecated(mod, name)) && continue + isdefined(mod, name) || continue + f = Base.unwrap_unionall(getfield(mod, name)) + if isa(f, Module) && f !== mod && parentmodule(f) === mod && nameof(f) === name + push!(work, f) + elseif isa(f, DataType) && + isdefined(f.name, :mt) && + parentmodule(f) === mod && + nameof(f) === name && + f.name.mt !== Symbol.name.mt && + f.name.mt !== DataType.name.mt + examine(f.name.mt) + end + end + end + examine(Symbol.name.mt) + examine(DataType.name.mt) + return meths end ################################## @@ -141,7 +152,7 @@ function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as # fallback to general code return !(T in treat_as_own) && - !(T <: Function && T.instance in treat_as_own) && + !(T <: Function && isdefined(T, :instance) && T.instance in treat_as_own) && is_foreign(T, pkg; treat_as_own = treat_as_own) end @@ -162,12 +173,9 @@ function is_pirate(meth::Method; treat_as_own = Union{Function,Type}[]) ) end -hunt(mod::Module; from::Module = mod, kwargs...) = - hunt(Base.PkgId(mod); from = from, kwargs...) - -function hunt(pkg::Base.PkgId; from::Module, kwargs...) - filter(all_methods(from)) do method - Base.PkgId(method.module) === pkg && is_pirate(method; kwargs...) +function hunt(mod::Module; skip_deprecated::Bool = true, kwargs...) + filter(all_methods(mod; skip_deprecated = skip_deprecated)) do method + method.module === mod && is_pirate(method; kwargs...) end end @@ -182,6 +190,7 @@ See [Julia documentation](https://docs.julialang.org/en/v1/manual/style-guide/#A # Keyword Arguments - `broken::Bool = false`: If true, it uses `@test_broken` instead of `@test`. +- `skip_deprecated::Bool = true`: If true, it does not check deprecated methods. - `treat_as_own = Union{Function, Type}[]`: The types in this container are considered to be "owned" by the module `m`. This is useful for testing packages that deliberately commit some type piracy, e.g. modules diff --git a/test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl b/test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl index 03dea9e6..fc2c5c5e 100644 --- a/test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl +++ b/test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl @@ -3,4 +3,8 @@ module PiracyForeignProject struct ForeignType end struct ForeignParameterizedType{T} end +struct ForeignNonSingletonType + x::Int +end + end diff --git a/test/test_piracy.jl b/test/test_piracy.jl index f586c433..78de01d3 100644 --- a/test/test_piracy.jl +++ b/test/test_piracy.jl @@ -2,7 +2,7 @@ push!(LOAD_PATH, joinpath(@__DIR__, "pkgs", "PiracyForeignProject")) baremodule PiracyModule -using PiracyForeignProject: ForeignType, ForeignParameterizedType +using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType using Base: Base, @@ -44,6 +44,8 @@ export MyUnion Base.findfirst(::Set{Vector{Char}}, ::Int) = 1 Base.findfirst(::Union{Foo,Bar{Set{Unsigned}},UInt}, ::Tuple{Vararg{String}}) = 1 Base.findfirst(::AbstractChar, ::Set{T}) where {Int <: T <: Integer} = 1 +(::ForeignType)(x::Int8) = x + 1 +(::ForeignNonSingletonType)(x::Int8) = x + 1 # Piracy, but not for `ForeignType in treat_as_own` Base.findmax(::ForeignType, x::Int) = x + 1 @@ -55,29 +57,27 @@ Base.findmin(::ForeignParameterizedType{Int}, x::Int) = x + 1 Base.findmin(::Set{Vector{ForeignParameterizedType{Int}}}, x::Int) = x + 1 Base.findmin(::Union{Foo,ForeignParameterizedType{Int}}, x::Int) = x + 1 -# Assign them names in this module so they can be found by all_methods -a = Base.findfirst -b = Base.findlast -c = Base.findmax -d = Base.findmin end # PiracyModule using Aqua: Piracy -using PiracyForeignProject: ForeignType, ForeignParameterizedType +using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType # Get all methods - test length meths = filter(Piracy.all_methods(PiracyModule)) do m m.module == PiracyModule end -# 2 Foo constructors -# 2 from f -# 1 from MyUnion -# 6 from findlast -# 3 from findfirst -# 3 from findmax -# 3 from findmin -@test length(meths) == 2 + 2 + 1 + 6 + 3 + 3 + 3 +@test length(meths) == + 2 + # Foo constructors + 1 + # Bar constructor + 2 + # f + 1 + # MyUnion + 6 + # findlast + 3 + # findfirst + 1 + # ForeignType callable + 1 + # ForeignNonSingletonType callable + 3 + # findmax + 3 # findmin # Test what is foreign BasePkg = Base.PkgId(Base) @@ -90,24 +90,36 @@ ThisPkg = Base.PkgId(PiracyModule) @test !Piracy.is_foreign(Set{Int}, CorePkg; treat_as_own = []) # Test what is pirate -pirates = filter(m -> Piracy.is_pirate(m), meths) -@test length(pirates) == 3 + 3 + 3 +pirates = Piracy.hunt(PiracyModule) +@test length(pirates) == + 3 + # findfirst + 3 + # findmax + 3 + # findmin + 1 + # ForeignType callable + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findfirst, :findmax, :findmin] + m.name in [:findfirst, :findmax, :findmin, :ForeignType, :ForeignNonSingletonType] end # Test what is pirate (with treat_as_own=[ForeignType]) -pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignType]), meths) -@test length(pirates) == 3 + 3 +pirates = Piracy.hunt(PiracyModule, treat_as_own = [ForeignType]) +@test length(pirates) == + 3 + # findfirst + 3 + # findmin + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findfirst, :findmin] + m.name in [:findfirst, :findmin, :ForeignNonSingletonType] end # Test what is pirate (with treat_as_own=[ForeignParameterizedType]) -pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignParameterizedType]), meths) -@test length(pirates) == 3 + 3 +pirates = Piracy.hunt(PiracyModule, treat_as_own = [ForeignParameterizedType]) +@test length(pirates) == + 3 + # findfirst + 3 + # findmax + 1 + # ForeignType callable + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findfirst, :findmax] + m.name in [:findfirst, :findmax, :ForeignType, :ForeignNonSingletonType] end # Test what is pirate (with treat_as_own=[ForeignType, ForeignParameterizedType]) @@ -115,24 +127,33 @@ pirates = filter( m -> Piracy.is_pirate(m; treat_as_own = [ForeignType, ForeignParameterizedType]), meths, ) -@test length(pirates) == 3 +@test length(pirates) == + 3 + # findfirst + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findfirst] + m.name in [:findfirst, :ForeignNonSingletonType] end # Test what is pirate (with treat_as_own=[Base.findfirst, Base.findmax]) -pirates = - filter(m -> Piracy.is_pirate(m; treat_as_own = [Base.findfirst, Base.findmax]), meths) -@test length(pirates) == 3 +pirates = Piracy.hunt(PiracyModule, treat_as_own = [Base.findfirst, Base.findmax]) +@test length(pirates) == + 3 + # findmin + 1 + # ForeignType callable + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findmin] + m.name in [:findmin, :ForeignType, :ForeignNonSingletonType] end # Test what is pirate (excluding a cover of everything) pirates = filter( m -> Piracy.is_pirate( m; - treat_as_own = [ForeignType, ForeignParameterizedType, Base.findfirst], + treat_as_own = [ + ForeignType, + ForeignParameterizedType, + ForeignNonSingletonType, + Base.findfirst, + ], ), meths, )