Skip to content

Commit

Permalink
Limit Dual to Real primals, replace all Number with Real (#95)
Browse files Browse the repository at this point in the history
* Limit Dual to Real primals

* Remove more Complex-ity (pun intended)
  • Loading branch information
gdalle authored May 23, 2024
1 parent ea98aba commit c9040ba
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 48 deletions.
16 changes: 8 additions & 8 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

## Type conversions (non-dual)
for TT in (GradientTracer, ConnectivityTracer, HessianTracer)
Base.promote_rule(::Type{T}, ::Type{N}) where {T<:TT,N<:Number} = T
Base.promote_rule(::Type{N}, ::Type{T}) where {T<:TT,N<:Number} = T
Base.promote_rule(::Type{T}, ::Type{N}) where {T<:TT,N<:Real} = T
Base.promote_rule(::Type{N}, ::Type{T}) where {T<:TT,N<:Real} = T

Base.big(::Type{T}) where {T<:TT} = T
Base.widen(::Type{T}) where {T<:TT} = T

Base.convert(::Type{T}, x::Number) where {T<:TT} = empty(T)
Base.convert(::Type{T}, x::Real) where {T<:TT} = empty(T)
Base.convert(::Type{T}, t::T) where {T<:TT} = t
Base.convert(::Type{<:Number}, t::T) where {T<:TT} = t
Base.convert(::Type{<:Real}, t::T) where {T<:TT} = t

## Constants
Base.zero(::Type{T}) where {T<:TT} = empty(T)
Expand Down Expand Up @@ -42,21 +42,21 @@ function Base.promote_rule(::Type{Dual{P1, T}}, ::Type{Dual{P2, T}}) where {P1,P
PP = Base.promote_type(P1, P2) # TODO: possible method call error?
return Dual{PP,T}
end
function Base.promote_rule(::Type{Dual{P, T}}, ::Type{N}) where {P,T,N<:Number}
function Base.promote_rule(::Type{Dual{P, T}}, ::Type{N}) where {P,T,N<:Real}
PP = Base.promote_type(P, N) # TODO: possible method call error?
return Dual{PP,T}
end
function Base.promote_rule(::Type{N}, ::Type{Dual{P, T}}) where {P,T,N<:Number}
function Base.promote_rule(::Type{N}, ::Type{Dual{P, T}}) where {P,T,N<:Real}
PP = Base.promote_type(P, N) # TODO: possible method call error?
return Dual{PP,T}
end

Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{big(P),T}
Base.widen(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{widen(P),T}

Base.convert(::Type{D}, x::Number) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T))
Base.convert(::Type{D}, x::Real) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T))
Base.convert(::Type{D}, d::D) where {P,T,D<:Dual{P,T}} = d
Base.convert(::Type{N}, d::D) where {N<:Number,P,T,D<:Dual{P,T}} = Dual(convert(T, primal(d)), tracer(d))
Base.convert(::Type{N}, d::D) where {N<:Real,P,T,D<:Dual{P,T}} = Dual(convert(T, primal(d)), tracer(d))

function Base.convert(::Type{Dual{P1,T}}, d::Dual{P2,T}) where {P1,P2,T}
return Dual(convert(P1, primal(d)), tracer(d))
Expand Down
10 changes: 5 additions & 5 deletions src/overload_connectivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ function overload_connectivity_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(tx::$SCT.ConnectivityTracer, ::Number)
function $M.$op(tx::$SCT.ConnectivityTracer, ::Real)
return $SCT.connectivity_tracer_1_to_1(
tx, $SCT.is_influence_arg1_zero_global($M.$op)
)
end
function $M.$op(
dx::D, y::Number
dx::D, y::Real
) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$op(x, y)
Expand All @@ -87,13 +87,13 @@ function overload_connectivity_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(::Number, ty::$SCT.ConnectivityTracer)
function $M.$op(::Real, ty::$SCT.ConnectivityTracer)
return $SCT.connectivity_tracer_1_to_1(
ty, $SCT.is_influence_arg2_zero_global($M.$op)
)
end
function $M.$op(
x::Number, dy::D
x::Real, dy::D
) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$op(x, y)
Expand Down Expand Up @@ -142,7 +142,7 @@ end
## Special cases

## Exponent (requires extra types)
for S in (Real, Integer, Rational, Complex, Irrational{:ℯ})
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::ConnectivityTracer, ::S) = t
function Base.:^(dx::D, y::S) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
return Dual(primal(dx)^y, tracer(dx))
Expand Down
8 changes: 4 additions & 4 deletions src/overload_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ end

for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=))
@eval Base.$fn(dx::D, dy::D) where {D<:Dual} = $fn(primal(dx), primal(dy))
@eval Base.$fn(dx::D, y::Number) where {D<:Dual} = $fn(primal(dx), y)
@eval Base.$fn(x::Number, dy::D) where {D<:Dual} = $fn(x, primal(dy))
@eval Base.$fn(dx::D, y::Real) where {D<:Dual} = $fn(primal(dx), y)
@eval Base.$fn(x::Real, dy::D) where {D<:Dual} = $fn(x, primal(dy))

