diff --git a/src/Fields/ApplyOptimizations.jl b/src/Fields/ApplyOptimizations.jl index 76dd091fd..842a68621 100644 --- a/src/Fields/ApplyOptimizations.jl +++ b/src/Fields/ApplyOptimizations.jl @@ -194,6 +194,58 @@ for op in (:*,:⋅,:⊙,:⊗) end end +# Hessian rules +for op in (:+,:-) + @eval begin + + function lazy_map( + ::typeof(∇∇), a::LazyArray{<:Fill{Operation{typeof($op)}}}) + + f = a.args + g = map(i->lazy_map(∇∇,i),f) + lazy_map(Operation($op),g...) + end + + function lazy_map( + ::Broadcasting{typeof(∇∇)}, a::LazyArray{<:Fill{Broadcasting{Operation{typeof($op)}}}}) + + f = a.args + g = map(i->lazy_map(Broadcasting(∇∇),i),f) + lazy_map(Broadcasting(Operation($op)),g...) + end + + end +end + +for op in (:*,) + @eval begin + + function lazy_map( + ::typeof(∇∇), a::LazyArray{<:Fill{Operation{typeof($op)}}}) + + f = a.args + @notimplementedif length(f) != 2 + g = map(i->lazy_map(gradient,i),f) + h = map(i->lazy_map(∇∇,i),f) + prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2) + lazy_map(Operation(prod_rule_hess),f...,g...,h...) + end + + function lazy_map( + ::Broadcasting{typeof(∇∇)}, a::LazyArray{<:Fill{Broadcasting{Operation{typeof($op)}}}}) + + f = a.args + @notimplementedif length(f) != 2 + g = map(i->lazy_map(Broadcasting(∇),i),f) + h = map(i->lazy_map(Broadcasting(∇∇),i),f) + prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2) + lazy_map(Broadcasting(Operation(prod_rule_hess)),f...,g...,h...) + end + + end +end + + function lazy_map( ::Broadcasting{typeof(gradient)}, a::LazyArray{<:Fill{Broadcasting{typeof(∘)}}}) diff --git a/src/Fields/FieldsInterfaces.jl b/src/Fields/FieldsInterfaces.jl index bdad6a111..d248fc56a 100644 --- a/src/Fields/FieldsInterfaces.jl +++ b/src/Fields/FieldsInterfaces.jl @@ -427,6 +427,43 @@ for op in (:*,:⋅,:⊙,:⊗) end end +# Hessian (∇∇) of sum + +for op in (:+,:-) + @eval begin + function ∇∇(a::OperationField{typeof($op)}) + f = a.fields + g = map( ∇∇, f) + $op(g...) + end + end +end + +# Hessian (∇∇) of product + +function product_rule_hessian(fun,f1,f2,∇f1,∇f2,∇∇f1,∇∇f2) + msg = "Product rule not implemented for product $fun between types $(typeof(f1)) and $(typeof(f2))" + @notimplemented msg +end + +function product_rule_hessian(::typeof(*),f1::Real,f2::Real,∇f1,∇f2,∇∇f1,∇∇f2) + ∇∇f1*f2 + ∇∇f2*f1 + ∇f1⊗∇f2 + ∇f2⊗∇f1 +end + +for op in (:*,) + @eval begin + function ∇∇(a::OperationField{typeof($op)}) + f = a.fields + @notimplementedif length(f) != 2 + f1, f2 = f + g1, g2 = map(gradient, f) + h1, h2 = map(∇∇, f) + prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2) + Operation(prod_rule_hess)(f1,f2,g1,g2,h1,h2) + end + end +end + # Chain rule function gradient(f::OperationField{<:Field}) a = f.op diff --git a/test/FieldsTests/FieldInterfacesTests.jl b/test/FieldsTests/FieldInterfacesTests.jl index 21cfd99f4..b24c10a8e 100644 --- a/test/FieldsTests/FieldInterfacesTests.jl +++ b/test/FieldsTests/FieldInterfacesTests.jl @@ -376,4 +376,32 @@ test_field(vf,x,v(x),grad=∇(v)(x)) test_field(vt,x,zero.(v(x)),grad=zero.(∇(v)(x)),gradgrad=zero.(∇∇(v)(x))) test_field(vf,x,v(x),grad=∇(v)(x),gradgrad=∇∇(v)(x)) +# testing hessian rule for sum and product of two fields + +afun(x) = x[1]^3 + x[2]^4 +bfun(x) = sin(x[1])*cos(x[2]) +cfun(x) = exp(x⋅x) + +a = GenericField(afun) +b = GenericField(bfun) +c = GenericField(cfun) + +f = Operation(+)(Operation(*)(a,b), c) +∇f = ∇(a)*b + ∇(b)*a + ∇(c) +cp = afun(p) * bfun(p) + cfun(p) +∇cp = ∇(afun)(p)*bfun(p) + ∇(bfun)(p)*afun(p) + ∇(cfun)(p) +∇∇cp = ∇∇(afun)(p) * bfun(p) + afun(p) * ∇∇(bfun)(p) + ∇(afun)(p)⊗∇(bfun)(p) + ∇(bfun)(p)⊗∇(afun)(p) + ∇∇(cfun)(p) +test_field(f,p,cp) +test_field(f,p,cp, grad=∇cp, gradgrad=∇∇cp) + +test_field(f,x,f.(x)) +test_field(f,x,f.(x),grad=∇(f).(x),gradgrad=∇∇(f).(x)) +test_field(f,z,f.(z)) +test_field(f,z,f.(z),grad=∇(f).(z),gradgrad=∇∇(f).(z)) + +# this one checks by taking ∇ of ∇f to see if matches with rule for ∇∇(f) +test_field(∇f, p, ∇cp, grad=∇∇cp) +test_field(∇f, x, ∇(f).(x), grad=∇∇(f).(x)) +test_field(∇f, z, ∇(f).(z), grad=∇∇(f).(z)) + end # module