Skip to content

Commit

Permalink
fix: non-scalar tangents
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 5, 2024
1 parent e73a150 commit ed2391e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DataInterpolations"
uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
version = "6.4.3"
version = "6.4.4"

[deps]
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
Expand Down
23 changes: 17 additions & 6 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,13 @@ function u_tangent(A::LinearInterpolation, t, Δ)
out = zero(A.u)
idx = get_idx(A, t, A.iguesser)
t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx])
out[idx] = Δ * (one(eltype(out)) - t_factor)
out[idx + 1] = Δ * t_factor
if eltype(out) <: Number
out[idx] = Δ * (one(eltype(out)) - t_factor)
out[idx + 1] = Δ * t_factor
else
@. out[idx] = Δ * (true - t_factor)
@. out[idx + 1] = Δ * t_factor
end
out
end

Expand All @@ -78,9 +83,15 @@ function u_tangent(A::QuadraticInterpolation, t, Δ)
Δt₀ = t₁ - t₀
Δt₁ = t₂ - t₁
Δt₂ = t₂ - t₀
out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂)
out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁)
out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁)
if eltype(out) <: Number
out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂)
out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁)
out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁)
else
@. out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂)
@. out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁)
@. out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁)
end
out
end

Expand All @@ -100,7 +111,7 @@ function ChainRulesCore.rrule(::typeof(_interpolate),
t::Number)
deriv = derivative(A, t)
function interpolate_pullback(Δ)
(NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), deriv * Δ)
(NoTangent(), Tangent{typeof(A)}(; u = u_tangent(A, t, Δ)), sum(deriv .* Δ))
end
return _interpolate(A, t), interpolate_pullback
end
Expand Down

0 comments on commit ed2391e

Please sign in to comment.