From ed2391eb5d44edc220465a1636e05e62a186c76b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 09:10:23 -0400 Subject: [PATCH] fix: non-scalar tangents --- Project.toml | 2 +- ext/DataInterpolationsChainRulesCoreExt.jl | 23 ++++++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index deb03b09..c53e2c5e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DataInterpolations" uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -version = "6.4.3" +version = "6.4.4" [deps] FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 9e9659d1..a811784d 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -64,8 +64,13 @@ function u_tangent(A::LinearInterpolation, t, Δ) out = zero(A.u) idx = get_idx(A, t, A.iguesser) t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx]) - out[idx] = Δ * (one(eltype(out)) - t_factor) - out[idx + 1] = Δ * t_factor + if eltype(out) <: Number + out[idx] = Δ * (one(eltype(out)) - t_factor) + out[idx + 1] = Δ * t_factor + else + @. out[idx] = Δ * (true - t_factor) + @. out[idx + 1] = Δ * t_factor + end out end @@ -78,9 +83,15 @@ function u_tangent(A::QuadraticInterpolation, t, Δ) Δt₀ = t₁ - t₀ Δt₁ = t₂ - t₁ Δt₂ = t₂ - t₀ - out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂) - out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁) - out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁) + if eltype(out) <: Number + out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂) + out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁) + out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁) + else + @. out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂) + @. out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁) + @. out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁) + end out end @@ -100,7 +111,7 @@ function ChainRulesCore.rrule(::typeof(_interpolate), t::Number) deriv = derivative(A, t) function interpolate_pullback(Δ) - (NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), deriv * Δ) + (NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), sum(deriv .* Δ)) end return _interpolate(A, t), interpolate_pullback end