From bdb5afbe0fc82f5a415344b166da366ecf3c2b28 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 29 Nov 2023 03:03:43 +0000 Subject: [PATCH 1/2] WIP invoke fix --- src/Enzyme.jl | 6 ++--- src/compiler.jl | 36 +++++++++++++++-------------- src/rules/activityrules.jl | 4 ++-- src/rules/jitrules.jl | 46 ++++++++++++++++++++++---------------- 4 files changed, 51 insertions(+), 41 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 470f8ecde9..c91e153bbf 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -194,7 +194,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} rt = Core.Compiler.return_type(f.val, tt) if !allocatedinline(rt) || rt isa Union - forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) + forward, adjoint = Enzyme.Compiler.thunk(nothing, Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) res = forward(f, args′...) tape = res[1] if ReturnPrimal @@ -206,7 +206,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed || A <: BatchDuplicatedFunc throw(ErrorException("Duplicated Returns not yet handled")) end - thunk = Enzyme.Compiler.thunk(Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + thunk = Enzyme.Compiler.thunk(nothing, Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) if A <: Active tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} rt = Core.Compiler.return_type(f.val, tt) @@ -319,7 +319,7 @@ f(x) = x*x tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} world = codegen_world_age(Core.Typeof(f.val), tt) - thunk = Enzyme.Compiler.thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + thunk = Enzyme.Compiler.thunk(nothing, Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) thunk(f, args′...) end diff --git a/src/compiler.jl b/src/compiler.jl index 223cfa84b7..25f042f265 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5298,10 +5298,8 @@ end @inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated @inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed -@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} +@inline function thunkbase(mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} JuliaContext() do ctx - mi = fspec(eltype(FA), TT, World) - target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) @@ -5352,30 +5350,34 @@ end TapeType = compile_result.TapeType AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal), TapeType} AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, TapeType} - return quote - Base.@_inline_meta - augmented = $AugT($(compile_result.primal)) - adjoint = $AdjT($(compile_result.adjoint)) - (augmented, adjoint) - end + augmented = AugT((compile_result.primal)) + adjoint = AdjT((compile_result.adjoint)) + return (augmented, adjoint) elseif Mode == API.DEM_ReverseModeCombined CAdjT = CombinedAdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)} - return quote - Base.@_inline_meta - $CAdjT($(compile_result.adjoint)) - end + return CAdjT(compile_result.adjoint) elseif Mode == API.DEM_ForwardMode FMT = ForwardModeThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, Val{width}, Val(ReturnPrimal)} - return quote - Base.@_inline_meta - $FMT($(compile_result.adjoint)) - end + return FMT(compile_result.adjoint) else @assert false end end end +@inline function thunk(mi::Core.MethodInstance, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} + return thunkbase(mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) +end + +@inline @generated function thunk(::Nothing, ::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} + mi = fspec(eltype(FA), TT, World) + res = thunkbase(mi, Val(World), FA, A, TT, Val(Mode), Val(width), Val(ModifiedBetween), Val(ReturnPrimal), Val(ShadowInit), ABI) + return quote + Base.@_inline_meta + return $(res) + end +end + import GPUCompiler: deferred_codegen_jobs @generated function deferred_codegen(::Val{World}, ::Type{FA}, ::Val{TT}, ::Val{A},::Val{Mode}, diff --git a/src/rules/activityrules.jl b/src/rules/activityrules.jl index ee93f532b9..060e1a8c5b 100644 --- a/src/rules/activityrules.jl +++ b/src/rules/activityrules.jl @@ -24,7 +24,7 @@ function julia_activity_rule(f::LLVM.Function) # Unsupported calling conv # also wouldn't have any type info for this [would for earlier args though] - if mi.specTypes.parameters[end] === Vararg{Any} + if Base.isvarargtype(mi.specTypes.parameters[end]) return end @@ -71,4 +71,4 @@ function julia_activity_rule(f::LLVM.Function) push!(return_attributes(f), StringAttribute("enzyme_inactive")) end end -end \ No newline at end of file +end diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 4073638d7f..98fa8db9a9 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -111,7 +111,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) world = codegen_world_age(FT, tt) - forward = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + forward = thunk(mi, Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -135,13 +135,13 @@ function func_runtime_generic_fwd(N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote - function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_fwd(mi::MI, activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {MI, ActivityTup, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} +@generated function runtime_generic_fwd(mi::MI, activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {MI, ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) @@ -172,7 +172,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) world = codegen_world_age(FT, Tuple{$(ElTypes...)}) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, + forward, adjoint = thunk(mi, Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) @@ -213,13 +213,13 @@ function func_runtime_generic_augfwd(N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes) quote - function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_augfwd(mi::MI, activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {MI, ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} +@generated function runtime_generic_augfwd(mi::MI, activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {MI, ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes) @@ -278,7 +278,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) end world = codegen_world_age(FT, tt) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, + forward, adjoint = thunk(mi, Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) if tape.shadow_return !== nothing args = (args..., $shadowret) @@ -296,13 +296,13 @@ function func_runtime_generic_rev(N, Width) body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) quote - function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, shadow_ptr, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} + function runtime_generic_rev(mi::MI, activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, shadow_ptr, f::F, df::DF, $(allargs...)) where {MI, ActivityTup, MB, TapeType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, shadow_ptr, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} +@generated function runtime_generic_rev(mi::MI, activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, shadow_ptr, f::F, df::DF, allargs...) where {MI, ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) @@ -315,7 +315,7 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_generic_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false) +function generic_setup(orig, func, mi::Union{LLVM.Value, Nothing}, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -436,6 +436,11 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, @assert length(collect(LLVM.uses(etup0))) == 1 end pushfirst!(vals, etup) + if mi === nothing + pushfirst!(vals, unsafe_to_llvm(nothing)) + else + pushfirst!(vals, mi) + end @static if VERSION < v"1.7.0-" || true else @@ -479,7 +484,7 @@ function common_generic_fwd(offset, B, orig, gutils, normalR, shadowR) width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset, B, false) + sret = generic_setup(orig, runtime_generic_fwd, nothing, AnyArray(1+Int(width)), gutils, #=start=#offset, B, false) AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) if shadowR != C_NULL if width == 1 @@ -523,7 +528,7 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) end width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset, B, false) + sret = generic_setup(orig, runtime_generic_augfwd, nothing, AnyArray(2+Int(width)), gutils, #=start=#offset, B, false) AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) if shadowR != C_NULL @@ -566,7 +571,7 @@ function common_generic_rev(offset, B, orig, gutils, tape)::Cvoid @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset, B, true; tape) + generic_setup(orig, runtime_generic_rev, nothing, Nothing, gutils, #=start=#offset, B, true; tape) end return nothing end @@ -594,7 +599,7 @@ function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) width = get_width(gutils) AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) + sret = generic_setup(orig, runtime_generic_fwd, nothing, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) if shadowR != C_NULL if width == 1 @@ -633,7 +638,7 @@ function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, t width = get_width(gutils) AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) # sret = generic_setup(orig, runtime_apply_latest_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, ctx, B, false) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) + sret = generic_setup(orig, runtime_generic_augfwd, nothing, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) if shadowR != C_NULL if width == 1 @@ -664,7 +669,7 @@ end function common_apply_latest_rev(offset, B, orig, gutils, tape)::Cvoid if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) + generic_setup(orig, runtime_generic_rev, nothing, Nothing, gutils, #=start=#offset+1, B, true; tape) end return nothing @@ -825,7 +830,8 @@ function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) + mi = operands(orig)[offset] + sret = generic_setup(orig, runtime_generic_fwd, mi, AnyArray(1+Int(width)), gutils, #=start=#offset+1, B, false) AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) if shadowR != C_NULL @@ -865,7 +871,8 @@ function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) conv = LLVM.callconv(orig) width = get_width(gutils) - sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) + mi = operands(orig)[offset] + sret = generic_setup(orig, runtime_generic_augfwd, mi, AnyArray(2+Int(width)), gutils, #=start=#offset+1, B, false) AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) if shadowR != C_NULL @@ -898,7 +905,8 @@ end function common_invoke_rev(offset, B, orig, gutils, tape) if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) + mi = operands(orig)[offset] + generic_setup(orig, runtime_generic_rev, mi, Nothing, gutils, #=start=#offset+1, B, true; tape) end return nothing From 4b9f776ba622a756c683d2ad400f5819064dd098 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 15 Dec 2023 18:30:20 -0500 Subject: [PATCH 2/2] fixup --- src/Enzyme.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c91e153bbf..6be72f513c 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -521,7 +521,7 @@ result, ∂v, ∂A if !(A <: Const) @assert ReturnShadow end - Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + Enzyme.Compiler.thunk(nothing, Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) end """ @@ -584,7 +584,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated world = codegen_world_age(eltype(FA), tt) - Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) + Enzyme.Compiler.thunk(nothing, Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) end @inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI} @@ -610,7 +610,7 @@ end primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + nondef = Enzyme.Compiler.thunk(nothing, Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) TapeType = EnzymeRules.tape_type(nondef[1]) return TapeType end @@ -684,7 +684,7 @@ result, ∂v, ∂A world = codegen_world_age(eltype(FA), primal_tt) # TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior - nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + nondef = Enzyme.Compiler.thunk(nothing, Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) TapeType = EnzymeRules.tape_type(nondef[1]) A2 = Compiler.return_type(typeof(nondef[1])) @@ -995,7 +995,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} World = Val(nothing) - primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + primal, adjoint = Enzyme.Compiler.thunk(nothing, Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) if num * chunk == n_out_val last_size = chunk @@ -1003,7 +1003,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) else last_size = n_out_val - (num-1)*chunk tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + primal2, adjoint2 = Enzyme.Compiler.thunk(nothing, Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) end tmp = ntuple(num) do i @@ -1034,7 +1034,7 @@ end rt = Core.Compiler.return_type(f, tt) ModifiedBetween = Val((false, false)) FA = Const{Core.Typeof(f)} - primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) + primal, adjoint = Enzyme.Compiler.thunk(nothing, Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI) rows = ntuple(n_outs) do i Base.@_inline_meta dx = zero(x)