Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit Dual to Real primals, replace all Number with Real #95

Merged
merged 7 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

## Type conversions (non-dual)
for TT in (GradientTracer, ConnectivityTracer, HessianTracer)
Base.promote_rule(::Type{T}, ::Type{N}) where {T<:TT,N<:Number} = T
Base.promote_rule(::Type{N}, ::Type{T}) where {T<:TT,N<:Number} = T
Base.promote_rule(::Type{T}, ::Type{N}) where {T<:TT,N<:Real} = T
Base.promote_rule(::Type{N}, ::Type{T}) where {T<:TT,N<:Real} = T

Base.big(::Type{T}) where {T<:TT} = T
Base.widen(::Type{T}) where {T<:TT} = T

Base.convert(::Type{T}, x::Number) where {T<:TT} = empty(T)
Base.convert(::Type{T}, x::Real) where {T<:TT} = empty(T)
Base.convert(::Type{T}, t::T) where {T<:TT} = t
Base.convert(::Type{<:Number}, t::T) where {T<:TT} = t
Base.convert(::Type{<:Real}, t::T) where {T<:TT} = t
gdalle marked this conversation as resolved.
Show resolved Hide resolved

## Constants
Base.zero(::Type{T}) where {T<:TT} = empty(T)
Expand Down Expand Up @@ -42,21 +42,21 @@ function Base.promote_rule(::Type{Dual{P1, T}}, ::Type{Dual{P2, T}}) where {P1,P
PP = Base.promote_type(P1, P2) # TODO: possible method call error?
return Dual{PP,T}
end
function Base.promote_rule(::Type{Dual{P, T}}, ::Type{N}) where {P,T,N<:Number}
function Base.promote_rule(::Type{Dual{P, T}}, ::Type{N}) where {P,T,N<:Real}
PP = Base.promote_type(P, N) # TODO: possible method call error?
return Dual{PP,T}
end
function Base.promote_rule(::Type{N}, ::Type{Dual{P, T}}) where {P,T,N<:Number}
function Base.promote_rule(::Type{N}, ::Type{Dual{P, T}}) where {P,T,N<:Real}
PP = Base.promote_type(P, N) # TODO: possible method call error?
return Dual{PP,T}
end

Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{big(P),T}
Base.widen(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{widen(P),T}

Base.convert(::Type{D}, x::Number) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T))
Base.convert(::Type{D}, x::Real) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T))
Base.convert(::Type{D}, d::D) where {P,T,D<:Dual{P,T}} = d
Base.convert(::Type{N}, d::D) where {N<:Number,P,T,D<:Dual{P,T}} = Dual(convert(T, primal(d)), tracer(d))
Base.convert(::Type{N}, d::D) where {N<:Real,P,T,D<:Dual{P,T}} = Dual(convert(T, primal(d)), tracer(d))

function Base.convert(::Type{Dual{P1,T}}, d::Dual{P2,T}) where {P1,P2,T}
return Dual(convert(P1, primal(d)), tracer(d))
Expand Down
10 changes: 5 additions & 5 deletions src/overload_connectivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ function overload_connectivity_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(tx::$SCT.ConnectivityTracer, ::Number)
function $M.$op(tx::$SCT.ConnectivityTracer, ::Real)
return $SCT.connectivity_tracer_1_to_1(
tx, $SCT.is_influence_arg1_zero_global($M.$op)
)
end
function $M.$op(
dx::D, y::Number
dx::D, y::Real
) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$op(x, y)
Expand All @@ -87,13 +87,13 @@ function overload_connectivity_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(::Number, ty::$SCT.ConnectivityTracer)
function $M.$op(::Real, ty::$SCT.ConnectivityTracer)
return $SCT.connectivity_tracer_1_to_1(
ty, $SCT.is_influence_arg2_zero_global($M.$op)
)
end
function $M.$op(
x::Number, dy::D
x::Real, dy::D
) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$op(x, y)
Expand Down Expand Up @@ -142,7 +142,7 @@ end
## Special cases

