From 2b003fdd53260d65d5dc983dda89fea2172c1aa8 Mon Sep 17 00:00:00 2001 From: Kishore Nori Date: Sun, 17 Nov 2024 21:50:31 +1100 Subject: [PATCH 1/5] rules for hessian of sum of fields and sum of product two fields --- src/Fields/FieldsInterfaces.jl | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/Fields/FieldsInterfaces.jl b/src/Fields/FieldsInterfaces.jl index bdad6a111..aa6dc70f9 100644 --- a/src/Fields/FieldsInterfaces.jl +++ b/src/Fields/FieldsInterfaces.jl @@ -427,6 +427,43 @@ for op in (:*,:⋅,:⊙,:⊗) end end +# Hessian (∇∇) of sum + +for op in (:+,:-) + @eval begin + function ∇∇(a::OperationField{typeof($op)}) + f = a.fields + g = map( ∇∇, f) + $op(g...) + end + end +end + +# Hessian (∇∇) of product + +function product_rule_hessian(fun,f1,f2,∇f1,∇f2,∇∇f1,∇∇f2) + msg = "Product rule not implemented for product $fun between types $(typeof(f1)) and $(typeof(f2))" + @notimplemented msg +end + +function product_rule_hessian(::typeof(*),f1::Real,f2::Real,∇f1,∇f2,∇∇f1,∇∇f2) + ∇∇f1*f2 + ∇∇f2*f1 + 2*∇f1⊗∇f2 +end + +for op in (:*,) + @eval begin + function ∇∇(a::OperationField{typeof($op)}) + f = a.fields + @notimplementedif length(f) != 2 + f1, f2 = f + g1, g2 = map(gradient, f) + h1, h2 = map(∇∇, f) + prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2) + Operation(prod_rule_hess)(f1,f2,g1,g2,h1,h2) + end + end +end + # Chain rule function gradient(f::OperationField{<:Field}) a = f.op From c037699a055669cf9903015f8a9ca3fdaf3025e1 Mon Sep 17 00:00:00 2001 From: Kishore Nori Date: Sun, 17 Nov 2024 21:51:26 +1100 Subject: [PATCH 2/5] adding hessian rule for broadcast operations with sum and product of two fields --- src/Fields/ApplyOptimizations.jl | 52 ++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/Fields/ApplyOptimizations.jl b/src/Fields/ApplyOptimizations.jl index 76dd091fd..842a68621 100644 --- a/src/Fields/ApplyOptimizations.jl +++ b/src/Fields/ApplyOptimizations.jl @@ -194,6 +194,58 @@ for op in (:*,:⋅,:⊙,:⊗) end end +# Hessian rules +for op in (:+,:-) + @eval begin + + function lazy_map( + ::typeof(∇∇), a::LazyArray{<:Fill{Operation{typeof($op)}}}) + + f = a.args + g = map(i->lazy_map(∇∇,i),f) + lazy_map(Operation($op),g...) + end + + function lazy_map( + ::Broadcasting{typeof(∇∇)}, a::LazyArray{<:Fill{Broadcasting{Operation{typeof($op)}}}}) + + f = a.args + g = map(i->lazy_map(Broadcasting(∇∇),i),f) + lazy_map(Broadcasting(Operation($op)),g...) + end + + end +end + +for op in (:*,) + @eval begin + + function lazy_map( + ::typeof(∇∇), a::LazyArray{<:Fill{Operation{typeof($op)}}}) + + f = a.args + @notimplementedif length(f) != 2 + g = map(i->lazy_map(gradient,i),f) + h = map(i->lazy_map(∇∇,i),f) + prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2) + lazy_map(Operation(prod_rule_hess),f...,g...,h...) + end + + function lazy_map( + ::Broadcasting{typeof(∇∇)}, a::LazyArray{<:Fill{Broadcasting{Operation{typeof($op)}}}}) + + f = a.args + @notimplementedif length(f) != 2 + g = map(i->lazy_map(Broadcasting(∇),i),f) + h = map(i->lazy_map(Broadcasting(∇∇),i),f) + prod_rule_hess(F1,F2,G1,G2,H1,H2) = product_rule_hessian($op,F1,F2,G1,G2,H1,H2) + lazy_map(Broadcasting(Operation(prod_rule_hess)),f...,g...,h...) + end + + end +end + + function lazy_map( ::Broadcasting{typeof(gradient)}, a::LazyArray{<:Fill{Broadcasting{typeof(∘)}}}) From 558339a971de914aa2d7f59d5bc49f8239811490 Mon Sep 17 00:00:00 2001 From: Kishore Nori Date: Sun, 17 Nov 2024 22:24:20 +1100 Subject: [PATCH 3/5] tests for hessian rule for sum and prod of two fields, with broadcast versions --- test/FieldsTests/FieldInterfacesTests.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/FieldsTests/FieldInterfacesTests.jl b/test/FieldsTests/FieldInterfacesTests.jl index 21cfd99f4..fdef98b93 100644 --- a/test/FieldsTests/FieldInterfacesTests.jl +++ b/test/FieldsTests/FieldInterfacesTests.jl @@ -376,4 +376,26 @@ test_field(vf,x,v(x),grad=∇(v)(x)) test_field(vt,x,zero.(v(x)),grad=zero.(∇(v)(x)),gradgrad=zero.(∇∇(v)(x))) test_field(vf,x,v(x),grad=∇(v)(x),gradgrad=∇∇(v)(x)) +# testing hessian rule for sum and product of two fields + +afun(x) = x[1]^3 + x[2]^4 +bfun(x) = sin(x[1])*cos(x[2]) +cfun(x) = exp(x⋅x) + +a = GenericField(afun) +b = GenericField(bfun) +c = GenericField(cfun) + +f = Operation(+)(Operation(*)(a,b), c) +cp = afun(p) * bfun(p) + cfun(p) +∇cp = ∇(afun)(p)*bfun(p) + ∇(bfun)(p)*afun(p) + ∇(cfun)(p) +∇∇cp = ∇∇(afun)(p) * bfun(p) + afun(p) * ∇∇(bfun)(p) + 2*∇(afun)(p)⊗∇(bfun)(p) + ∇∇(cfun)(p) +test_field(f,p,cp) +test_field(f,p,cp, grad=∇cp, gradgrad=∇∇cp) + +test_field(f,x,f.(x)) +test_field(f,x,f.(x),grad=∇(f).(x),gradgrad=∇∇(f).(x)) +test_field(f,z,f.(z)) +test_field(f,z,f.(z),grad=∇(f).(z),gradgrad=∇∇(f).(z)) + end # module From 757af98fa12c05845361629d096396a2632ecfe2 Mon Sep 17 00:00:00 2001 From: Kishore Nori Date: Sun, 17 Nov 2024 22:56:54 +1100 Subject: [PATCH 4/5] fixing mistake in the prod rule of hessian of two fields --- src/Fields/FieldsInterfaces.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Fields/FieldsInterfaces.jl b/src/Fields/FieldsInterfaces.jl index aa6dc70f9..d248fc56a 100644 --- a/src/Fields/FieldsInterfaces.jl +++ b/src/Fields/FieldsInterfaces.jl @@ -447,7 +447,7 @@ function product_rule_hessian(fun,f1,f2,∇f1,∇f2,∇∇f1,∇∇f2) end function product_rule_hessian(::typeof(*),f1::Real,f2::Real,∇f1,∇f2,∇∇f1,∇∇f2) - ∇∇f1*f2 + ∇∇f2*f1 + 2*∇f1⊗∇f2 + ∇∇f1*f2 + ∇∇f2*f1 + ∇f1⊗∇f2 + ∇f2⊗∇f1 end for op in (:*,) From 8d4d28fa848c2fbf24fa0289f05fc6342a6b6589 Mon Sep 17 00:00:00 2001 From: Kishore Nori Date: Sun, 17 Nov 2024 22:57:39 +1100 Subject: [PATCH 5/5] adding more rigorous tests for product rule of hessian of two fields --- test/FieldsTests/FieldInterfacesTests.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/FieldsTests/FieldInterfacesTests.jl b/test/FieldsTests/FieldInterfacesTests.jl index fdef98b93..b24c10a8e 100644 --- a/test/FieldsTests/FieldInterfacesTests.jl +++ b/test/FieldsTests/FieldInterfacesTests.jl @@ -387,9 +387,10 @@ b = GenericField(bfun) c = GenericField(cfun) f = Operation(+)(Operation(*)(a,b), c) +∇f = ∇(a)*b + ∇(b)*a + ∇(c) cp = afun(p) * bfun(p) + cfun(p) ∇cp = ∇(afun)(p)*bfun(p) + ∇(bfun)(p)*afun(p) + ∇(cfun)(p) -∇∇cp = ∇∇(afun)(p) * bfun(p) + afun(p) * ∇∇(bfun)(p) + 2*∇(afun)(p)⊗∇(bfun)(p) + ∇∇(cfun)(p) +∇∇cp = ∇∇(afun)(p) * bfun(p) + afun(p) * ∇∇(bfun)(p) + ∇(afun)(p)⊗∇(bfun)(p) + ∇(bfun)(p)⊗∇(afun)(p) + ∇∇(cfun)(p) test_field(f,p,cp) test_field(f,p,cp, grad=∇cp, gradgrad=∇∇cp) @@ -398,4 +399,9 @@ test_field(f,x,f.(x),grad=∇(f).(x),gradgrad=∇∇(f).(x)) test_field(f,z,f.(z)) test_field(f,z,f.(z),grad=∇(f).(z),gradgrad=∇∇(f).(z)) +# this one checks by taking ∇ of ∇f to see if matches with rule for ∇∇(f) +test_field(∇f, p, ∇cp, grad=∇∇cp) +test_field(∇f, x, ∇(f).(x), grad=∇∇(f).(x)) +test_field(∇f, z, ∇(f).(z), grad=∇∇(f).(z)) + end # module