Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Deep ritz to NeuralPDE.jl #857

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ include("ode_solve.jl")
# include("rode_solve.jl")
include("dae_solve.jl")
include("transform_inf_integral.jl")
include("deep_ritz.jl")
include("discretize.jl")
include("neural_adapter.jl")
include("advancedHMC_MCMC.jl")
include("BPINN_ode.jl")
include("PDE_BPINN.jl")
include("dgm.jl")


export NNODE, NNDAE,
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
Expand All @@ -66,6 +68,6 @@ export NNODE, NNDAE,
MiniMaxAdaptiveLoss, LogOptions,
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
BPINNsolution, BayesianPINN,
DeepGalerkin
DeepGalerkin, DeepRitz

end # module
135 changes: 135 additions & 0 deletions src/deep_ritz.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
DeepRitz(chain,
strategy;
init_params = nothing,
phi = nothing,
param_estim = false,
additional_loss = nothing,
adaptive_loss = nothing,
logger = nothing,
log_options = LogOptions(),
iteration = nothing,
kwargs...)

A `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a
`PDESystem` into an `OptimizationProblem` for the Deep Ritz method.

## Positional Arguments

* `chain`: a vector of Lux/Flux chains with a d-dimensional input and a
1-dimensional output corresponding to each of the dependent variables. Note that this
specification respects the order of the dependent variables as specified in the PDESystem.
Flux chains will be converted to Lux internally using `adapt(FromFluxAdaptor(false, false), chain)`.
* `strategy`: determines which training strategy will be used. See the Training Strategy
documentation for more details.

## Keyword Arguments

* `init_params`: the initial parameters of the neural networks. If `init_params` is not
given, then the neural network default parameters are used. Note that for Lux, the default
will convert to Float64.
* `phi`: a trial solution, specified as `phi(x,p)` where `x` is the coordinates vector for
the dependent variable and `p` are the weights of the phi function (generally the weights
of the neural network defining `phi`). By default, this is generated from the `chain`. This
should only be used to more directly impose functional information in the training problem,
for example imposing the boundary condition by the test function formulation.
* `adaptive_loss`: the choice for the adaptive loss function. See the
[adaptive loss page](@ref adaptive_loss) for more details. Defaults to no adaptivity.
* `additional_loss`: a function `additional_loss(phi, θ, p_)` where `phi` are the neural
network trial solutions, `θ` are the weights of the neural network(s), and `p_` are the
hyperparameters of the `OptimizationProblem`. If `param_estim = true`, then `θ` additionally
contains the parameters of the differential equation appended to the end of the vector.
* `param_estim`: whether the parameters of the differential equation should be included in
the values sent to the `additional_loss` function. Defaults to `false`.
* `logger`: ?? needs docs
* `log_options`: ?? why is this separate from the logger?
* `iteration`: used to control the iteration counter???
* `
"""
struct DeepRitz{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN
chain::Any
strategy::T
init_params::P
phi::PH
derivative::DER
param_estim::PE
additional_loss::AL
adaptive_loss::ADA
logger::LOG
log_options::LogOptions
iteration::Vector{Int64}
self_increment::Bool
multioutput::Bool
kwargs::K
end

function DeepRitz(chain, strategy; kwargs...)
pinn = NeuralPDE.PhysicsInformedNN(chain, strategy);

Check warning on line 67 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L66-L67

Added lines #L66 - L67 were not covered by tests

DeepRitz([

Check warning on line 69 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L69

Added line #L69 was not covered by tests
getfield(pinn, k) for k in propertynames(pinn)]...)
end

"""
prob = discretize(pde_system::PDESystem, discretization::DeepRitz)