## Exponent (requires extra types)
for S in (Real, Integer, Rational, Complex, Irrational{:ℯ})
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::ConnectivityTracer, ::S) = t
function Base.:^(dx::D, y::S) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
return Dual(primal(dx)^y, tracer(dx))
Expand Down
8 changes: 4 additions & 4 deletions src/overload_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ end

for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=))
@eval Base.$fn(dx::D, dy::D) where {D<:Dual} = $fn(primal(dx), primal(dy))
@eval Base.$fn(dx::D, y::Number) where {D<:Dual} = $fn(primal(dx), y)
@eval Base.$fn(x::Number, dy::D) where {D<:Dual} = $fn(x, primal(dy))
@eval Base.$fn(dx::D, y::Real) where {D<:Dual} = $fn(primal(dx), y)
@eval Base.$fn(x::Real, dy::D) where {D<:Dual} = $fn(x, primal(dy))

# Error on non-dual tracers
@eval function Base.$fn(tx::T, ty::T) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, tx))
end
@eval function Base.$fn(tx::T, y::Number) where {T<:AbstractTracer}
@eval function Base.$fn(tx::T, y::Real) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, tx))
end
@eval function Base.$fn(x::Number, ty::T) where {T<:AbstractTracer}
@eval function Base.$fn(x::Real, ty::T) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, ty))
end
end
10 changes: 5 additions & 5 deletions src/overload_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ function overload_gradient_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(tx::$SCT.GradientTracer, ::Number)
function $M.$op(tx::$SCT.GradientTracer, ::Real)
return $SCT.gradient_tracer_1_to_1(
tx, $SCT.is_firstder_arg1_zero_global($M.$op)
)
end
function $M.$op(dx::D, y::Number) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$op(dx::D, y::Real) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$op(x, y)
t_out = $SCT.gradient_tracer_1_to_1(
Expand All @@ -83,12 +83,12 @@ function overload_gradient_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(::Number, ty::$SCT.GradientTracer)
function $M.$op(::Real, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(
ty, $SCT.is_firstder_arg2_zero_global($M.$op)
)
end
function $M.$op(x::Number, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$op(x::Real, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$op(x, y)
t_out = $SCT.gradient_tracer_1_to_1(
Expand Down Expand Up @@ -136,7 +136,7 @@ end
## Special cases

## Exponent (requires extra types)
for S in (Real, Integer, Rational, Complex, Irrational{:ℯ})
gdalle marked this conversation as resolved.
Show resolved Hide resolved
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::GradientTracer, ::S) = t
Base.:^(::S, t::GradientTracer) = t

Expand Down
10 changes: 5 additions & 5 deletions src/overload_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,22 @@ function overload_hessian_2_to_1(M, op)
return $SCT.Dual(p_out, t_out)
end

function $M.$op(tx::$SCT.HessianTracer, y::Number)
function $M.$op(tx::$SCT.HessianTracer, y::Real)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
return $SCT.hessian_tracer_1_to_1(
tx,
$SCT.is_firstder_arg1_zero_global($M.$op),
$SCT.is_seconder_arg1_zero_global($M.$op),
)
end
function $M.$op(x::Number, ty::$SCT.HessianTracer)
function $M.$op(x::Real, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(
ty,
$SCT.is_firstder_arg2_zero_global($M.$op),
$SCT.is_seconder_arg2_zero_global($M.$op),
)
end

function $M.$op(dx::D, y::Number) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$op(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$op(x, y)
t_out = $SCT.hessian_tracer_1_to_1(
Expand All @@ -129,7 +129,7 @@ function overload_hessian_2_to_1(M, op)
)
return $SCT.Dual(p_out, t_out)
end
function $M.$op(x::Number, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$op(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$op(x, y)
t_out = $SCT.hessian_tracer_1_to_1(
Expand Down Expand Up @@ -187,7 +187,7 @@ end
## Special cases

## Exponent (requires extra types)
for S in (Real, Integer, Rational, Complex, Irrational{:ℯ})
gdalle marked this conversation as resolved.
Show resolved Hide resolved
for S in (Integer, Rational, Irrational{:ℯ})
function Base.:^(tx::T, y::S) where {T<:HessianTracer}
return T(gradient(tx), hessian(tx) ∪ (gradient(tx) × gradient(tx)))
end
Expand Down
14 changes: 7 additions & 7 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Supports [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTra
"""
trace_input(::Type{T}, x) where {T<:AbstractTracer} = trace_input(T, x, 1)

function trace_input(::Type{T}, x::Number, i::Integer) where {T<:AbstractTracer}
function trace_input(::Type{T}, x::Real, i::Integer) where {T<:AbstractTracer}
gdalle marked this conversation as resolved.
Show resolved Hide resolved
return create_tracer(T, x, i)
end
function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:AbstractTracer}
Expand All @@ -40,11 +40,11 @@ function trace_function(::Type{T}, f!, y, x) where {T<:AbstractTracer}
return xt, yt
end

to_array(x::Number) = [x]
to_array(x::Real) = [x]
to_array(x::AbstractArray) = x

# Utilities
_tracer_or_number(x::Number) = x
_tracer_or_number(x::Real) = x
_tracer_or_number(d::Dual) = tracer(d)

#====================#
Expand Down Expand Up @@ -147,7 +147,7 @@ function local_connectivity_pattern(f!, y, x, ::Type{C}=DEFAULT_VECTOR_TYPE) whe
end

function connectivity_pattern_to_mat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
xt::AbstractArray{T}, yt::AbstractArray{<:Real}
) where {T<:ConnectivityTracer}
n, m = length(xt), length(yt)
I = Int[] # row indices
Expand All @@ -166,7 +166,7 @@ function connectivity_pattern_to_mat(
end

function connectivity_pattern_to_mat(
xt::AbstractArray{D}, yt::AbstractArray{<:Number}
xt::AbstractArray{D}, yt::AbstractArray{<:Real}
) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
gdalle marked this conversation as resolved.
Show resolved Hide resolved
return connectivity_pattern_to_mat(tracer.(xt), _tracer_or_number.(yt))
end
Expand Down Expand Up @@ -258,7 +258,7 @@ function local_jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {
end

function jacobian_pattern_to_mat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
xt::AbstractArray{T}, yt::AbstractArray{<:Real}
) where {T<:GradientTracer}
gdalle marked this conversation as resolved.
Show resolved Hide resolved
n, m = length(xt), length(yt)
I = Int[] # row indices
Expand All @@ -277,7 +277,7 @@ function jacobian_pattern_to_mat(
end

function jacobian_pattern_to_mat(
xt::AbstractArray{D}, yt::AbstractArray{<:Number}
xt::AbstractArray{D}, yt::AbstractArray{<:Real}
) where {P,T<:GradientTracer,D<:Dual{P,T}}
return jacobian_pattern_to_mat(tracer.(xt), _tracer_or_number.(yt))
end
Expand Down
20 changes: 10 additions & 10 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
# Generic code expecting "regular" numbers `x` will sometimes convert them
# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `ConnectivityTracer`.
# When this happens, we create a new empty tracer with no input pattern.
function ConnectivityTracer{C}(::Number) where {C<:AbstractSet{<:Integer}}
function ConnectivityTracer{C}(::Real) where {C<:AbstractSet{<:Integer}}
return empty(ConnectivityTracer{C})
end

Expand Down Expand Up @@ -107,7 +107,7 @@ function empty(::Type{GradientTracer{G}}) where {G}
return GradientTracer{G}(empty(G))
end

function GradientTracer{G}(::Number) where {G<:AbstractSet{<:Integer}}
function GradientTracer{G}(::Real) where {G<:AbstractSet{<:Integer}}
return empty(GradientTracer{G})
end

Expand Down Expand Up @@ -172,7 +172,7 @@ function empty(::Type{HessianTracer{G,H}}) where {G,H}
end

function HessianTracer{G,H}(
::Number
::Real
) where {G<:AbstractSet{<:Integer},H<:AbstractSet{<:Tuple{Integer,Integer}}}
return empty(HessianTracer{G,H})
end
Expand All @@ -191,12 +191,12 @@ HessianTracer(t::HessianTracer) = t
"""
$(TYPEDEF)

Dual number type keeping track of the results of a primal computation as well as a tracer.
Dual `Real` number type keeping track of the results of a primal computation as well as a tracer.

## Fields
$(TYPEDFIELDS)
"""
struct Dual{P<:Number,T<:Union{ConnectivityTracer,GradientTracer,HessianTracer}} <:
struct Dual{P<:Real,T<:Union{ConnectivityTracer,GradientTracer,HessianTracer}} <:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all we need.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought so too. But then I ran the tests and found that leaving the ::Number dispatches led to some method ambiguities, typically between

<=(::Real, ::Real)
<=(::SCT.Dual, ::Number)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even argue that it would be more informative to have MethodErrors due to type restrictions occur here rather than in Base.promote_rule.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the discussion below, we shouldn't promise complex number support anyway. Tracers are real numbers, and if people want to go beyond they can do Complex{Tracer}

AbstractTracer
primal::P
tracer::T
Expand All @@ -211,7 +211,7 @@ gradient(d::Dual{P,T}) where {P,T<:GradientTracer} = gradient(d.tracer)
gradient(d::Dual{P,T}) where {P,T<:HessianTracer} = gradient(d.tracer)
hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(d.tracer)

function Dual{P,T}(x::Number) where {P<:Number,T<:AbstractTracer}
function Dual{P,T}(x::Real) where {P<:Real,T<:AbstractTracer}
return Dual(convert(P, x), empty(T))
end
gdalle marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -224,16 +224,16 @@ end

Convenience constructor for [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices.
"""
function create_tracer(::Type{Dual{P,T}}, primal::Number, index::Integer) where {P,T}
function create_tracer(::Type{Dual{P,T}}, primal::Real, index::Integer) where {P,T}
return Dual(primal, create_tracer(T, primal, index))
end

function create_tracer(::Type{GradientTracer{G}}, ::Number, index::Integer) where {G}
function create_tracer(::Type{GradientTracer{G}}, ::Real, index::Integer) where {G}
return GradientTracer{G}(sparse_vector(G, index))
end
function create_tracer(::Type{ConnectivityTracer{C}}, ::Number, index::Integer) where {C}
function create_tracer(::Type{ConnectivityTracer{C}}, ::Real, index::Integer) where {C}
return ConnectivityTracer{C}(sparse_vector(C, index))
end
function create_tracer(::Type{HessianTracer{G,H}}, ::Number, index::Integer) where {G,H}
function create_tracer(::Type{HessianTracer{G,H}}, ::Real, index::Integer) where {G,H}
return HessianTracer{G,H}(sparse_vector(G, index), empty(H))
end
gdalle marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 4 additions & 4 deletions test/test_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ const SECOND_ORDER_SET_TYPES = (

# Code coverage
@test hessian_sparsity(typemax, 1, method) ≈ [0;;]
@test hessian_sparsity(x -> x^(2im), 1, method) ≈ [1;;]
@test hessian_sparsity(x -> (2im)^x, 1, method) ≈ [1;;]
@test hessian_sparsity(x -> real(x^(2im)), 1, method) ≈ [1;;]
@test hessian_sparsity(x -> real((2im)^x), 1, method) ≈ [1;;]
@test hessian_sparsity(x -> x^(2//3), 1, method) ≈ [1;;]
@test hessian_sparsity(x -> (2//3)^x, 1, method) ≈ [1;;]
@test hessian_sparsity(x -> x^ℯ, 1, method) ≈ [1;;]
Expand Down Expand Up @@ -205,8 +205,8 @@ end

# Code coverage
@test hessian_sparsity(typemax, 1, method) ≈ [0;;]
@test hessian_sparsity(x -> x^(2im), 1, method) ≈ [1;;]
@test hessian_sparsity(x -> (2im)^x, 1, method) ≈ [1;;]
@test hessian_sparsity(x -> real(x^(2im)), 1, method) ≈ [1;;]
@test hessian_sparsity(x -> real((2im)^x), 1, method) ≈ [1;;]
@test hessian_sparsity(x -> x^(2//3), 1, method) ≈ [1;;]
@test hessian_sparsity(x -> (2//3)^x, 1, method) ≈ [1;;]
@test hessian_sparsity(x -> x^ℯ, 1, method) ≈ [1;;]
Expand Down
Loading