Skip to content

Commit

Permalink
Add broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Oct 11, 2024
1 parent 2555949 commit 2b0b1ce
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 24 deletions.
51 changes: 44 additions & 7 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Representation of Taylor polynomials in array mode.
# Fields
- `value::A`: zeroth order coefficient
- `partials::NTuple{P, T}`: i-th element of this stores the i-th derivative
- `partials::NTuple{P, A}`: i-th element of this stores the i-th derivative
"""
struct TaylorArray{T, N, A <: AbstractArray{T, N}, P} <:
AbstractArray{TaylorScalar{T, P}, N}
Expand All @@ -24,15 +24,28 @@ struct TaylorArray{T, N, A <: AbstractArray{T, N}, P} <:
end

function TaylorArray{P}(value::A) where {A <: AbstractArray, P}
TaylorArray(value, ntuple(i -> zeros(eltype(value), size(value)), Val(P)))
TaylorArray(value, ntuple(i -> broadcast(zero, value), Val(P)))
end

function TaylorArray{P}(value::A, seed::A) where {A <: AbstractArray, P}
TaylorArray(
value, ntuple(i -> i == 1 ? seed : zeros(eltype(value), size(value)), Val(P)))
value, ntuple(i -> i == 1 ? seed : broadcast(zero, seed), Val(P)))
end

# Indexing
# Necessary AbstractArray interface methods for TaylorArray to work
# https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array

## 1. Invariant
for op in Symbol[:size, :strides, :eachindex, :IndexStyle]
@eval Base.$(op)(x::TaylorArray) = Base.$(op)(value(x))
end

## 2. Indexing
function Base.similar(a::TaylorArray, ::Type{<:TaylorScalar{T}}, dims::Dims) where {T}
new_value = similar(value(a), T, dims)
new_partials = map(p -> similar(p, T, dims), partials(a))
return TaylorArray(new_value, new_partials)
end

Base.@propagate_inbounds function Base.getindex(a::TaylorArray, i::Int...)
new_value = value(a)[i...]
Expand All @@ -55,7 +68,31 @@ Base.@propagate_inbounds function Base.setindex!(
return a
end

# Invariant
for op in Symbol[:size, :eachindex, :IndexStyle]
@eval Base.$(op)(x::TaylorArray) = Base.$(op)(value(x))
## 3. Broadcasting
struct TaylorArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
TaylorArrayStyle(::Val{N}) where {N} = TaylorArrayStyle{N}()
TaylorArrayStyle{M}(::Val{N}) where {N, M} = TaylorArrayStyle{N}()

Base.BroadcastStyle(::Type{<:TaylorArray{T, N}}) where {T, N} = TaylorArrayStyle{N}()
# This is added to make Zygote custom broadcasting work
# However, we might implement custom broadcasting semantics for TaylorArray in the future
# function Base.BroadcastStyle(::Type{<:Array{
# <:Tuple{TaylorScalar{T, P}, Any}, N}}) where {T, N, P}
# TaylorArrayStyle{N}()
# end

function Base.similar(
bc::Broadcast.Broadcasted{<:TaylorArrayStyle}, ::Type{ElType}) where {ElType}
A = find_taylor(bc)
similar(A, ElType, axes(bc))
end

find_taylor(bc::Broadcast.Broadcasted) = find_taylor(bc.args)
find_taylor(args::Tuple) = find_taylor(find_taylor(args[1]), Base.tail(args))
find_taylor(x) = x
find_taylor(::Tuple{}) = nothing
find_taylor(a::TaylorArray, rest) = a
function find_taylor(a::Array{<:Tuple{TaylorScalar{T, P}, Any}, N}, rest) where {T, P, N}
TaylorArray{P}(zeros(T, size(a)))
end
find_taylor(::Any, rest) = find_taylor(rest)
24 changes: 22 additions & 2 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
using Base.Broadcast: broadcasted

function rrule(::Type{TaylorScalar}, v::T, p::NTuple{N, T}) where {N, T}
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
return TaylorScalar(v, p), taylor_scalar_pullback
constructor_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
return TaylorScalar(v, p), constructor_pullback
end

function rrule(::Type{TaylorArray}, v::T, p::NTuple{N, T}) where {N, T}
constructor_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
return TaylorArray(v, p), constructor_pullback
end

function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
Expand All @@ -23,6 +28,11 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
return partials(t), value_pullback
end

function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P}
value_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄)
return partials(t), value_pullback
end

function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
i::Integer) where {N, T}
function extract_derivative_pullback(d̄)
Expand All @@ -32,6 +42,16 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
return extract_derivative(t, i), extract_derivative_pullback
end

function rrule(::typeof(Base.getindex), a::TaylorArray, i::Int...)
function getindex_pullback(t̄)
= similar(a)
ā .= zero(eltype(a))
ā[i...] =
NoTangent(), ā, map(Returns(NoTangent()), i)
end
return getindex(a, i...), getindex_pullback
end

function rrule(::typeof(*), A::AbstractMatrix{S},
t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real}
project_A = ProjectTo(A)
Expand Down
15 changes: 5 additions & 10 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function derivatives end

# Convenience wrapper for adding unit seed to the input

@inline derivative(f, x, p::Int64) = derivative(f, x, one(eltype(x)), p)
@inline derivative(f, x, p::Int64) = derivative(f, x, broadcast(one, x), p)

# Convenience wrappers for converting ps to value types
# and forward work to core APIs
Expand All @@ -42,13 +42,8 @@ function derivatives end
# Core APIs

# Added to help Zygote infer types
@inline function make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P}
TaylorScalar{P}(x, convert(T, l))
end

@inline function make_seed(x::AbstractArray{T}, l, p::Val{P}) where {T <: Real, P}
broadcast(make_seed, x, l, p)
end
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(make_seed, x, l, Val{P}())

# `derivative` API: computes the `P - 1`-th derivative of `f` at `x`
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
Expand All @@ -66,8 +61,8 @@ end
@inline derivatives(f, x, l, p::Val{P}) where {P} = f(make_seed(x, l, p))

# In-place function
@inline function derivatives(f!, y::AbstractArray{T}, x, l, p::Val{P}) where {T, P}
buffer = similar(y, TaylorScalar{T, P})
@inline function derivatives(f!, y, x, l, p::Val{P}) where {P}
buffer = similar(y, TaylorScalar{eltype(y), P})
f!(buffer, make_seed(x, l, p))
map!(value, y, buffer)
return buffer
Expand Down
2 changes: 1 addition & 1 deletion test/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@test derivative(g1, [1.0, 2.0], [1.0, 0.0], 1) 2.0

h1(x) = sum(x, dims = 1)
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2.0 2.0]
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0 1.0; 1.0 1.0], 1) [2.0 2.0]
end

@testset "I-function, O-derivative" begin
Expand Down
4 changes: 2 additions & 2 deletions test/downstream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ backend = AutoZygote()
# Matrix functions
some_matrix = [0.7 0.1; 0.4 0.2]
f(x) = sum(exp.(x), dims = 1)
dfdx1(x) = derivative(f, x, [1.0, 0.0], 1)
dfdx2(x) = derivative(f, x, [0.0, 1.0], 1)
dfdx1(x) = derivative(f, x, [1.0 1.0; 0.0 0.0], 1)
dfdx2(x) = derivative(f, x, [0.0 0.0; 1.0 1.0], 1)
res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x))
grad = DI.gradient(res, backend, some_matrix)
@test grad [1 0; 0 2] * exp.(some_matrix)
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using TaylorDiff
using Test

# include("primitive.jl")
# include("derivative.jl")
include("primitive.jl")
include("derivative.jl")
include("downstream.jl")
# include("lux.jl")

0 comments on commit 2b0b1ce

Please sign in to comment.