Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 15, 2023
1 parent bdb5afb commit 4b9f776
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ result, ∂v, ∂A
if !(A <: Const)
@assert ReturnShadow
end
Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
Enzyme.Compiler.thunk(nothing, Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
end

"""
Expand Down Expand Up @@ -584,7 +584,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated

world = codegen_world_age(eltype(FA), tt)

Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI)
Enzyme.Compiler.thunk(nothing, Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI)
end

@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI}
Expand All @@ -610,7 +610,7 @@ end

primal_tt = Tuple{map(eltype, args)...}
world = codegen_world_age(eltype(FA), primal_tt)
nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
nondef = Enzyme.Compiler.thunk(nothing, Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
TapeType = EnzymeRules.tape_type(nondef[1])
return TapeType
end
Expand Down Expand Up @@ -684,7 +684,7 @@ result, ∂v, ∂A
world = codegen_world_age(eltype(FA), primal_tt)

# TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior
nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
nondef = Enzyme.Compiler.thunk(nothing, Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
TapeType = EnzymeRules.tape_type(nondef[1])
A2 = Compiler.return_type(typeof(nondef[1]))

Expand Down Expand Up @@ -995,15 +995,15 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
World = Val(nothing)
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
primal, adjoint = Enzyme.Compiler.thunk(nothing, Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)

if num * chunk == n_out_val
last_size = chunk
primal2, adjoint2 = primal, adjoint
else
last_size = n_out_val - (num-1)*chunk
tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}}
primal2, adjoint2 = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
primal2, adjoint2 = Enzyme.Compiler.thunk(nothing, Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
end

tmp = ntuple(num) do i
Expand Down Expand Up @@ -1034,7 +1034,7 @@ end
rt = Core.Compiler.return_type(f, tt)
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
primal, adjoint = Enzyme.Compiler.thunk(nothing, Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)
rows = ntuple(n_outs) do i
Base.@_inline_meta
dx = zero(x)
Expand Down

0 comments on commit 4b9f776

Please sign in to comment.