From 2b0b1ce000abefa282439864bcb23583091f0f1c Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Fri, 4 Oct 2024 15:18:55 -0400 Subject: [PATCH] Add broadcasting --- src/array.jl | 51 +++++++++++++++++++++++++++++++++++++++------- src/chainrules.jl | 24 ++++++++++++++++++++-- src/derivative.jl | 15 +++++--------- test/derivative.jl | 2 +- test/downstream.jl | 4 ++-- test/runtests.jl | 4 ++-- 6 files changed, 76 insertions(+), 24 deletions(-) diff --git a/src/array.jl b/src/array.jl index af1d225..1ace577 100644 --- a/src/array.jl +++ b/src/array.jl @@ -8,7 +8,7 @@ Representation of Taylor polynomials in array mode. # Fields - `value::A`: zeroth order coefficient -- `partials::NTuple{P, T}`: i-th element of this stores the i-th derivative +- `partials::NTuple{P, A}`: i-th element of this stores the i-th derivative """ struct TaylorArray{T, N, A <: AbstractArray{T, N}, P} <: AbstractArray{TaylorScalar{T, P}, N} @@ -24,15 +24,28 @@ struct TaylorArray{T, N, A <: AbstractArray{T, N}, P} <: end function TaylorArray{P}(value::A) where {A <: AbstractArray, P} - TaylorArray(value, ntuple(i -> zeros(eltype(value), size(value)), Val(P))) + TaylorArray(value, ntuple(i -> broadcast(zero, value), Val(P))) end function TaylorArray{P}(value::A, seed::A) where {A <: AbstractArray, P} TaylorArray( - value, ntuple(i -> i == 1 ? seed : zeros(eltype(value), size(value)), Val(P))) + value, ntuple(i -> i == 1 ? seed : broadcast(zero, seed), Val(P))) end -# Indexing +# Necessary AbstractArray interface methods for TaylorArray to work +# https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array + +## 1. Invariant +for op in Symbol[:size, :strides, :eachindex, :IndexStyle] + @eval Base.$(op)(x::TaylorArray) = Base.$(op)(value(x)) +end + +## 2. Indexing +function Base.similar(a::TaylorArray, ::Type{<:TaylorScalar{T}}, dims::Dims) where {T} + new_value = similar(value(a), T, dims) + new_partials = map(p -> similar(p, T, dims), partials(a)) + return TaylorArray(new_value, new_partials) +end Base.@propagate_inbounds function Base.getindex(a::TaylorArray, i::Int...) new_value = value(a)[i...] @@ -55,7 +68,31 @@ Base.@propagate_inbounds function Base.setindex!( return a end -# Invariant -for op in Symbol[:size, :eachindex, :IndexStyle] - @eval Base.$(op)(x::TaylorArray) = Base.$(op)(value(x)) +## 3. Broadcasting +struct TaylorArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end +TaylorArrayStyle(::Val{N}) where {N} = TaylorArrayStyle{N}() +TaylorArrayStyle{M}(::Val{N}) where {N, M} = TaylorArrayStyle{N}() + +Base.BroadcastStyle(::Type{<:TaylorArray{T, N}}) where {T, N} = TaylorArrayStyle{N}() +# This is added to make Zygote custom broadcasting work +# However, we might implement custom broadcasting semantics for TaylorArray in the future +# function Base.BroadcastStyle(::Type{<:Array{ +# <:Tuple{TaylorScalar{T, P}, Any}, N}}) where {T, N, P} +# TaylorArrayStyle{N}() +# end + +function Base.similar( + bc::Broadcast.Broadcasted{<:TaylorArrayStyle}, ::Type{ElType}) where {ElType} + A = find_taylor(bc) + similar(A, ElType, axes(bc)) +end + +find_taylor(bc::Broadcast.Broadcasted) = find_taylor(bc.args) +find_taylor(args::Tuple) = find_taylor(find_taylor(args[1]), Base.tail(args)) +find_taylor(x) = x +find_taylor(::Tuple{}) = nothing +find_taylor(a::TaylorArray, rest) = a +function find_taylor(a::Array{<:Tuple{TaylorScalar{T, P}, Any}, N}, rest) where {T, P, N} + TaylorArray{P}(zeros(T, size(a))) end +find_taylor(::Any, rest) = find_taylor(rest) diff --git a/src/chainrules.jl b/src/chainrules.jl index ac66842..2cb36e2 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -2,8 +2,13 @@ import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out using Base.Broadcast: broadcasted function rrule(::Type{TaylorScalar}, v::T, p::NTuple{N, T}) where {N, T} - taylor_scalar_pullback(t̄) = NoTangent(), value(t̄), partials(t̄) - return TaylorScalar(v, p), taylor_scalar_pullback + constructor_pullback(t̄) = NoTangent(), value(t̄), partials(t̄) + return TaylorScalar(v, p), constructor_pullback +end + +function rrule(::Type{TaylorArray}, v::T, p::NTuple{N, T}) where {N, T} + constructor_pullback(t̄) = NoTangent(), value(t̄), partials(t̄) + return TaylorArray(v, p), constructor_pullback end function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T} @@ -23,6 +28,11 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T} return partials(t), value_pullback end +function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P} + value_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄) + return partials(t), value_pullback +end + function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, i::Integer) where {N, T} function extract_derivative_pullback(d̄) @@ -32,6 +42,16 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, return extract_derivative(t, i), extract_derivative_pullback end +function rrule(::typeof(Base.getindex), a::TaylorArray, i::Int...) + function getindex_pullback(t̄) + ā = similar(a) + ā .= zero(eltype(a)) + ā[i...] = t̄ + NoTangent(), ā, map(Returns(NoTangent()), i) + end + return getindex(a, i...), getindex_pullback +end + function rrule(::typeof(*), A::AbstractMatrix{S}, t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real} project_A = ProjectTo(A) diff --git a/src/derivative.jl b/src/derivative.jl index dcee92d..3d85119 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -27,7 +27,7 @@ function derivatives end # Convenience wrapper for adding unit seed to the input -@inline derivative(f, x, p::Int64) = derivative(f, x, one(eltype(x)), p) +@inline derivative(f, x, p::Int64) = derivative(f, x, broadcast(one, x), p) # Convenience wrappers for converting ps to value types # and forward work to core APIs @@ -42,13 +42,8 @@ function derivatives end # Core APIs # Added to help Zygote infer types -@inline function make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} - TaylorScalar{P}(x, convert(T, l)) -end - -@inline function make_seed(x::AbstractArray{T}, l, p::Val{P}) where {T <: Real, P} - broadcast(make_seed, x, l, p) -end +@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l) +@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(make_seed, x, l, Val{P}()) # `derivative` API: computes the `P - 1`-th derivative of `f` at `x` @inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative( @@ -66,8 +61,8 @@ end @inline derivatives(f, x, l, p::Val{P}) where {P} = f(make_seed(x, l, p)) # In-place function -@inline function derivatives(f!, y::AbstractArray{T}, x, l, p::Val{P}) where {T, P} - buffer = similar(y, TaylorScalar{T, P}) +@inline function derivatives(f!, y, x, l, p::Val{P}) where {P} + buffer = similar(y, TaylorScalar{eltype(y), P}) f!(buffer, make_seed(x, l, p)) map!(value, y, buffer) return buffer diff --git a/test/derivative.jl b/test/derivative.jl index 3c00ff6..e5f1705 100644 --- a/test/derivative.jl +++ b/test/derivative.jl @@ -10,7 +10,7 @@ @test derivative(g1, [1.0, 2.0], [1.0, 0.0], 1) ≈ 2.0 h1(x) = sum(x, dims = 1) - @test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) ≈ [2.0 2.0] + @test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0 1.0; 1.0 1.0], 1) ≈ [2.0 2.0] end @testset "I-function, O-derivative" begin diff --git a/test/downstream.jl b/test/downstream.jl index 17af58d..48184bd 100644 --- a/test/downstream.jl +++ b/test/downstream.jl @@ -27,8 +27,8 @@ backend = AutoZygote() # Matrix functions some_matrix = [0.7 0.1; 0.4 0.2] f(x) = sum(exp.(x), dims = 1) - dfdx1(x) = derivative(f, x, [1.0, 0.0], 1) - dfdx2(x) = derivative(f, x, [0.0, 1.0], 1) + dfdx1(x) = derivative(f, x, [1.0 1.0; 0.0 0.0], 1) + dfdx2(x) = derivative(f, x, [0.0 0.0; 1.0 1.0], 1) res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x)) grad = DI.gradient(res, backend, some_matrix) @test grad ≈ [1 0; 0 2] * exp.(some_matrix) diff --git a/test/runtests.jl b/test/runtests.jl index 83609b4..2088824 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ using TaylorDiff using Test -# include("primitive.jl") -# include("derivative.jl") +include("primitive.jl") +include("derivative.jl") include("downstream.jl") # include("lux.jl")