Skip to content

Commit

Permalink
Enhance Piracy detection by considering more methods (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgoettgens authored Aug 18, 2023
1 parent a764d4a commit ff9761b
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 75 deletions.
97 changes: 53 additions & 44 deletions src/piracy.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,58 @@
module Piracy

using Test: @test, @test_broken
using ..Aqua: walkmodules

const DEFAULT_PKGS = (Base.PkgId(Base), Base.PkgId(Core))

function all_methods!(
mod::Module,
done_callables::Base.IdSet{Any}, # cached to prevent duplicates
result::Vector{Method},
filter_default::Bool,
)::Vector{Method}
for name in names(mod; all = true, imported = true)
# names can list undefined symbols which cannot be eval'd
isdefined(mod, name) || continue

# Skip closures
startswith(String(name), "#") && continue
val = getfield(mod, name)

if !in(val, done_callables)
# In old versions of Julia, Vararg errors when methods is called on it
val === Vararg && continue
for method in methods(val)
# Default filtering removes all methods defined in DEFAULT_PKGs,
# since these may pirate each other.
if !(filter_default && in(Base.PkgId(method.module), DEFAULT_PKGS))
push!(result, method)
end
end
push!(done_callables, val)
if VERSION >= v"1.6-"
using Test: is_in_mods
else
function is_in_mods(m::Module, recursive::Bool, mods)
while true
m in mods && return true
recursive || return false
p = parentmodule(m)
p === m && return false
m = p
end
end
result
end

function all_methods(mod::Module; filter_default::Bool = true)
result = Method[]
done_callables = Base.IdSet()
walkmodules(mod) do mod
all_methods!(mod, done_callables, result, filter_default)
# based on Test/Test.jl#detect_ambiguities
# https://github.com/JuliaLang/julia/blob/v1.9.1/stdlib/Test/src/Test.jl#L1838-L1896
function all_methods(mods::Module...; skip_deprecated::Bool = true)
meths = Method[]
mods = collect(mods)::Vector{Module}

function examine(mt::Core.MethodTable)
examine(Base.MethodList(mt))
end
return result
function examine(ml::Base.MethodList)
for m in ml
is_in_mods(m.module, true, mods) || continue
push!(meths, m)
end
end

work = Base.loaded_modules_array()
filter!(mod -> mod === parentmodule(mod), work) # some items in loaded_modules_array are not top modules (really just Base)
while !isempty(work)
mod = pop!(work)
for name in names(mod; all = true)
(skip_deprecated && Base.isdeprecated(mod, name)) && continue
isdefined(mod, name) || continue
f = Base.unwrap_unionall(getfield(mod, name))
if isa(f, Module) && f !== mod && parentmodule(f) === mod && nameof(f) === name
push!(work, f)
elseif isa(f, DataType) &&
isdefined(f.name, :mt) &&
parentmodule(f) === mod &&
nameof(f) === name &&
f.name.mt !== Symbol.name.mt &&
f.name.mt !== DataType.name.mt
examine(f.name.mt)
end
end
end
examine(Symbol.name.mt)
examine(DataType.name.mt)
return meths
end

##################################
Expand Down Expand Up @@ -141,7 +152,7 @@ function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as

# fallback to general code
return !(T in treat_as_own) &&
!(T <: Function && T.instance in treat_as_own) &&
!(T <: Function && isdefined(T, :instance) && T.instance in treat_as_own) &&
is_foreign(T, pkg; treat_as_own = treat_as_own)
end

Expand All @@ -162,12 +173,9 @@ function is_pirate(meth::Method; treat_as_own = Union{Function,Type}[])
)
end

hunt(mod::Module; from::Module = mod, kwargs...) =
hunt(Base.PkgId(mod); from = from, kwargs...)

function hunt(pkg::Base.PkgId; from::Module, kwargs...)
filter(all_methods(from)) do method
Base.PkgId(method.module) === pkg && is_pirate(method; kwargs...)
function hunt(mod::Module; skip_deprecated::Bool = true, kwargs...)
filter(all_methods(mod; skip_deprecated = skip_deprecated)) do method
method.module === mod && is_pirate(method; kwargs...)
end
end

Expand All @@ -182,6 +190,7 @@ See [Julia documentation](https://docs.julialang.org/en/v1/manual/style-guide/#A
# Keyword Arguments
- `broken::Bool = false`: If true, it uses `@test_broken` instead of
`@test`.
- `skip_deprecated::Bool = true`: If true, it does not check deprecated methods.
- `treat_as_own = Union{Function, Type}[]`: The types in this container
are considered to be "owned" by the module `m`. This is useful for
testing packages that deliberately commit some type piracy, e.g. modules
Expand Down
4 changes: 4 additions & 0 deletions test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ module PiracyForeignProject
struct ForeignType end
struct ForeignParameterizedType{T} end

struct ForeignNonSingletonType
x::Int
end

end
83 changes: 52 additions & 31 deletions test/test_piracy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ push!(LOAD_PATH, joinpath(@__DIR__, "pkgs", "PiracyForeignProject"))

baremodule PiracyModule

using PiracyForeignProject: ForeignType, ForeignParameterizedType
using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType

using Base:
Base,
Expand Down Expand Up @@ -44,6 +44,8 @@ export MyUnion
Base.findfirst(::Set{Vector{Char}}, ::Int) = 1
Base.findfirst(::Union{Foo,Bar{Set{Unsigned}},UInt}, ::Tuple{Vararg{String}}) = 1
Base.findfirst(::AbstractChar, ::Set{T}) where {Int <: T <: Integer} = 1
(::ForeignType)(x::Int8) = x + 1
(::ForeignNonSingletonType)(x::Int8) = x + 1

# Piracy, but not for `ForeignType in treat_as_own`
Base.findmax(::ForeignType, x::Int) = x + 1
Expand All @@ -55,29 +57,27 @@ Base.findmin(::ForeignParameterizedType{Int}, x::Int) = x + 1
Base.findmin(::Set{Vector{ForeignParameterizedType{Int}}}, x::Int) = x + 1
Base.findmin(::Union{Foo,ForeignParameterizedType{Int}}, x::Int) = x + 1

# Assign them names in this module so they can be found by all_methods
a = Base.findfirst
b = Base.findlast
c = Base.findmax
d = Base.findmin
end # PiracyModule

using Aqua: Piracy
using PiracyForeignProject: ForeignType, ForeignParameterizedType
using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType

# Get all methods - test length
meths = filter(Piracy.all_methods(PiracyModule)) do m
m.module == PiracyModule
end

# 2 Foo constructors
# 2 from f
# 1 from MyUnion
# 6 from findlast
# 3 from findfirst
# 3 from findmax
# 3 from findmin
@test length(meths) == 2 + 2 + 1 + 6 + 3 + 3 + 3
@test length(meths) ==
2 + # Foo constructors
1 + # Bar constructor
2 + # f
1 + # MyUnion
6 + # findlast
3 + # findfirst
1 + # ForeignType callable
1 + # ForeignNonSingletonType callable
3 + # findmax
3 # findmin

# Test what is foreign
BasePkg = Base.PkgId(Base)
Expand All @@ -90,49 +90,70 @@ ThisPkg = Base.PkgId(PiracyModule)
@test !Piracy.is_foreign(Set{Int}, CorePkg; treat_as_own = [])

# Test what is pirate
pirates = filter(m -> Piracy.is_pirate(m), meths)
@test length(pirates) == 3 + 3 + 3
pirates = Piracy.hunt(PiracyModule)
@test length(pirates) ==
3 + # findfirst
3 + # findmax
3 + # findmin
1 + # ForeignType callable
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findfirst, :findmax, :findmin]
m.name in [:findfirst, :findmax, :findmin, :ForeignType, :ForeignNonSingletonType]
end

# Test what is pirate (with treat_as_own=[ForeignType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignType]), meths)
@test length(pirates) == 3 + 3
pirates = Piracy.hunt(PiracyModule, treat_as_own = [ForeignType])
@test length(pirates) ==
3 + # findfirst
3 + # findmin
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findfirst, :findmin]
m.name in [:findfirst, :findmin, :ForeignNonSingletonType]
end

# Test what is pirate (with treat_as_own=[ForeignParameterizedType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignParameterizedType]), meths)
@test length(pirates) == 3 + 3
pirates = Piracy.hunt(PiracyModule, treat_as_own = [ForeignParameterizedType])
@test length(pirates) ==
3 + # findfirst
3 + # findmax
1 + # ForeignType callable
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findfirst, :findmax]
m.name in [:findfirst, :findmax, :ForeignType, :ForeignNonSingletonType]
end

# Test what is pirate (with treat_as_own=[ForeignType, ForeignParameterizedType])
pirates = filter(
m -> Piracy.is_pirate(m; treat_as_own = [ForeignType, ForeignParameterizedType]),
meths,
)
@test length(pirates) == 3
@test length(pirates) ==
3 + # findfirst
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findfirst]
m.name in [:findfirst, :ForeignNonSingletonType]
end

# Test what is pirate (with treat_as_own=[Base.findfirst, Base.findmax])
pirates =
filter(m -> Piracy.is_pirate(m; treat_as_own = [Base.findfirst, Base.findmax]), meths)
@test length(pirates) == 3
pirates = Piracy.hunt(PiracyModule, treat_as_own = [Base.findfirst, Base.findmax])
@test length(pirates) ==
3 + # findmin
1 + # ForeignType callable
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findmin]
m.name in [:findmin, :ForeignType, :ForeignNonSingletonType]
end

# Test what is pirate (excluding a cover of everything)
pirates = filter(
m -> Piracy.is_pirate(
m;
treat_as_own = [ForeignType, ForeignParameterizedType, Base.findfirst],
treat_as_own = [
ForeignType,
ForeignParameterizedType,
ForeignNonSingletonType,
Base.findfirst,
],
),
meths,
)
Expand Down

0 comments on commit ff9761b

Please sign in to comment.