Skip to content

Commit

Permalink
EnzymeTestUtils: use make_zero instead of zero_tangent (#2006)
Browse files Browse the repository at this point in the history
* fix

* fix

* EnzymeTestUtils: use make_zero instead of zero_tangent

* bump

* Add make zero

* fix

* fixup

* fix
  • Loading branch information
wsmoses authored Oct 23, 2024
1 parent 53c3198 commit 820c005
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 40 deletions.
2 changes: 1 addition & 1 deletion lib/EnzymeTestUtils/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeTestUtils"
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
authors = ["Seth Axen <[email protected]>", "William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.2.0"
version = "0.2.1"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
10 changes: 5 additions & 5 deletions lib/EnzymeTestUtils/src/finite_difference_calls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ function _fd_forward(fdm, f, rettype, y, activities)
# vectorize inputs and outputs of function
f_vec = first to_vec Base.splat(f_sig_args) from_vec_in
if rettype <: Union{Duplicated,DuplicatedNoNeed}
all(ignores) && return zero_tangent(y)
all(ignores) && return Enzyme.make_zero(y)
sig_arg_dval_vec, _ = to_vec(ẋs[.!ignores])
ret_deval_vec = FiniteDifferences.jvp(fdm, f_vec,
(sig_arg_val_vec, sig_arg_dval_vec))
return from_vec_out(ret_deval_vec)
elseif rettype <: Union{BatchDuplicated,BatchDuplicatedNoNeed}
all(ignores) && return (var"1"=zero_tangent(y),)
all(ignores) && return (var"1"=Enzyme.make_zero(y),)
ret_dvals = map(ẋs[.!ignores]...) do sig_args_dvals...
sig_args_dvals_vec, _ = to_vec(sig_args_dvals)
ret_dval_vec = FiniteDifferences.jvp(fdm, f_vec,
Expand Down Expand Up @@ -67,13 +67,13 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return)
xs = map(x -> x.val, activities)
ignores = map(a -> a isa Const, activities)
f_sig_args = _wrap_reverse_function(active_return, f, xs, ignores)
all(ignores) && return map(zero_tangent, xs)
all(ignores) && return map(Enzyme.make_zero, xs)
ignores = collect(ignores)
is_batch = _any_batch_duplicated(map(typeof, activities)...)
batch_size = is_batch ? _batch_size(map(typeof, activities)...) : 1
x̄s = map(collect(activities)) do a
if a isa Union{Const,Active}
dval = ntuple(_ -> zero_tangent(a.val), batch_size)
dval = ntuple(_ -> Enzyme.make_zero(a.val), batch_size)
return is_batch ? dval : dval[1]
else
return a.dval
Expand Down Expand Up @@ -178,7 +178,7 @@ function _wrap_reverse_function(active_return, f, xs, ignores)
# zero, if the input and output alias.
if active_return
for k in keys(zeros)
zeros[k] = zero_tangent(k)
zeros[k] = Enzyme.make_zero(k)
end
end

Expand Down
7 changes: 0 additions & 7 deletions lib/EnzymeTestUtils/src/generate_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@ function rand_tangent(rng, x)
return from_vec(v_new)
end

# differs from Enzyme.make_zero primarily in that reshaped Arrays in the argument will share
# the same memory in the output.
function zero_tangent(x)
v, from_vec = to_vec(x)
return from_vec(zero(v))
end

auto_activity(arg) = auto_activity(Random.default_rng(), arg)
function auto_activity(rng, arg::Tuple)
if length(arg) == 2 && arg[2] isa Type && arg[2] <: Annotation
Expand Down
4 changes: 2 additions & 2 deletions lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ function test_reverse(
y = fcopy(args_copy...; deepcopy(fkwargs)...)
# generate tangent for output
if !_any_batch_duplicated(ret_activity, map(typeof, activities)...)
= ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y)
= ret_activity <: Const ? Enzyme.make_zero(y) : rand_tangent(rng, y)
else
batch_size = _batch_size(ret_activity, map(typeof, activities)...)
ks = ntuple(Symbol string, batch_size)
= ntuple(batch_size) do _
return ret_activity <: Const ? zero_tangent(y) : rand_tangent(y)
return ret_activity <: Const ? Enzyme.make_zero(y) : rand_tangent(y)
end
end
# call finitedifferences, avoid mutating original arguments
Expand Down
46 changes: 46 additions & 0 deletions lib/EnzymeTestUtils/src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,52 @@ function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:Array}
end
return x_vec, Array_from_vec
end

@static if VERSION < v"1.11-"
else
# basic containers: loop over defined elements, recursively converting them to vectors
function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:GenericMemory}
has_seen = haskey(seen_vecs, x)
is_const = Enzyme.Compiler.guaranteed_const_nongen(RT, nothing)
if has_seen || is_const
x_vec = Float32[]
else
x_vecs = Vector{<:AbstractFloat}[]
from_vecs = []
subvec_inds = UnitRange{Int}[]
l = 0
for i in eachindex(x)
isassigned(x, i) || continue
xi_vec, xi_from_vec = to_vec(x[i], seen_vecs)
push!(x_vecs, xi_vec)
push!(from_vecs, xi_from_vec)
push!(subvec_inds, (l + 1):(l + length(xi_vec)))
l += length(xi_vec)
end
x_vec = reduce(vcat, x_vecs; init=Float32[])
seen_vecs[x] = x_vec
end
function Memory_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict)
if xor(has_seen, haskey(seen_xs, x))
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
end
has_seen && return reshape(seen_xs[x], size(x))
is_const && return x
x_new = typeof(x)(undef, size(x))
k = 1
for i in eachindex(x)
isassigned(x, i) || continue
xi = from_vecs[k](x_vec_new[subvec_inds[k]], seen_xs)
x_new[i] = xi
k += 1
end
seen_xs[x] = x_new
return x_new
end
return x_vec, Memory_from_vec
end
end

