-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Symbolic AD of ScalarNonlinearFunction #2533
Comments
I started hacking something for this: module SymbolicAD
import MacroTools
import MathOptInterface as MOI
derivative(::Real, ::MOI.VariableIndex) = false
function derivative(f::MOI.VariableIndex, x::MOI.VariableIndex)
return ifelse(f == x, true, false)
end
function derivative(
f::MOI.ScalarAffineFunction{T},
x::MOI.VariableIndex,
) where {T}
ret = zero(T)
for term in f.terms
if term.variable == x
ret += term.coefficient
end
end
return ret
end
function derivative(
f::MOI.ScalarQuadraticFunction{T},
x::MOI.VariableIndex,
) where {T}
constant = zero(T)
for term in f.affine_terms
if term.variable == x
constant += term.coefficient
end
end
aff_terms = MOI.ScalarAffineTerm{T}[]
for q_term in f.quadratic_terms
if q_term.variable_1 == q_term.variable_2 == x
push!(aff_terms, MOI.ScalarAffineTerm(q_term.coefficient, x))
elseif q_term.variable_1 == x
push!(
aff_terms,
MOI.ScalarAffineTerm(q_term.coefficient, q_term.variable_2),
)
elseif q_term.variable_2 == x
push!(
aff_terms,
MOI.ScalarAffineTerm(q_term.coefficient, q_term.variable_1),
)
end
end
return MOI.ScalarAffineFunction(aff_terms, constant)
end
function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
if length(f.args) == 1
u = only(f.args)
if f.head == :+
return derivative(u, x)
elseif f.head == :-
return MOI.ScalarNonlinearFunction(:-, Any[derivative(u, x)])
elseif f.head == :abs
scale = MOI.ScalarNonlinearFunction(
:ifelse,
Any[MOI.ScalarNonlinearFunction(:>=, Any[u, 0]), 1, -1],
)
return MOI.ScalarNonlinearFunction(:*, Any[scale, derivative(u, x)])
elseif f.head == :sign
return false
end
for (key, df, _) in MOI.Nonlinear.SYMBOLIC_UNIVARIATE_EXPRESSIONS
if key == f.head
# The chain rule: d(f(g(x))) / dx = f'(g(x)) * g'(x)
u = only(f.args)
df_du = MacroTools.postwalk(df) do node
if node === :x
return u
elseif Meta.isexpr(node, :call)
op, args = node.args[1], node.args[2:end]
return MOI.ScalarNonlinearFunction(op, args)
end
return node
end
du_dx = derivative(u, x)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
end
end
end
if f.head == :+
# d/dx(+(args...)) = +(d/dx args)
args = Any[derivative(arg, x) for arg in f.args]
return MOI.ScalarNonlinearFunction(:+, args)
elseif f.head == :-
# d/dx(-(args...)) = -(d/dx args)
# Note that - is not unary here because that wouuld be caught above.
args = Any[derivative(arg, x) for arg in f.args]
return MOI.ScalarNonlinearFunction(:-, args)
elseif f.head == :*
# Product rule: d/dx(*(args...)) = sum(d{i}/dx * args\{i})
sum_terms = Any[]
for i in 1:length(f.args)
g = MOI.ScalarNonlinearFunction(:*, copy(f.args))
g.args[i] = derivative(f.args[i], x)
push!(sum_terms, g)
end
return MOI.ScalarNonlinearFunction(:+, sum_terms)
elseif f.head == :^
@assert length(f.args) == 2
u, p = f.args
du_dx = derivative(u, x)
dp_dx = derivative(p, x)
if _iszero(dp_dx)
# p is constant and does not depend on x
df_du = MOI.ScalarNonlinearFunction(
:*,
Any[p, MOI.ScalarNonlinearFunction(:^, Any[u, p-1])],
)
du_dx = derivative(u, x)
return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
else
# u(x)^p(x)
end
elseif f.head == :/
# Quotient rule: d/dx(u / v) = (du/dx)*v - u*(dv/dx)) / v^2
@assert length(f.args) == 2
u, v = f.args
du_dx, dv_dx = derivative(u, x), derivative(v, x)
return MOI.ScalarNonlinearFunction(
:/,
Any[
MOI.ScalarNonlinearFunction(
:-,
Any[
MOI.ScalarNonlinearFunction(:*, Any[du_dx, v]),
MOI.ScalarNonlinearFunction(:*, Any[u, dv_dx]),
],
),
MOI.ScalarNonlinearFunction(:^, Any[v, 2]),
],
)
elseif f.head == :ifelse
@assert length(f.args) == 3
# Pick the derivative of the active branch
return MOI.ScalarNonlinearFunction(
:ifelse,
Any[f.args[1], derivative(f.args[2], x), derivative(f.args[3], x)],
)
elseif f.head == :atan
# TODO
elseif f.head == :min
g = derivative(f.args[end], x)
for i in length(f.args)-1:-1:1
g = MOI.ScalarNonlinearFunction(
:ifelse,
Any[
MOI.ScalarNonlinearFunction(:(<=), Any[f.args[i], f]),
derivative(f.args[i], x),
g,
],
)
end
return g
elseif f.head == :max
g = derivative(f.args[end], x)
for i in length(f.args)-1:-1:1
g = MOI.ScalarNonlinearFunction(
:ifelse,
Any[
MOI.ScalarNonlinearFunction(:(>=), Any[f.args[i], f]),
derivative(f.args[i], x),
g,
],
)
end
return g
elseif f.head in (:(>=), :(<=), :(<), :(>), :(==))
return false
end
err = MOI.UnsupportedNonlinearOperator(
f.head,
"the operator does not support symbolic differentiation",
)
return throw(err)
end
simplify(f) = f
function simplify(f::MOI.ScalarAffineFunction{T}) where {T}
f = MOI.Utilities.canonical(f)
if isempty(f.terms)
return f.constant
end
return f
end
function simplify(f::MOI.ScalarQuadraticFunction{T}) where {T}
f = MOI.Utilities.canonical(f)
if isempty(f.quadratic_terms)
return simplify(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
end
return f
end
function _eval_if_constant(f::MOI.ScalarNonlinearFunction)
if all(_isnum, f.args) && hasproperty(Base, f.head)
return getproperty(Base, f.head)(f.args...)
end
return f
end
_eval_if_constant(f) = f
function simplify(f::MOI.ScalarNonlinearFunction)
for i in 1:length(f.args)
f.args[i] = simplify(f.args[i])
end
return _eval_if_constant(simplify(Val(f.head), f))
end
simplify(::Val, f::MOI.ScalarNonlinearFunction) = f
_iszero(x::Union{Bool,Integer,Float64}) = iszero(x)
_iszero(::Any) = false
_isone(x::Union{Bool,Integer,Float64}) = isone(x)
_isone(::Any) = false
_isnum(::Union{Bool,Integer,Float64}) = true
_isnum(::Any) = false
_isexpr(::Any, ::Symbol, n::Int = 0) = false
_isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol) = f.head == head
function _isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol, n::Int)
return _isexpr(f, head) && length(f.args) == n
end
function simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
new_args = Any[]
first_constant = 0
for arg in f.args
if _isexpr(arg, :*)
append!(new_args, arg.args)
elseif _iszero(arg)
return false
elseif _isone(arg)
# nothing
elseif arg isa Real
if first_constant == 0
push!(new_args, arg)
first_constant = length(new_args)
else
new_args[first_constant] *= arg
end
else
push!(new_args, arg)
end
end
if isempty(new_args)
return true
elseif length(new_args) == 1
return only(new_args)
end
return MOI.ScalarNonlinearFunction(:*, new_args)
end
function simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
if length(f.args) == 1
return only(f.args)
elseif length(f.args) == 2 && _isexpr(f.args[2], :-, 1)
return MOI.ScalarNonlinearFunction(
:-,
Any[f.args[1], f.args[2].args[1]],
)
end
new_args = Any[]
first_constant = 0
for arg in f.args
if _isexpr(arg, :+)
append!(new_args, arg.args)
elseif _iszero(arg)
# nothing
elseif arg isa Real
if first_constant == 0
push!(new_args, arg)
first_constant = length(new_args)
else
new_args[first_constant] += arg
end
else
push!(new_args, arg)
end
end
if isempty(new_args)
return false
elseif length(new_args) == 1
return only(new_args)
end
return MOI.ScalarNonlinearFunction(:+, new_args)
end
function simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
if length(f.args) == 1
if _isexpr(f.args[1], :-, 1)
# -(-(x)) => x
return f.args[1].args[1]
end
elseif length(f.args) == 2
if _iszero(f.args[1])
# 0 - x => -x
return MOI.ScalarNonlinearFunction(:-, Any[f.args[2]])
elseif _iszero(f.args[2])
# x - 0 => x
return f.args[1]
elseif f.args[1] == f.args[2]
# x - x => 0
return false
elseif _isexpr(f.args[2], :-, 1)
# x - -(y) => x + y
return MOI.ScalarNonlinearFunction(
:+,
Any[f.args[1], f.args[2].args[1]],
)
end
end
return f
end
function simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
if _iszero(f.args[2])
# x^0 => 1
return true
elseif _isone(f.args[2])
# x^1 => x
return f.args[1]
elseif _iszero(f.args[1])
# 0^x => 0
return false
elseif _isone(f.args[1])
# 1^x => 1
return true
end
return f
end
function variables(f::MOI.AbstractScalarFunction)
ret = MOI.VariableIndex[]
variables!(ret, f)
return ret
end
variables(::Real) = MOI.VariableIndex[]
variables!(ret, ::Real) = nothing
function variables!(ret, f::MOI.VariableIndex)
if !(f in ret)
push!(ret, f)
end
return
end
function variables!(ret, f::MOI.ScalarAffineTerm)
if !(f.variable in ret)
push!(ret, f.variable)
end
return
end
function variables!(ret, f::MOI.ScalarAffineFunction)
for term in f.terms
variables!(ret, term)
end
return
end
function variables!(ret, f::MOI.ScalarQuadraticTerm)
if !(f.variable_1 in ret)
push!(ret, f.variable_1)
end
if !(f.variable_2 in ret)
push!(ret, f.variable_2)
end
return
end
function variables!(ret, f::MOI.ScalarQuadraticFunction)
for term in f.affine_terms
variables!(ret, term)
end
for q_term in f.quadratic_terms
variables!(ret, q_term)
end
return
end
function variables!(ret, f::MOI.ScalarNonlinearFunction)
for arg in f.args
variables!(ret, arg)
end
return
end
gradient(::Real) = Dict{MOI.VariableIndex,Any}()
function gradient(f::MOI.AbstractScalarFunction)
return Dict{MOI.VariableIndex,Any}(
x => simplify(derivative(f, x)) for x in variables(f)
)
end
end
using JuMP, Test
function test_derivative()
model = Model()
@variable(model, x)
@variable(model, y)
@variable(model, z)
@testset "$f" for (f, fp) in Any[
# derivative(::Real, ::MOI.VariableIndex)
1.0=>0.0,
1.23=>0.0,
# derivative(f::MOI.VariableIndex, x::MOI.VariableIndex)
x=>1.0,
y=>0.0,
# derivative(f::MOI.ScalarAffineFunction{T}, x::MOI.VariableIndex)
1.0*x=>1.0,
1.0*x+2.0=>1.0,
2.0*x+2.0=>2.0,
2.0*x+y+2.0=>2.0,
2.0*x+y+z+2.0=>2.0,
# derivative(f::MOI.ScalarQuadraticFunction{T}, x::MOI.VariableIndex)
QuadExpr(1.0 * x)=>1.0,
QuadExpr(1.0 * x + 0.0 * y)=>1.0,
x*y=>1.0*y,
y*x=>1.0*y,
x^2=>2.0*x,
x^2+3x+4=>2.0*x+3.0,
(x-1.0)^2=>2.0*(x-1),
(3*x+1.0)^2=>6.0*(3x+1),
# Univariate
# f.head == :+
@force_nonlinear(+x)=>1,
@force_nonlinear(+sin(x))=>cos(x),
# f.head == :-
@force_nonlinear(-sin(x))=>-cos(x),
# f.head == :abs
@force_nonlinear(
abs(sin(x))
)=>op_ifelse(op_greater_than_or_equal_to(sin(x), 0), 1, -1)*cos(x),
# f.head == :sign
sign(x)=>false,
# SYMBOLIC_UNIVARIATE_EXPRESSIONS
sin(x)=>cos(x),
cos(x)=>-sin(x),
log(x)=>1/x,
log(2x)=>1/(2x)*2.0,
# f.head == :+
sin(x)+cos(x)=>cos(x)-sin(x),
# f.head == :-
sin(x)-cos(x)=>cos(x)+sin(x),
# f.head == :*
@force_nonlinear(*(x, y, z))=>@force_nonlinear(*(y, z)),
@force_nonlinear(*(y, x, z))=>@force_nonlinear(*(y, z)),
@force_nonlinear(*(y, z, x))=>@force_nonlinear(*(y, z)),
# :^
sin(x)^2=>@force_nonlinear(*(2.0, sin(x), cos(x))),
sin(x)^1=>cos(x),
# :/
@force_nonlinear(/(x, 2))=>0.5,
@force_nonlinear(
x^2 / (x + 1)
)=>@force_nonlinear((*(2, x, x + 1) - x^2) / (x + 1)^2),
# :ifelse
op_ifelse(z, x^2, x)=>op_ifelse(z, 2x, 1),
# :atan
# :min
min(x, x^2)=>op_ifelse(op_less_than_or_equal_to(x, min(x, x^2)), 1, 2x),
# :max
max(
x,
x^2,
)=>op_ifelse(op_greater_than_or_equal_to(x, max(x, x^2)), 1, 2x),
# comparisons
op_greater_than_or_equal_to(x, y)=>false,
op_equal_to(x, y)=>false,
]
g = SymbolicAD.derivative(moi_function(f), index(x))
h = SymbolicAD.simplify(g)
if !(h ≈ moi_function(fp))
@show h
@show f
end
@test h ≈ moi_function(fp)
end
return
end
function test_gradient()
model = Model()
@variable(model, x)
@variable(model, y)
@variable(model, z)
@testset "$f" for (f, fp) in Any[
# ::Real
1.0=>Dict(),
# ::AffExpr
x=>Dict(x => 1),
x+y=>Dict(x => 1, y => 1),
2x+y=>Dict(x => 2, y => 1),
2x+3y+1=>Dict(x => 2, y => 3),
# ::QuadExpr
2x^2+3y+z=>Dict(x => 4x, y => 3, z => 1),
# ::NonlinearExpr
sin(x)=>Dict(x => cos(x)),
sin(x + y)=>Dict(x => cos(x + y), y => cos(x + y)),
sin(x + 2y)=>Dict(x => cos(x + 2y), y => cos(x + 2y) * 2),
]
g = SymbolicAD.gradient(moi_function(f))
h = Dict{MOI.VariableIndex,Any}(
index(k) => moi_function(v) for (k, v) in fp
)
@test length(g) == length(h)
for k in keys(g)
@test g[k] ≈ h[k]
end
end
return
end
function test_simplify()
model = Model()
@variable(model, x)
@variable(model, y)
@variable(model, z)
@testset "$f" for (f, fp) in Any[
# simplify(f)
x=>x,
# simplify(f::MOI.ScalarAffineFunction{T})
AffExpr(2.0)=>2.0,
# simplify(f::MOI.ScalarQuadraticFunction{T})
QuadExpr(x + 1)=>x+1,
# simplify(f::MOI.ScalarNonlinearFunction)
@force_nonlinear(sin(*(3, x^0)))=>sin(3),
sin(log(x))=>sin(log(x)),
op_ifelse(z, x, 0)=>op_ifelse(z, x, 0),
# simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
@force_nonlinear(*(x, *(y, z)))=>@force_nonlinear(*(x, y, z)),
@force_nonlinear(
*(x, *(y, z, *(x, 2)))
)=>@force_nonlinear(*(x, y, z, x, 2)),
@force_nonlinear(*(x, 3, 2))=>@force_nonlinear(*(x, 6)),
@force_nonlinear(*(3, x, 2))=>@force_nonlinear(*(6, x)),
@force_nonlinear(*(x, 1))=>x,
@force_nonlinear(*(-(x, x), 1))=>0,
# simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
@force_nonlinear(+(x, +(y, z)))=>@force_nonlinear(+(x, y, z)),
+(sin(x), -cos(x))=>sin(x)-cos(x),
@force_nonlinear(+(x, 1, 2))=>@force_nonlinear(+(x, 3)),
@force_nonlinear(+(1, x, 2))=>@force_nonlinear(+(3, x)),
@force_nonlinear(+(x, 0))=>x,
@force_nonlinear(+(-(x, x), 0))=>0,
# simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
@force_nonlinear(-(-(x)))=>x,
@force_nonlinear(-(x, 0))=>x,
@force_nonlinear(-(0, x))=>@force_nonlinear(-x),
@force_nonlinear(-(x, x))=>0,
@force_nonlinear(-(x, -y))=>@force_nonlinear(x + y),
@force_nonlinear(-(x, y))=>@force_nonlinear(x - y),
# simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
@force_nonlinear(^(x, 0))=>1,
@force_nonlinear(^(x, 1))=>x,
@force_nonlinear(^(0, x))=>0,
@force_nonlinear(^(1, x))=>1,
x^y=>x^y,
]
g = SymbolicAD.simplify(moi_function(f))
if !(g ≈ moi_function(fp))
@show f
@show g
end
@test g ≈ moi_function(fp)
end
return
end
function test_variable()
model = Model()
@variable(model, x)
@variable(model, y)
@variable(model, z)
@testset "$f" for (f, fp) in Any[
# ::Real
1.0=>[],
# ::VariableRef,
x=>[x],
# ::AffExpr
AffExpr(2.0)=>[],
x+1=>[x],
2x+1=>[x],
2x+y+1=>[x, y],
y+1+z=>[y, z],
# ::QuadExpr
zero(QuadExpr)=>[],
QuadExpr(x + 1)=>[x],
QuadExpr(x + 1 + y)=>[x, y],
x^2=>[x],
x^2+x=>[x],
x^2+y=>[y, x],
x*y=>[x, y],
y*x=>[y, x],
# ::NonlinearExpr
sin(x)=>[x],
sin(x + y)=>[x, y],
sin(x)*cos(y)=>[x, y],
]
@test SymbolicAD.variables(moi_function(f)) == index.(fp)
end
return
end
@testset "SymbolicAD" begin
@testset "derivative" begin
test_derivative()
end
@testset "simplify" begin
test_simplify()
end
@testset "variable" begin
test_variable()
end
@testset "gradient" begin
test_gradient()
end
end
nothing I think the trick for integrating this into MathOptSymbolicAD is to have an efficient interpreter that re-uses expression values across the primal and derivatives evaluation. The symbolic expression trees are always going to be fundamentally limited. |
Thinking on this, I should probably merge this first into MathOptSymbolicAD.jl, get it working, and then we can add MathOptSymbolicAD as |
This has come up quite a few times, so I think we need this.
I don't know what the right API is. Perhaps:
The use case for this would be:
It's okay for this to have all the usual issues with symbolic AD.
The text was updated successfully, but these errors were encountered: