From 0da5429d749c3e5f8d56a282ab5a01a5f2f141b7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 6 Nov 2024 18:03:07 +0100 Subject: [PATCH] fix: correct interpretation for `compile` with `AutoReverseDiff` (#613) * fix: correct interpretation for `compile` with `AutoReverseDiff` * Fix static arrays * Wrong arg order --- DifferentiationInterface/Project.toml | 2 +- .../docs/src/explanation/backends.md | 8 +- .../DifferentiationInterfaceReverseDiffExt.jl | 2 +- .../onearg.jl | 279 +++++++++++++----- .../twoarg.jl | 96 ++++-- .../test/Back/ReverseDiff/test.jl | 2 +- 6 files changed, 281 insertions(+), 108 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index ebac92fc..b2c9b89e 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.20" +version = "0.6.21" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 1d853e99..53ea81ff 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -144,14 +144,18 @@ Most operators fall back on `AutoForwardDiff`. ### ReverseDiff -Wherever possible, preparation records a [tape](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractTape-API) of the function's execution. -This tape is computed from the arguments `x` and `contexts...` provided at preparation time. +With `AutoReverseDiff(compile=false)`, preparation preallocates a [config](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractConfig-API). + +With `AutoReverseDiff(compile=true)`, preparation records a [tape](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractTape-API) of the function's execution. +This tape is computed from the input `x` provided at preparation time. It is control-flow dependent, so only one branch is recorded at each `if` statement. !!! danger If your function has value-specific control flow (like `if x[1] > 0` or `if c == 1`), you may get silently wrong results whenever it takes new branches that were not taken during preparation. You must make sure to run preparation with an input and contexts whose values trigger the correct control flow for future executions. +Whenever contexts are provided, tape recording is deactivated in all cases, because otherwise the context values would be hardcoded into a tape. + ### Symbolics For all operators, preparation generates an [executable function](https://docs.sciml.ai/Symbolics/stable/manual/build_function/) from the symbolic expression of the differentiated function. diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl index a6f2a196..d57116bb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl @@ -20,6 +20,7 @@ using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult using LinearAlgebra: dot, mul! using ReverseDiff: + ReverseDiff, CompiledGradient, CompiledHessian, CompiledJacobian, @@ -29,7 +30,6 @@ using ReverseDiff: HessianTape, JacobianConfig, JacobianTape, - compile, gradient, gradient!, hessian, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index a316d1ca..2536abe8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -29,8 +29,8 @@ end function DI.value_and_pullback!( f, - ::NoPullbackPrep, tx::NTuple, + ::NoPullbackPrep, ::AutoReverseDiff, x::AbstractArray, ty::NTuple, @@ -69,24 +69,31 @@ end ### Without contexts -struct ReverseDiffGradientPrep{T} <: GradientPrep +@kwdef struct ReverseDiffGradientPrep{C,T} <: GradientPrep + config::C tape::T end -function DI.prepare_gradient(f, ::AutoReverseDiff{Compile}, x) where {Compile} - tape = GradientTape(f, x) - if Compile - tape = compile(tape) +function DI.prepare_gradient(f, ::AutoReverseDiff{compile}, x) where {compile} + if compile + tape = ReverseDiff.compile(GradientTape(f, x)) + return ReverseDiffGradientPrep(; config=nothing, tape=tape) + else + config = GradientConfig(x) + return ReverseDiffGradientPrep(; config=config, tape=nothing) end - return ReverseDiffGradientPrep(tape) end function DI.value_and_gradient!( - f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x -) + f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x +) where {compile} y = f(x) # TODO: ReverseDiff#251 result = DiffResult(y, (grad,)) - result = gradient!(result, prep.tape, x) + if compile + result = gradient!(result, prep.tape, x) + else + result = gradient!(result, f, x, prep.config) + end y = DR.value(result) grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) return y, grad @@ -99,220 +106,338 @@ function DI.value_and_gradient( return DI.value_and_gradient!(f, grad, prep, backend, x) end -function DI.gradient!(_f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x) - return gradient!(grad, prep.tape, x) +function DI.gradient!( + f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + return gradient!(grad, prep.tape, x) + else + return gradient!(grad, f, x, prep.config) + end end -function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x) - return gradient!(prep.tape, x) +function DI.gradient( + f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + return gradient!(prep.tape, x) + else + return gradient(f, x, prep.config) + end end ### With contexts function DI.prepare_gradient(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C} - return NoGradientPrep() + config = GradientConfig(x) + return ReverseDiffGradientPrep(; config=config, tape=nothing) end function DI.value_and_gradient!( - f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, + grad, + prep::ReverseDiffGradientPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc = with_contexts(f, contexts...) y = fc(x) # TODO: ReverseDiff#251 result = DiffResult(y, (grad,)) - result = gradient!(result, fc, x) + result = gradient!(result, fc, x, prep.config) y = DR.value(result) grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) return y, grad end function DI.value_and_gradient( - f, prep::NoGradientPrep, backend::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} grad = similar(x) return DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) end function DI.gradient!( - f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, + grad, + prep::ReverseDiffGradientPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc = with_contexts(f, contexts...) - return gradient!(grad, fc, x) + return gradient!(grad, fc, x, prep.config) end function DI.gradient( - f, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} ) where {C} fc = with_contexts(f, contexts...) - return gradient(fc, x) + return gradient(fc, x, prep.config) end ## Jacobian ### Without contexts -struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep +@kwdef struct ReverseDiffOneArgJacobianPrep{C,T} <: JacobianPrep + config::C tape::T end -function DI.prepare_jacobian(f, ::AutoReverseDiff{Compile}, x) where {Compile} - tape = JacobianTape(f, x) - if Compile - tape = compile(tape) +function DI.prepare_jacobian(f, ::AutoReverseDiff{compile}, x) where {compile} + if compile + tape = ReverseDiff.compile(JacobianTape(f, x)) + return ReverseDiffOneArgJacobianPrep(; config=nothing, tape=tape) + else + config = JacobianConfig(x) + return ReverseDiffOneArgJacobianPrep(; config=config, tape=nothing) end - return ReverseDiffOneArgJacobianPrep(tape) end function DI.value_and_jacobian!( - f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x -) + f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x +) where {compile} y = f(x) result = DiffResult(y, (jac,)) - result = jacobian!(result, prep.tape, x) + if compile + result = jacobian!(result, prep.tape, x) + else + result = jacobian!(result, f, x, prep.config) + end y = DR.value(result) jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result)) return y, jac end -function DI.value_and_jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x) - return f(x), jacobian!(prep.tape, x) +function DI.value_and_jacobian( + f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + return f(x), jacobian!(prep.tape, x) + else + return f(x), jacobian(f, x, prep.config) + end end -function DI.jacobian!(_f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x) - return jacobian!(jac, prep.tape, x) +function DI.jacobian!( + f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + return jacobian!(jac, prep.tape, x) + else + return jacobian!(jac, f, x, prep.config) + end end -function DI.jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x) - return jacobian!(prep.tape, x) +function DI.jacobian( + f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + return jacobian!(prep.tape, x) + else + return jacobian(f, x, prep.config) + end end ### With contexts function DI.prepare_jacobian(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C} - return NoJacobianPrep() + config = JacobianConfig(x) + return ReverseDiffOneArgJacobianPrep(; config=config, tape=nothing) end function DI.value_and_jacobian!( - f, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, + jac, + prep::ReverseDiffOneArgJacobianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc = with_contexts(f, contexts...) y = fc(x) result = DiffResult(y, (jac,)) - result = jacobian!(result, fc, x) + result = jacobian!(result, fc, x, prep.config) y = DR.value(result) jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result)) return y, jac end function DI.value_and_jacobian( - f, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, + prep::ReverseDiffOneArgJacobianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc = with_contexts(f, contexts...) - return fc(x), jacobian(fc, x) + return fc(x), jacobian(fc, x, prep.config) end function DI.jacobian!( - f, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, + jac, + prep::ReverseDiffOneArgJacobianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc = with_contexts(f, contexts...) - return jacobian!(jac, fc, x) + return jacobian!(jac, fc, x, prep.config) end function DI.jacobian( - f, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, + prep::ReverseDiffOneArgJacobianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc = with_contexts(f, contexts...) - return jacobian(fc, x) + return jacobian(fc, x, prep.config) end ## Hessian ### Without contexts -struct ReverseDiffHessianGradientPrep{GT,HT} <: HessianPrep +@kwdef struct ReverseDiffHessianPrep{GC,HC,GT,HT} <: HessianPrep + gradient_config::GC + hessian_config::HC gradient_tape::GT hessian_tape::HT end -function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x) where {Compile} - gradient_tape = GradientTape(f, x) - hessian_tape = HessianTape(f, x) - if Compile - gradient_tape = compile(gradient_tape) - hessian_tape = compile(hessian_tape) +function DI.prepare_hessian(f, ::AutoReverseDiff{compile}, x) where {compile} + if compile + gradient_tape = ReverseDiff.compile(GradientTape(f, x)) + hessian_tape = ReverseDiff.compile(HessianTape(f, x)) + return ReverseDiffHessianPrep(; + gradient_config=nothing, + hessian_config=nothing, + gradient_tape=gradient_tape, + hessian_tape=hessian_tape, + ) + else + gradient_config = GradientConfig(x) + hessian_config = HessianConfig(x) + return ReverseDiffHessianPrep(; + gradient_config=gradient_config, + hessian_config=hessian_config, + gradient_tape=nothing, + hessian_tape=nothing, + ) end - return ReverseDiffHessianGradientPrep(gradient_tape, hessian_tape) end -function DI.hessian!(_f, hess, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x) - return hessian!(hess, prep.hessian_tape, x) +function DI.hessian!( + f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + return hessian!(hess, prep.hessian_tape, x) + else + return hessian!(hess, f, x, prep.hessian_config) + end end -function DI.hessian(_f, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x) - return hessian!(prep.hessian_tape, x) +function DI.hessian( + f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + return hessian!(prep.hessian_tape, x) + else + return hessian(f, x, prep.hessian_config) + end end function DI.value_gradient_and_hessian!( - f, grad, hess, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x -) + f, grad, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x +) where {compile} y = f(x) # TODO: ReverseDiff#251 result = DiffResult(y, (grad, hess)) - result = hessian!(result, prep.hessian_tape, x) - y = DR.value(result) + if compile + result = hessian!(result, prep.hessian_tape, x) + grad = gradient!(grad, prep.gradient_tape, x) # TODO: ReverseDiff#251 + else + result = hessian!(result, f, x) # TODO: add prep.hessian_config + grad = gradient!(grad, f, x, prep.gradient_config) # TODO: ReverseDiff#251 + end # grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) - grad = gradient!(grad, prep.gradient_tape, x) # TODO: ReverseDiff#251 hess === DR.hessian(result) || copyto!(hess, DR.hessian(result)) return y, grad, hess end function DI.value_gradient_and_hessian( - f, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x -) + f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x +) where {compile} y = f(x) # TODO: remove once ReverseDiff#251 is fixed result = DiffResult(y, (similar(x), similar(x, length(x), length(x)))) - result = hessian!(result, prep.hessian_tape, x) - return (DR.value(result), DR.gradient(result), DR.hessian(result)) + if compile + result = hessian!(result, prep.hessian_tape, x) + else + result = hessian!(result, f, x) # todo: add prep.hessian_config + end + return (y, DR.gradient(result), DR.hessian(result)) end ### With contexts function DI.prepare_hessian(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C} - return NoHessianPrep() + gradient_config = GradientConfig(x) + hessian_config = HessianConfig(x) + return ReverseDiffHessianPrep(; + gradient_config=gradient_config, + hessian_config=hessian_config, + gradient_tape=nothing, + hessian_tape=nothing, + ) end function DI.hessian!( - f, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} ) where {C} fc = with_contexts(f, contexts...) - return hessian!(hess, fc, x) + return hessian!(hess, fc, x, prep.hessian_config) end function DI.hessian( - f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} ) where {C} fc = with_contexts(f, contexts...) - return hessian(fc, x) + return hessian(fc, x, prep.hessian_config) end function DI.value_gradient_and_hessian!( - f, grad, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, + grad, + hess, + prep::ReverseDiffHessianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc = with_contexts(f, contexts...) y = fc(x) # TODO: ReverseDiff#251 result = DiffResult(y, (grad, hess)) - result = hessian!(result, fc, x) + result = hessian!(result, fc, x) # TODO: add prep.hessian_config y = DR.value(result) # grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) - grad = gradient!(grad, fc, x) # TODO: ReverseDiff#251 + grad = gradient!(grad, fc, x, prep.gradient_config) # TODO: ReverseDiff#251 hess === DR.hessian(result) || copyto!(hess, DR.hessian(result)) return y, grad, hess end function DI.value_gradient_and_hessian( - f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} ) where {C} fc = with_contexts(f, contexts...) y = fc(x) # TODO: ReverseDiff#251 result = HessianResult(x) - result = hessian!(result, fc, x) + result = hessian!(result, fc, x) # TODO: add prep.hessian_config return (DR.value(result), DR.gradient(result), DR.hessian(result)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 4b0673be..4d4c0aa0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -121,44 +121,65 @@ end ### Without contexts -struct ReverseDiffTwoArgJacobianPrep{T} <: JacobianPrep +@kwdef struct ReverseDiffTwoArgJacobianPrep{C,T} <: JacobianPrep + config::C tape::T end -function DI.prepare_jacobian(f!, y, ::AutoReverseDiff{Compile}, x) where {Compile} - tape = JacobianTape(f!, y, x) - if Compile - tape = compile(tape) +function DI.prepare_jacobian(f!, y, ::AutoReverseDiff{compile}, x) where {compile} + if compile + tape = ReverseDiff.compile(JacobianTape(f!, y, x)) + return ReverseDiffTwoArgJacobianPrep(; config=nothing, tape=tape) + else + config = JacobianConfig(y, x) + return ReverseDiffTwoArgJacobianPrep(; config=config, tape=nothing) end - return ReverseDiffTwoArgJacobianPrep(tape) end function DI.value_and_jacobian( - _f!, y, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff, x -) + f!, y, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x +) where {compile} jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) - result = jacobian!(result, prep.tape, x) + if compile + result = jacobian!(result, prep.tape, x) + else + result = jacobian!(result, f!, y, x, prep.config) + end return DiffResults.value(result), DiffResults.derivative(result) end function DI.value_and_jacobian!( - _f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff, x -) + f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x +) where {compile} result = MutableDiffResult(y, (jac,)) - result = jacobian!(result, prep.tape, x) + if compile + result = jacobian!(result, prep.tape, x) + else + result = jacobian!(result, f!, y, x, prep.config) + end return DiffResults.value(result), DiffResults.derivative(result) end -function DI.jacobian(_f!, _y, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff, x) - jac = jacobian!(prep.tape, x) +function DI.jacobian( + f!, y, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + jac = jacobian!(prep.tape, x) + else + jac = jacobian(f!, y, x, prep.config) + end return jac end function DI.jacobian!( - _f!, _y, jac, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff, x -) - jac = jacobian!(jac, prep.tape, x) + f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x +) where {compile} + if compile + jac = jacobian!(jac, prep.tape, x) + else + jac = jacobian!(jac, f!, y, x, prep.config) + end return jac end @@ -167,40 +188,63 @@ end function DI.prepare_jacobian( f!, y, ::AutoReverseDiff, x, contexts::Vararg{Context,C} ) where {C} - return NoJacobianPrep() + config = JacobianConfig(y, x) + return ReverseDiffTwoArgJacobianPrep(; config=config, tape=nothing) end function DI.value_and_jacobian( - f!, y, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f!, + y, + prep::ReverseDiffTwoArgJacobianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc! = with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) - result = jacobian!(result, fc!, y, x) + result = jacobian!(result, fc!, y, x, prep.config) return DiffResults.value(result), DiffResults.derivative(result) end function DI.value_and_jacobian!( - f!, y, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f!, + y, + jac, + prep::ReverseDiffTwoArgJacobianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc! = with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) - result = jacobian!(result, fc!, y, x) + result = jacobian!(result, fc!, y, x, prep.config) return DiffResults.value(result), DiffResults.derivative(result) end function DI.jacobian( - f!, y, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f!, + y, + prep::ReverseDiffTwoArgJacobianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc! = with_contexts(f!, contexts...) - jac = jacobian(fc!, y, x) + jac = jacobian(fc!, y, x, prep.config) return jac end function DI.jacobian!( - f!, y, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C} + f!, + y, + jac, + prep::ReverseDiffTwoArgJacobianPrep, + ::AutoReverseDiff, + x, + contexts::Vararg{Context,C}, ) where {C} fc! = with_contexts(f!, contexts...) - jac = jacobian!(jac, fc!, y, x) + jac = jacobian!(jac, fc!, y, x, prep.config) return jac end diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index c6140cda..dc06a49a 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -25,7 +25,7 @@ test_differentiation( logging=LOGGING, ); -test_differentiation(backends[1], static_scenarios(); logging=LOGGING); +test_differentiation(backends, static_scenarios(); logging=LOGGING); ## Sparse