Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding rules and methods for hessian of sum of scalar CellFields and product of two scalar CellFields #1053

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/Fields/ApplyOptimizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,58 @@ for op in (:*,:⋅,:⊙,:⊗)
end
end

# Hessian rules
for op in (:+,:-)
@eval begin

function lazy_map(
::typeof(∇∇), a::LazyArray{<:Fill{Operation{typeof($op)}}})

f = a.args
g = map(i->lazy_map(∇∇,i),f)
lazy_map(Operation($op),g...)
end

function lazy_map(
::Broadcasting{typeof(∇∇)}, a::LazyArray{<:Fill{Broadcasting{Operation{typeof($op)}}}})

f = a.args
g = map(i->lazy_map(Broadcasting(∇∇),i),f)
lazy_map(Broadcasting(Operation($op)),g...)
end

end
end

for op in (:*,)
@eval begin

function lazy_map(
::typeof(∇∇), a::LazyArray{<:Fill{Operation{typeof($op)}}})

f = a.args
@notimplementedif length(f) != 2
g = map(i->lazy_map(gradient,i),f)
h = map(i->lazy_map(∇∇,i),f)
prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2)
lazy_map(Operation(prod_rule_hess),f...,g...,h...)
end

function lazy_map(
::Broadcasting{typeof(∇∇)}, a::LazyArray{<:Fill{Broadcasting{Operation{typeof($op)}}}})

f = a.args
@notimplementedif length(f) != 2
g = map(i->lazy_map(Broadcasting(∇),i),f)
h = map(i->lazy_map(Broadcasting(∇∇),i),f)
prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2)
lazy_map(Broadcasting(Operation(prod_rule_hess)),f...,g...,h...)
end

end
end


function lazy_map(
::Broadcasting{typeof(gradient)}, a::LazyArray{<:Fill{Broadcasting{typeof(∘)}}})

Expand Down
37 changes: 37 additions & 0 deletions src/Fields/FieldsInterfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,43 @@ for op in (:*,:⋅,:⊙,:⊗)
end
end

# Hessian (∇∇) of sum

for op in (:+,:-)
@eval begin
function ∇∇(a::OperationField{typeof($op)})
f = a.fields
g = map( ∇∇, f)
$op(g...)
end
end
end

# Hessian (∇∇) of product

function product_rule_hessian(fun,f1,f2,∇f1,∇f2,∇∇f1,∇∇f2)
msg = "Product rule not implemented for product $fun between types $(typeof(f1)) and $(typeof(f2))"
@notimplemented msg
end

function product_rule_hessian(::typeof(*),f1::Real,f2::Real,∇f1,∇f2,∇∇f1,∇∇f2)
∇∇f1*f2 + ∇∇f2*f1 + ∇f1⊗∇f2 + ∇f2⊗∇f1
end

for op in (:*,)
@eval begin
function ∇∇(a::OperationField{typeof($op)})
f = a.fields
@notimplementedif length(f) != 2
f1, f2 = f
g1, g2 = map(gradient, f)
h1, h2 = map(∇∇, f)
prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2)
Operation(prod_rule_hess)(f1,f2,g1,g2,h1,h2)
end
end
end

# Chain rule
function gradient(f::OperationField{<:Field})
a = f.op
Expand Down
28 changes: 28 additions & 0 deletions test/FieldsTests/FieldInterfacesTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,32 @@ test_field(vf,x,v(x),grad=∇(v)(x))
test_field(vt,x,zero.(v(x)),grad=zero.(∇(v)(x)),gradgrad=zero.(∇∇(v)(x)))
test_field(vf,x,v(x),grad=∇(v)(x),gradgrad=∇∇(v)(x))

# testing hessian rule for sum and product of two fields

afun(x) = x[1]^3 + x[2]^4
bfun(x) = sin(x[1])*cos(x[2])
cfun(x) = exp(x⋅x)

a = GenericField(afun)
b = GenericField(bfun)
c = GenericField(cfun)

f = Operation(+)(Operation(*)(a,b), c)
∇f = ∇(a)*b + ∇(b)*a + ∇(c)
cp = afun(p) * bfun(p) + cfun(p)
∇cp = ∇(afun)(p)*bfun(p) + ∇(bfun)(p)*afun(p) + ∇(cfun)(p)
∇∇cp = ∇∇(afun)(p) * bfun(p) + afun(p) * ∇∇(bfun)(p) + ∇(afun)(p)⊗∇(bfun)(p) + ∇(bfun)(p)⊗∇(afun)(p) + ∇∇(cfun)(p)
test_field(f,p,cp)
test_field(f,p,cp, grad=∇cp, gradgrad=∇∇cp)

test_field(f,x,f.(x))
test_field(f,x,f.(x),grad=∇(f).(x),gradgrad=∇∇(f).(x))
test_field(f,z,f.(z))
test_field(f,z,f.(z),grad=∇(f).(z),gradgrad=∇∇(f).(z))

# this one checks by taking ∇ of ∇f to see if matches with rule for ∇∇(f)
test_field(∇f, p, ∇cp, grad=∇∇cp)
test_field(∇f, x, ∇(f).(x), grad=∇∇(f).(x))
test_field(∇f, z, ∇(f).(z), grad=∇∇(f).(z))

end # module