function to_vec(x::Tuple, seen_vecs::AliasDict)
x_vec, from_vec = to_vec(collect(x), seen_vecs)
function Tuple_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict)
Expand Down
26 changes: 1 addition & 25 deletions lib/EnzymeTestUtils/test/generate_tangent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test
using EnzymeTestUtils
using EnzymeTestUtils: rand_tangent, zero_tangent
using EnzymeTestUtils: rand_tangent
using Enzyme
using Quaternions

Expand Down Expand Up @@ -42,30 +42,6 @@ using Quaternions
@test y.a != x.a
end

@testset "zero_tangent" begin
@test zero_tangent(1) == 1
@test zero_tangent(true) == true
@test zero_tangent(false) == false
@test zero_tangent(:foo) === :foo
@test zero_tangent("bar") === "bar"
@testset for T in (
Float32, Float64, ComplexF32, ComplexF64, QuaternionF32, QuaternionF64
)
x = randn(T)
@test zero_tangent(x) === zero(T)
y = randn(T, 5)
@test zero_tangent(y) == zero(y)
@test zero_tangent(y) isa typeof(y)
end
x = TestStruct(TestStruct(:foo, TestStruct(1, 3.0f0 + 1im)), [4.0, 5.0])
y = zero_tangent(x)
@test y.x.x == :foo
@test y.x.a.x == 1
@test y.x.a.a === zero(ComplexF32)
@test y.a isa Vector{Float64}
@test y.a == zero(x.a)
end

@testset "auto_activity" begin
@test EnzymeTestUtils.auto_activity((1.0, Const)) === Const(1.0)
@test EnzymeTestUtils.auto_activity((1.0, Active)) === Active(1.0)
Expand Down
134 changes: 134 additions & 0 deletions src/make_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ end
return Base.zero(x)
end


@static if VERSION < v"1.11-"
else
@inline function EnzymeCore.make_zero(
x::GenericMemory{kind, FT},
)::GenericMemory{kind, FT} where {FT<:AbstractFloat,kind}
return Base.zero(x)
end
@inline function EnzymeCore.make_zero(
x::GenericMemory{kind, Complex{FT}},
)::GenericMemory{kind, Complex{FT}} where {FT<:AbstractFloat,kind}
return Base.zero(x)
end
end