For 2nd order PDEs, transforms a symbolic description of a ModelingToolkit-defined `PDESystem`
using Deep-Ritz me and generates an `OptimizationProblem` for [Optimization.jl](https://docs.sciml.ai/Optimization/stable/)
whose solution is the solution to the PDE.
"""
function SciMLBase.discretize(pde_system::PDESystem, discretization::DeepRitz)
modify_deep_ritz!(pde_system);
pinnrep = symbolic_discretize(pde_system, discretization)
f = OptimizationFunction(pinnrep.loss_functions.full_loss_function,

Check warning on line 83 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L80-L83

Added lines #L80 - L83 were not covered by tests
Optimization.AutoZygote())
Optimization.OptimizationProblem(f, pinnrep.flat_init_params)

Check warning on line 85 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L85

Added line #L85 was not covered by tests
end


"""
modify_deep_ritz!(pde_system::PDESystem)

Performs the checks for Deep-Ritz method and modifies the pde in the `pde_system`.
"""
function modify_deep_ritz!(pde_system::PDESystem)

Check warning on line 94 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L94

Added line #L94 was not covered by tests

if length(pde_system.eqs) > 1
error("Deep Ritz solves for only one dependent variable")

Check warning on line 97 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L96-L97

Added lines #L96 - L97 were not covered by tests
end

ind_vars = pde_system.ivs
dep_var = pde_system.dvs[1]

Check warning on line 101 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L100-L101

Added lines #L100 - L101 were not covered by tests

n_vars = length(ind_vars)

Check warning on line 103 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L103

Added line #L103 was not covered by tests

expr = first(pde_system.eqs).lhs - first(pde_system.eqs).rhs

Check warning on line 105 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L105

Added line #L105 was not covered by tests

Ds = [Differential(ind_var) for ind_var in ind_vars];
D²s = [Differential(ind_var)^2 for ind_var in ind_vars];
laplacian = (sum([d²s(dep_var) for d²s in D²s]) ~ 0).lhs;

Check warning on line 109 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L107-L109

Added lines #L107 - L109 were not covered by tests

expr_new = modify_laplacian(expr, laplacian, n_vars);

Check warning on line 111 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L111

Added line #L111 was not covered by tests

rhs = - expr_new * dep_var
lhs = (sum([(ds(dep_var))^2 for ds in Ds]) ~ 0).lhs;

Check warning on line 114 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L113-L114

Added lines #L113 - L114 were not covered by tests

pde_system.eqs[1] = lhs ~ rhs
return nothing

Check warning on line 117 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L116-L117

Added lines #L116 - L117 were not covered by tests
end


function modify_laplacian(expr, Δ, n_vars)
expr_new = expr - Δ;
if (operation(expr_new)!= +) || (length(expr_new.dict) + n_vars == length(expr.dict))

Check warning on line 123 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L121-L123

Added lines #L121 - L123 were not covered by tests
# positive coeff of laplacian
return expr_new

Check warning on line 125 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L125

Added line #L125 was not covered by tests
else
expr_new = expr + Δ
if length(expr_new.dict) == n_vars + length(expr.dict)

Check warning on line 128 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L127-L128

Added lines #L127 - L128 were not covered by tests
# negative coeff of laplacian
return expr_new

Check warning on line 130 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L130

Added line #L130 was not covered by tests
else
error("Incorrect form of PDE given")

Check warning on line 132 in src/deep_ritz.jl

View check run for this annotation

Codecov / codecov/patch

src/deep_ritz.jl#L132

Added line #L132 was not covered by tests
end
end
end
2 changes: 1 addition & 1 deletion src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
pde_loss_functions,
bc_loss_functions)

function get_likelihood_estimate_function(discretization::PhysicsInformedNN)
function get_likelihood_estimate_function(discretization::Union{PhysicsInformedNN, DeepRitz})
function full_loss_function(θ, p)
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
pde_losses = [pde_loss_function(θ) for pde_loss_function in pde_loss_functions]
Expand Down
54 changes: 54 additions & 0 deletions test/deep_ritz_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using NeuralPDE, Test

using ModelingToolkit, Optimization, OptimizationOptimisers, Distributions, MethodOfLines,
OrdinaryDiffEq
import ModelingToolkit: Interval, infimum, supremum
using Lux #: tanh, identity

@testset "Poisson's equation" begin
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2

# 2D PDE
eq = Dxx(u(x, y)) + Dyy(u(x, y)) ~ -sin(pi * x) * sin(pi * y)

# Initial and boundary conditions
bcs = [u(0, y) ~ 0.0, u(1, y) ~ -sin(pi * 1) * sin(pi * y),
u(x, 0) ~ 0.0, u(x, 1) ~ -sin(pi * x) * sin(pi * 1)]
# Space and time domains
domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)]

strategy = QuasiRandomTraining(256, minibatch = 32)
hid = 40
chain_ = Lux.Chain(Lux.Dense(2, hid, Lux.σ), Lux.Dense(hid, hid, Lux.σ),
Lux.Dense(hid, 1))
discretization = DeepRitz(chain_, strategy);

@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
prob = discretize(pde_system, discretization)

global iter = 0
callback = function (p, l)
global iter += 1
if iter % 50 == 0
println("$iter => $l")
end
return false
end

res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 500)
prob = remake(prob, u0 = res.u)
res = Optimization.solve(prob, Adam(0.001); callback = callback, maxiters = 200)
phi = discretization.phi

xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2)

u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys],
(length(xs), length(ys)))
u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys],
(length(xs), length(ys)))
@test u_predict≈u_real atol=0.1
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,10 @@ end
include("dgm_test.jl")
end
end

if GROUP == "All" || GROUP == "Deep-Ritz"
@time @safetestset "Deep Ritz method" begin
include("deep_ritz_test.jl")
end
end
end
Loading