Skip to content

Commit

Permalink
Merge pull request #905 from AstitvaAggarwal/BPINN_pde
Browse files Browse the repository at this point in the history
improved BPINN solvers
  • Loading branch information
ChrisRackauckas authored Nov 5, 2024
2 parents 1555142 + 19c074c commit 6f4580f
Show file tree
Hide file tree
Showing 7 changed files with 576 additions and 136 deletions.
18 changes: 10 additions & 8 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
phystd = [0.05], phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (; n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (; Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric),
Expand Down Expand Up @@ -86,6 +86,7 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
param <: Union{Nothing, Vector{<:Distribution}}
l2std::Vector{Float64}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
physdt::Float64
MCMCkwargs <: NamedTuple
Expand All @@ -100,18 +101,18 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
verbose::Bool
end

function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 1000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
dataset = [nothing], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,),
nchains = 1, init_params = nothing,
phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCkwargs = (n_leapfrog = 30,), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false, autodiff = false, progress = false, verbose = false)
chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
return BNNODE(chain, kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd,
dataset, physdt, MCMCkwargs, nchains, init_params, Adaptorkwargs,
phynewstd, dataset, physdt, MCMCkwargs, nchains, init_params, Adaptorkwargs,
Integratorkwargs, numensemble, estim_collocate, autodiff, progress, verbose)
end

Expand Down Expand Up @@ -157,7 +158,7 @@ end
function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt = nothing,
timeseries_errors = true, save_everystep = true, adaptive = false,
abstol = 1.0f-6, reltol = 1.0f-3, verbose = false, saveat = 1 / 50.0,
maxiters = nothing, numensemble = floor(Int, alg.draw_samples / 3))
maxiters = nothing)
(; chain, param, strategy, draw_samples, numensemble, verbose) = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
Expand All @@ -168,7 +169,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt

mcmcchain, samples, statistics = ahmc_bayesian_pinn_ode(
prob, chain; strategy, alg.dataset, alg.draw_samples, alg.init_params,
alg.physdt, alg.l2std, alg.phystd, alg.priorsNNw, param, alg.nchains, alg.autodiff,
alg.physdt, alg.l2std, alg.phystd, alg.phynewstd,
alg.priorsNNw, param, alg.nchains, alg.autodiff,
Kernel = alg.kernel, alg.Adaptorkwargs, alg.Integratorkwargs,
alg.MCMCkwargs, alg.progress, alg.verbose, alg.estim_collocate)

Expand All @@ -178,7 +180,7 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt

θinit, st = LuxCore.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]
for i in (draw_samples - numensemble):draw_samples]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
Expand Down
149 changes: 136 additions & 13 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,90 @@
dataset <: Union{Nothing, Vector{<:Matrix{<:Real}}}
priors <: Vector{<:Distribution}
allstd::Vector{Vector{Float64}}
phynewstd::Vector{Float64}
names::Tuple
extraparams::Int
init_params <: Union{AbstractVector, NamedTuple, ComponentArray}
full_loglikelihood
L2_loss2
Φ
end

function LogDensityProblems.logdensity(ltd::PDELogTargetDensity, θ)
# for parameter estimation neccesarry to use multioutput case
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) + priorlogpdf(ltd, θ) +
L2LossData(ltd, θ)
if ltd.L2_loss2 === nothing
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) +
priorlogpdf(ltd, θ) + L2LossData(ltd, θ)
else
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) +
priorlogpdf(ltd, θ) + L2LossData(ltd, θ) + ltd.L2_loss2(setparameters(ltd, θ), ltd.phynewstd)
end
end

# you get a vector of losses
function get_lossy(pinnrep, dataset, Dict_differentials)
eqs = pinnrep.eqs
depvars = pinnrep.depvars #depvar order is same as dataset

# Dict_differentials is filled with Differential operator => diff_i key-value pairs
# masking operation
eqs_new = SymbolicUtils.substitute.(eqs, Ref(Dict_differentials))

to_subs, tobe_subs = get_symbols(dataset, depvars, eqs)

# for values of all depvars at corresponding indvar values in dataset, create dictionaries {Dict(x(t) => 1.0496435863173237, y(t) => 1.9227770685615337)}
# In each Dict, num form of depvar is key to its value at certain coords of indvars, n_dicts = n_rows_dataset(or n_indvar_coords_dataset)
eq_subs = [Dict(tobe_subs[depvar] => to_subs[depvar][i] for depvar in depvars)
for i in 1:size(dataset[1][:, 1])[1]]

# for each dataset point(eq_sub dictionary), substitute in masked equations
# n_collocated_equations = n_rows_dataset(or n_indvar_coords_dataset)
masked_colloc_equations = [[Symbolics.substitute(eq, eq_sub) for eq in eqs_new]
for eq_sub in eq_subs]
# now we have vector of dataset depvar's collocated equations

