Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Training Strategies to dae_solvers.jl #838

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
280a43d
Update dae_solve.jl
hippyhippohops Mar 26, 2024
911b3a3
Update dae_solve.jl
hippyhippohops Mar 26, 2024
9d98ae1
Update NNDAE_tests.jl
hippyhippohops Mar 28, 2024
35700c4
Added strategy::QuadratureTraining
hippyhippohops Mar 28, 2024
5773e91
Formatted indentation in strategy::WeightIntervalTraining
hippyhippohops Mar 28, 2024
c9f54de
Formatted Indentation in strategy::QuadratureTraining
hippyhippohops Mar 28, 2024
abc61bb
Refactored generate_losses in ode_solve.jl
hippyhippohops Mar 31, 2024
e4f06d5
Reverted back the ode_solve.jl to the previous set of codes
hippyhippohops Apr 1, 2024
0f4cfde
Edits to dae_solve.jl and NNDAE_tests.jl
hippyhippohops Apr 9, 2024
911d68c
Modified dae_solve.jl and NNDAE_tests
hippyhippohops Apr 12, 2024
2f4c505
Removed param_estim
hippyhippohops Apr 21, 2024
52cdea8
Update dae_solve.jl
hippyhippohops Apr 25, 2024
95457b9
Merge branch 'SciML:master' into patch-1
hippyhippohops Apr 30, 2024
d6f2e5f
Reset the code to match master code. Planning to start from scratch a…
hippyhippohops May 3, 2024
1ed7683
Implemented WeightedIntervalTraining and it's Test
hippyhippohops May 6, 2024
2f9db68
Formatted Code
hippyhippohops May 6, 2024
c2453d2
Added in failed Quadature training
hippyhippohops May 8, 2024
7c6c2bf
trying to workout quadature training strategy.
hippyhippohops May 16, 2024
70e0657
Stochastic training passes
hippyhippohops May 16, 2024
0098c6d
updates on NNDAE_tests.jl
hippyhippohops May 26, 2024
3e9473e
Updates
hippyhippohops May 26, 2024
92ec11c
Merge branch 'SciML:master' into patch-1
hippyhippohops Jun 4, 2024
b00c8cf
removing empty line
hippyhippohops Jun 4, 2024
afd05ee
Merge branch 'patch-1' of https://github.com/hippyhippohops/NeuralPDE…
hippyhippohops Jun 6, 2024
a4e2877
changes to quadrature training
hippyhippohops Jun 7, 2024
41dbf62
Added Quadrature training
hippyhippohops Jun 10, 2024
9c72fee
Changing to float64
hippyhippohops Jul 8, 2024
47b5aea
Merge branch 'SciML:master' into patch-1
hippyhippohops Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
sum(abs2, loss) / length(t)
end

function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, differential_vars::AbstractVector)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, differential_vars))
integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p, differential_vars)) for t in ts] WHAT IS TS HERE?? - DO WE NEED THIS LINE
function loss(θ, _)
intf = BatchIntegralFunction(integrand, max_batch = strategy.batch)
intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ)
sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, reltol = strategy.reltol, maxiters = strategy.maxiters)
sol.u
end
return loss
end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
ts = tspan[1]:(strategy.dx):tspan[2]
Expand All @@ -82,6 +94,42 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
return loss
end

function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, differential_vars::AbstractVector)
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end


function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,differential_vars::AbstractVector)
hippyhippohops marked this conversation as resolved.
Show resolved Hide resolved
autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining."))
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
points = strategy.points

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+ ((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data
function loss(θ, _)
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end

function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
alg::NNDAE,
args...;
Expand Down
143 changes: 74 additions & 69 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,79 +219,87 @@ end

Representation of the loss function, parametric on the training strategy `strategy`.
"""
function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
batch, param_estim::Bool)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim))

integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts]

function loss(θ, _)
intf = BatchIntegralFunction(integrand, max_batch = strategy.batch)
intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ)
sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, reltol = strategy.reltol, maxiters = strategy.maxiters)
sol.u
end
return loss
end
function generate_loss(strategy, phi, f, autodiff::Bool, tspan, p,
batch, param_estim::Bool)

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch, param_estim::Bool)
ts = tspan[1]:(strategy.dx):tspan[2]
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
function loss(θ, _)
if batch
inner_loss(phi, f, autodiff, ts, θ, p, param_estim)
else
sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])
if strategy isa QuadratureTraining

integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim))
function integrand(ts, θ)
[abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts]
end
end
return loss
end

function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
batch, param_estim::Bool)
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
if batch
inner_loss(phi, f, autodiff, ts, θ, p, param_estim)
else
sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])

function loss(θ, _)
intf = BatchIntegralFunction(integrand, max_batch = strategy.batch)
intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ)
sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol,
reltol = strategy.reltol, maxiters = strategy.maxiters)
sol.u
end
end
return loss
end

function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
batch, param_estim::Bool)
autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining."))
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
points = strategy.points

difference = (maxT - minT) / N
return loss

elseif strategy isa GridTraining
ts = tspan[1]:(strategy.dx):tspan[2]
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
function loss(θ, _)
if batch
inner_loss(phi, f, autodiff, ts, θ, p, param_estim)
else
sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])
end
end
return loss

elseif strategy isa StochasticTraining
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
if batch
inner_loss(phi, f, autodiff, ts, θ, p, param_estim)
else
sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])
end
end
return loss

elseif strategy isa WeightedIntervalTraining

autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining."))
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
points = strategy.points

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data
function loss(θ, _)
if batch
inner_loss(phi, f, autodiff, ts, θ, p, param_estim)
else
sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])
end
end
return loss

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
else strategy isa QuasiRandomTraining
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end
end


ts = data
function loss(θ, _)
if batch
inner_loss(phi, f, autodiff, ts, θ, p, param_estim)
else
sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])
end
end
return loss
end

function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch, param_estim::Bool)
function loss(θ, _)
Expand All @@ -304,9 +312,6 @@ function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch, param_es
return loss
end

function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan)
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end

struct NNODEInterpolation{T <: ODEPhi, T2}
phi::T
Expand Down
37 changes: 37 additions & 0 deletions test/NNDAE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,40 @@ end

@test ground_sol(0:(1 / 100):(pi / 2))≈sol atol=0.4
end

@testset "WeightedIntervalTraining" begin
println("WeightedIntervalTraining")
function f(out, du, u, p, t)
out[1] = du[1] -u[2]
out[2] = u[1] + u[2]
end
p = []
u0 = [1.0/4.0, 1.0/4.0]
tspan = (0.0, 100000.0)
differential_vars = [true, false]
prob = DAEProblem(f, u0, tspan, differential_vars = differential_vars)
true_out_1(t) = exp(-t)/4.0
true_out_2(t) = -1.0 * exp(-t)/4.0
func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))
opt = OptimizationOptimisers.Adam(0.01)
weights = [0.7, 0.2, 0.1]
points = 200
alg = NNDAE(chain, opt,init_params = nothing; autodiff = false,
strategy = NeuralPDE.WeightedIntervalTraining(weights, points))
sol = solve(prob, alg, verbose = false, maxiters = 5000, saveat = 0.01)
#@test abs(mean(sol) - mean(true_sol)) < 0.2
"""Sol would have 2 outputs: one for u[1] and the other for u[2] so just I need compute the total error for all t in tspan """
total_error = 0
for i in tspan:
total_error = total_error + abs(sol(i) - [true_out_1(i) true_out_2(i)])
end
if total_error < 0.01:
print("It works!")
else:
print("Total error exceeds bound")
end
end