@inline function EnzymeCore.make_zero(
::Type{Array{FT,N}},
seen::IdDict,
Expand Down Expand Up @@ -43,6 +59,36 @@ end
return newa
end

@static if VERSION < v"1.11-"
else
@inline function EnzymeCore.make_zero(
::Type{GenericMemory{kind, FT}},
seen::IdDict,
prev::GenericMemory{kind, FT},
::Val{copy_if_inactive} = Val(false),
)::GenericMemory{kind, FT} where {copy_if_inactive,FT<:AbstractFloat,kind}
if haskey(seen, prev)
return seen[prev]
end
newa = Base.zero(prev)
seen[prev] = newa
return newa
end
@inline function EnzymeCore.make_zero(
::Type{GenericMemory{kind, Complex{FT}}},
seen::IdDict,
prev::GenericMemory{kind, Complex{FT}},
::Val{copy_if_inactive} = Val(false),
)::GenericMemory{kind, Complex{FT}} where {copy_if_inactive,FT<:AbstractFloat,kind}
if haskey(seen, prev)
return seen[prev]
end
newa = Base.zero(prev)
seen[prev] = newa
return newa
end
end

@inline function EnzymeCore.make_zero(
::Type{RT},
seen::IdDict,
Expand Down Expand Up @@ -86,6 +132,34 @@ end
return newa
end

@static if VERSION < v"1.11-"
else
@inline function EnzymeCore.make_zero(
::Type{RT},
seen::IdDict,
prev::RT,
::Val{copy_if_inactive} = Val(false),
)::RT where {copy_if_inactive,RT<:GenericMemory}
if haskey(seen, prev)
return seen[prev]
end
if guaranteed_const_nongen(RT, nothing)
return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev
end
newa = RT(undef, size(prev))
seen[prev] = newa
for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
innerty = Core.Typeof(pv)
@inbounds newa[I] =
EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive))
end
end
return newa
end
end

@inline function EnzymeCore.make_zero(
::Type{RT},
seen::IdDict,
Expand Down Expand Up @@ -267,6 +341,25 @@ end
nothing
end

@static if VERSION < v"1.11-"
else
@inline function EnzymeCore.make_zero!(
prev::GenericMemory{kind, T},
seen::ST,
)::Nothing where {T<:AbstractFloat,kind,ST}
fill!(prev, zero(T))
nothing
end

@inline function EnzymeCore.make_zero!(
prev::Array{GenericMemory, Complex{T}},
seen::ST,
)::Nothing where {T<:AbstractFloat,kind,ST}
fill!(prev, zero(Complex{T}))
nothing
end
end

@inline function EnzymeCore.make_zero!(
prev::Base.RefValue{T},
)::Nothing where {T<:AbstractFloat}
Expand Down Expand Up @@ -318,6 +411,47 @@ end
nothing
end

@static if VERSION < v"1.11-"
else
@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind}
EnzymeCore.make_zero!(prev, nothing)
nothing
end

@inline function EnzymeCore.make_zero!(
prev::GenericMemory{kind, Complex{T}},
)::Nothing where {T<:AbstractFloat, kind}
EnzymeCore.make_zero!(prev, nothing)
nothing
end

@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST}
if guaranteed_const_nongen(T, nothing)
return
end
if in(seen, prev)
return
end
push!(seen, prev)

for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
SBT = Core.Typeof(pv)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
@inbounds prev[I] = make_zero_immutable!(pv, seen)
nothing
else
EnzymeCore.make_zero!(pv, seen)
nothing
end
end
end
nothing
end
end


@inline function EnzymeCore.make_zero!(
prev::Base.RefValue{T},
seen::ST,
Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,13 @@ end
world = codegen_world_age(typeof(mul2), Tuple{Vector{Float64}})
forward, pullback = Enzyme.Compiler.thunk(Val(world), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false))
res = forward(Const(mul2), d)

@static if VERSION < v"1.11-"
@test typeof(res[1]) == Tuple{Float64, Float64}
else
@test typeof(res[1]) == NamedTuple{(Symbol("1"),Symbol("2"),Symbol("3"),Symbol("4"),Symbol("5"),Symbol("6")), Tuple{Any, Core.LLVMPtr{UInt8, 0}, Any, Core.LLVMPtr{Any, 0}, Float64, Float64}}
end

pullback(Const(mul2), d, 1.0, res[1])
@test d.dval[1] 5.0
@test d.dval[2] 3.0
Expand Down

0 comments on commit 820c005

Please sign in to comment.