From 680ed234f190bc69d95b0c6319794954d32ba4a4 Mon Sep 17 00:00:00 2001 From: mhar100 <56600917+mhar100@users.noreply.github.com> Date: Tue, 14 Jun 2022 14:09:31 -0400 Subject: [PATCH] added type for differentiate in chain_rules.jl 1. We should make sure that differentiate is being called on a polynomial p. (It is possible there are other meanings of differentiate for other types in the code.) 2. The function "pullback" is very generic, so I added some clarity. --- src/chain_rules.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/chain_rules.jl b/src/chain_rules.jl index fdeb75d8..2e3b9919 100644 --- a/src/chain_rules.jl +++ b/src/chain_rules.jl @@ -18,10 +18,10 @@ end function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x) return differentiate(p, x), differentiate(Δp, x) end -function pullback(Δdpdx, x) +function pullback_differentiate_polynomial(Δdpdx, x) return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent() end -function ChainRulesCore.rrule(::typeof(differentiate), p, x) +function ChainRulesCore.rrule(::typeof(differentiate), p::APL, x) dpdx = differentiate(p, x) - return dpdx, Base.Fix2(pullback, x) + return dpdx, Base.Fix2(pullback_differentiate_polynomial, x) end