# Error on non-dual tracers
@eval function Base.$fn(tx::T, ty::T) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, tx))
end
@eval function Base.$fn(tx::T, y::Number) where {T<:AbstractTracer}
@eval function Base.$fn(tx::T, y::Real) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, tx))
end
@eval function Base.$fn(x::Number, ty::T) where {T<:AbstractTracer}
@eval function Base.$fn(x::Real, ty::T) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, ty))
end
end
10 changes: 5 additions & 5 deletions src/overload_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ function overload_gradient_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(tx::$SCT.GradientTracer, ::Number)
function $M.$op(tx::$SCT.GradientTracer, ::Real)
return $SCT.gradient_tracer_1_to_1(
tx, $SCT.is_firstder_arg1_zero_global($M.$op)
)
end
function $M.$op(dx::D, y::Number) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$op(dx::D, y::Real) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$op(x, y)
t_out = $SCT.gradient_tracer_1_to_1(
Expand All @@ -83,12 +83,12 @@ function overload_gradient_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(::Number, ty::$SCT.GradientTracer)
function $M.$op(::Real, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(
ty, $SCT.is_firstder_arg2_zero_global($M.$op)
)
end
function $M.$op(x::Number, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$op(x::Real, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$op(x, y)
t_out = $SCT.gradient_tracer_1_to_1(
Expand Down Expand Up @@ -136,7 +136,7 @@ end
## Special cases

## Exponent (requires extra types)
for S in (Real, Integer, Rational, Complex, Irrational{:ℯ})
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::GradientTracer, ::S) = t
Base.:^(::S, t::GradientTracer) = t

Expand Down
10 changes: 5 additions & 5 deletions src/overload_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,22 @@ function overload_hessian_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(tx::$SCT.HessianTracer, y::Number)
function $M.$op(tx::$SCT.HessianTracer, y::Real)
return $SCT.hessian_tracer_1_to_1(
tx,
$SCT.is_firstder_arg1_zero_global($M.$op),
$SCT.is_seconder_arg1_zero_global($M.$op),
)
end
function $M.$op(x::Number, ty::$SCT.HessianTracer)
function $M.$op(x::Real, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(
ty,
$SCT.is_firstder_arg2_zero_global($M.$op),
$SCT.is_seconder_arg2_zero_global($M.$op),
)
end

function $M.$op(dx::D, y::Number) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$op(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$op(x, y)
t_out = $SCT.hessian_tracer_1_to_1(
Expand All @@ -129,7 +129,7 @@ function overload_hessian_2_to_1(M, op)
)
return $SCT.Dual(p_out, t_out)
end
function $M.$op(x::Number, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$op(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$op(x, y)
t_out = $SCT.hessian_tracer_1_to_1(
Expand Down Expand Up @@ -187,7 +187,7 @@ end
## Special cases

## Exponent (requires extra types)
for S in (Real, Integer, Rational, Complex, Irrational{:ℯ})
for S in (Integer, Rational, Irrational{:ℯ})
function Base.:^(tx::T, y::S) where {T<:HessianTracer}
return T(gradient(tx), hessian(tx) (gradient(tx) × gradient(tx)))
end
Expand Down
14 changes: 7 additions & 7 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Supports [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTra
"""
trace_input(::Type{T}, x) where {T<:AbstractTracer} = trace_input(T, x, 1)

function trace_input(::Type{T}, x::Number, i::Integer) where {T<:AbstractTracer}
function trace_input(::Type{T}, x::Real, i::Integer) where {T<:AbstractTracer}
return create_tracer(T, x, i)
end
function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:AbstractTracer}
Expand All @@ -40,11 +40,11 @@ function trace_function(::Type{T}, f!, y, x) where {T<:AbstractTracer}
return xt, yt
end

to_array(x::Number) = [x]
to_array(x::Real) = [x]
to_array(x::AbstractArray) = x

# Utilities
_tracer_or_number(x::Number) = x
_tracer_or_number(x::Real) = x
_tracer_or_number(d::Dual) = tracer(d)

#====================#
Expand Down Expand Up @@ -147,7 +147,7 @@ function local_connectivity_pattern(f!, y, x, ::Type{C}=DEFAULT_VECTOR_TYPE) whe
end

function connectivity_pattern_to_mat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
xt::AbstractArray{T}, yt::AbstractArray{<:Real}
) where {T<:ConnectivityTracer}
n, m = length(xt), length(yt)
I = Int[] # row indices
Expand All @@ -166,7 +166,7 @@ function connectivity_pattern_to_mat(
end

function connectivity_pattern_to_mat(
xt::AbstractArray{D}, yt::AbstractArray{<:Number}
xt::AbstractArray{D}, yt::AbstractArray{<:Real}
) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
return connectivity_pattern_to_mat(tracer.(xt), _tracer_or_number.(yt))
end
Expand Down Expand Up @@ -258,7 +258,7 @@ function local_jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {
end

function jacobian_pattern_to_mat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
xt::AbstractArray{T}, yt::AbstractArray{<:Real}
) where {T<:GradientTracer}
n, m = length(xt), length(yt)
I = Int[] # row indices
Expand All @@ -277,7 +277,7 @@ function jacobian_pattern_to_mat(
end

function jacobian_pattern_to_mat(
xt::AbstractArray{D}, yt::AbstractArray{<:Number}
xt::AbstractArray{D}, yt::AbstractArray{<:Real}
) where {P,T<:GradientTracer,D<:Dual{P,T}}
return jacobian_pattern_to_mat(tracer.(xt), _tracer_or_number.(yt))
end
Expand Down
20 changes: 10 additions & 10 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
# Generic code expecting "regular" numbers `x` will sometimes convert them
# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `ConnectivityTracer`.
# When this happens, we create a new empty tracer with no input pattern.
function ConnectivityTracer{C}(::Number) where {C<:AbstractSet{<:Integer}}
function ConnectivityTracer{C}(::Real) where {C<:AbstractSet{<:Integer}}
return empty(ConnectivityTracer{C})
end

Expand Down Expand Up @@ -107,7 +107,7 @@ function empty(::Type{GradientTracer{G}}) where {G}
return GradientTracer{G}(empty(G))
end

function GradientTracer{G}(::Number) where {G<:AbstractSet{<:Integer}}
function GradientTracer{G}(::Real) where {G<:AbstractSet{<:Integer}}
return empty(GradientTracer{G})
end

Expand Down Expand Up @@ -172,7 +172,7 @@ function empty(::Type{HessianTracer{G,H}}) where {G,H}
end

function HessianTracer{G,H}(
::Number
::Real
) where {G<:AbstractSet{<:Integer},H<:AbstractSet{<:Tuple{Integer,Integer}}}
return empty(HessianTracer{G,H})
end
Expand All @@ -191,12 +191,12 @@ HessianTracer(t::HessianTracer) = t
"""
$(TYPEDEF)
Dual number type keeping track of the results of a primal computation as well as a tracer.
Dual `Real` number type keeping track of the results of a primal computation as well as a tracer.
## Fields
$(TYPEDFIELDS)
"""
struct Dual{P<:Number,T<:Union{ConnectivityTracer,GradientTracer,HessianTracer}} <:
struct Dual{P<:Real,T<:Union{ConnectivityTracer,GradientTracer,HessianTracer}} <:
AbstractTracer
primal::P
tracer::T
Expand All @@ -222,7 +222,7 @@ gradient(d::Dual{P,T}) where {P,T<:GradientTracer} = gradient(d.tracer)
gradient(d::Dual{P,T}) where {P,T<:HessianTracer} = gradient(d.tracer)
hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(d.tracer)

function Dual{P,T}(x::Number) where {P<:Number,T<:AbstractTracer}
function Dual{P,T}(x::Real) where {P<:Real,T<:AbstractTracer}
return Dual(convert(P, x), empty(T))
end

Expand All @@ -235,16 +235,16 @@ end
Convenience constructor for [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices.
"""
function create_tracer(::Type{Dual{P,T}}, primal::Number, index::Integer) where {P,T}
function create_tracer(::Type{Dual{P,T}}, primal::Real, index::Integer) where {P,T}
return Dual(primal, create_tracer(T, primal, index))
end

function create_tracer(::Type{GradientTracer{G}}, ::Number, index::Integer) where {G}
function create_tracer(::Type{GradientTracer{G}}, ::Real, index::Integer) where {G}
return GradientTracer{G}(sparse_vector(G, index))
end
function create_tracer(::Type{ConnectivityTracer{C}}, ::Number, index::Integer) where {C}
function create_tracer(::Type{ConnectivityTracer{C}}, ::Real, index::Integer) where {C}
return ConnectivityTracer{C}(sparse_vector(C, index))
end
function create_tracer(::Type{HessianTracer{G,H}}, ::Number, index::Integer) where {G,H}
function create_tracer(::Type{HessianTracer{G,H}}, ::Real, index::Integer) where {G,H}
return HessianTracer{G,H}(sparse_vector(G, index), empty(H))
end
4 changes: 0 additions & 4 deletions test/test_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ const SECOND_ORDER_SET_TYPES = (

# Code coverage
@test hessian_sparsity(typemax, 1, method) [0;;]
@test hessian_sparsity(x -> x^(2im), 1, method) [1;;]
@test hessian_sparsity(x -> (2im)^x, 1, method) [1;;]
@test hessian_sparsity(x -> x^(2//3), 1, method) [1;;]
@test hessian_sparsity(x -> (2//3)^x, 1, method) [1;;]
@test hessian_sparsity(x -> x^ℯ, 1, method) [1;;]
Expand Down Expand Up @@ -206,8 +204,6 @@ end

# Code coverage
@test hessian_sparsity(typemax, 1, method) [0;;]
@test hessian_sparsity(x -> x^(2im), 1, method) [1;;]
@test hessian_sparsity(x -> (2im)^x, 1, method) [1;;]
@test hessian_sparsity(x -> x^(2//3), 1, method) [1;;]
@test hessian_sparsity(x -> (2//3)^x, 1, method) [1;;]
@test hessian_sparsity(x -> x^ℯ, 1, method) [1;;]
Expand Down

0 comments on commit c9040ba

Please sign in to comment.