# reverse dict for re-substituting values of Differential(t)(u(t)) etc
rev_Dict_differentials = Dict(value => key for (key, value) in Dict_differentials)

# unmask Differential terms in masked_colloc_equations
colloc_equations = [Symbolics.substitute.(
masked_colloc_equation, Ref(rev_Dict_differentials))
for masked_colloc_equation in masked_colloc_equations]
# nested vector of data_pde_loss_functions (as in discretize.jl)
# each sub vector has dataset's indvar coord's datafree_colloc_loss_function, n_subvectors = n_rows_dataset(or n_indvar_coords_dataset)
# zip each colloc equation with args for each build_loss call per equation vector
data_colloc_loss_functions = [[build_loss_function(pinnrep, eq, pde_indvar)
for (eq, pde_indvar, integration_indvar) in zip(
colloc_equation,
pinnrep.pde_indvars,
pinnrep.pde_integration_vars)]
for colloc_equation in colloc_equations]

return data_colloc_loss_functions
end

function get_symbols(dataset, depvars, eqs)
# take only values of depvars from dataset
depvar_vals = [dataset_i[:, 1] for dataset_i in dataset]
# order of pinnrep.depvars, depvar_vals, BayesianPINN.dataset must be same
to_subs = Dict(depvars .=> depvar_vals)

numform_vars = Symbolics.get_variables.(eqs)
Eq_vars = unique(reduce(vcat, numform_vars))
# got equation's depvar num format {x(t)} for use in substitute()

tobe_subs = Dict()
for a in depvars
for i in Eq_vars
expr = toexpr(i)
if (expr isa Expr) && (expr.args[1] == a)
tobe_subs[a] = i
end
end
end
# depvar symbolic and num format got, tobe_subs : Dict{Any, Any}(:y => y(t), :x => x(t))

return to_subs, tobe_subs
end

@views function setparameters(ltd::PDELogTargetDensity, θ)
Expand Down Expand Up @@ -180,8 +253,8 @@ end
"""
ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 2.0), param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
Kernel = HMC(0.1, 30), Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
Expand Down Expand Up @@ -210,6 +283,7 @@ end
each dependant variable of interest.
* `phystd`: Vector of standard deviations of BPINN prediction against Chosen Underlying PDE
equations.
* `phynewstd`: Vector of standard deviations of new loss term.
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen PDE's parameter's Distributions in case of Inverse problems.
Expand All @@ -235,14 +309,53 @@ end
"""
function ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 2.0), param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
Kernel = HMC(0.1, 30), Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
numensemble = floor(Int, draw_samples / 3), Dict_differentials = nothing, progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization)
dataset_pde, dataset_bc = discretization.dataset

newloss = if Dict_differentials isa Nothing
nothing
else
data_colloc_loss_functions = get_lossy(pinnrep, dataset_pde, Dict_differentials)
# size = number of indvar coords in dataset
# add case for if parameters present in bcs?

train_sets_pde = get_dataset_train_points(pde_system.eqs,
dataset_pde,
pinnrep)
# j is number of indvar coords in dataset, i is number of PDE equations in system
# -1 is placeholder, removed in merge_strategy_with_loglikelihood_function function call (train_sets[:, 2:end]())
colloc_train_sets = [[hcat([-1], train_sets_pde[i][:, j]...)
for i in eachindex(data_colloc_loss_functions[1])]
for j in eachindex(data_colloc_loss_functions)]

# using dataset's indvar coords as train_sets_pde and indvar coord's datafree_colloc_loss_function, create loss functions
# placeholder strategy = GridTraining(0.1), datafree_bc_loss_function and train_sets_bc must be nothing
# order of indvar coords will be same as corresponding depvar coords values in dataset provided in get_lossy() call.
pde_loss_function_points = [merge_strategy_with_loglikelihood_function(
pinnrep,
GridTraining(0.1),
data_colloc_loss_functions[i],
nothing;
train_sets_pde = colloc_train_sets[i],
train_sets_bc = nothing)[1]
for i in eachindex(data_colloc_loss_functions)]

function L2_loss2(θ, phynewstd)
# first sum is over points losses over many equations for the same points
# second sum is over all points
pde_loglikelihoods = sum([sum([pde_loss_function(θ,
phynewstd[i])
for (i, pde_loss_function) in enumerate(pde_loss_functions)])
for pde_loss_functions in pde_loss_function_points])
end
end

# add overall functionality for BC dataset points (case of parametric BC) ?
if ((dataset_bc isa Nothing) && (dataset_pde isa Nothing))
dataset = nothing
elseif dataset_bc isa Nothing
Expand Down Expand Up @@ -306,8 +419,8 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;

# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = PDELogTargetDensity(
nparameters, strategy, dataset, priors, [phystd, bcstd, l2std],
names, ninv, initial_nnθ, full_weighted_loglikelihood, Φ)
nparameters, strategy, dataset, priors, [phystd, bcstd, l2std], phynewstd,
names, ninv, initial_nnθ, full_weighted_loglikelihood, newloss, Φ)

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]
Expand All @@ -320,8 +433,13 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
@printf("Current Physics Log-likelihood : %g\n",
ℓπ.full_loglikelihood(setparameters(ℓπ, initial_θ), ℓπ.allstd))
@printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, initial_θ))
@printf("Current MSE against dataset Log-likelihood : %g\n",
@printf("Current SSE against dataset Log-likelihood : %g\n",
L2LossData(ℓπ, initial_θ))
if !(newloss isa Nothing)
@printf("Current new loss : %g\n",
ℓπ.L2_loss2(setparameters(ℓπ, initial_θ),
ℓπ.phynewstd))
end
end

# parallel sampling option
Expand Down Expand Up @@ -370,11 +488,16 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;

if verbose
@printf("Sampling Complete.\n")
@printf("Current Physics Log-likelihood : %g\n",
@printf("Final Physics Log-likelihood : %g\n",
ℓπ.full_loglikelihood(setparameters(ℓπ, samples[end]), ℓπ.allstd))
@printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end]))
@printf("Current MSE against dataset Log-likelihood : %g\n",
@printf("Final Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end]))
@printf("Final SSE against dataset Log-likelihood : %g\n",
L2LossData(ℓπ, samples[end]))
if !(newloss isa Nothing)
@printf("Final new loss : %g\n",
ℓπ.L2_loss2(setparameters(ℓπ, samples[end]),
ℓπ.phynewstd))
end
end

fullsolution = BPINNstats(mcmc_chain, samples, stats)
Expand Down
27 changes: 14 additions & 13 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
priors <: Vector{<:Distribution}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
l2std::Vector{Float64}
autodiff::Bool
physdt::Float64
Expand Down Expand Up @@ -97,7 +98,7 @@ suggested extra loss function for ODE solver case
for i in 1:length(ltd.prob.u0)
physlogprob += logpdf(
MvNormal(deri_physsol[i, :],
Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(nnsol[i, :]))))),
Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))),
nnsol[i, :]
)
end
Expand Down Expand Up @@ -263,7 +264,7 @@ end
"""
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,
l2std = [0.05], phystd = [0.05], priorsNNw = (0.0, 2.0),
l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand Down Expand Up @@ -336,6 +337,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
~2/3 of draw samples)
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
* `phynewstd`: standard deviation of new loss func term
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
Expand Down Expand Up @@ -366,10 +368,10 @@ Incase you are only solving the Equations for solution, do not provide dataset
function ahmc_bayesian_pinn_ode(
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, autodiff = false,
Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false, estim_collocate = false)
@assert !isinplace(prob) "The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."
Expand Down Expand Up @@ -415,16 +417,15 @@ function ahmc_bayesian_pinn_ode(
nparameters += ninv
end

t0 = prob.tspan[1]
smodel = StatefulLuxLayer{true}(chain, nothing, st)
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, smodel, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)
phystd, phynewstd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

if verbose
@printf("Current Physics Log-likelihood: %g\n", physloglikelihood(ℓπ, initial_θ))
@printf("Current Prior Log-likelihood: %g\n", priorweights(ℓπ, initial_θ))
@printf("Current MSE against dataset Log-likelihood: %g\n",
@printf("Current SSE against dataset Log-likelihood: %g\n",
L2LossData(ℓπ, initial_θ))
if estim_collocate
@printf("Current gradient loss against dataset Log-likelihood: %g\n",
Expand Down Expand Up @@ -483,13 +484,13 @@ function ahmc_bayesian_pinn_ode(

if verbose
println("Sampling Complete.")
@printf("Current Physics Log-likelihood: %g\n",
@printf("Final Physics Log-likelihood: %g\n",
physloglikelihood(ℓπ, samples[end]))
@printf("Current Prior Log-likelihood: %g\n", priorweights(ℓπ, samples[end]))
@printf("Current MSE against dataset Log-likelihood: %g\n",
@printf("Final Prior Log-likelihood: %g\n", priorweights(ℓπ, samples[end]))
@printf("Final SSE against dataset Log-likelihood: %g\n",
L2LossData(ℓπ, samples[end]))
if estim_collocate
@printf("Current gradient loss against dataset Log-likelihood: %g\n",
@printf("Final gradient loss against dataset Log-likelihood: %g\n",
L2loss2(ℓπ, samples[end]))
end
end
Expand Down
Loading

0 comments on commit 6f4580f

Please sign in to comment.