Skip to content

Commit

Permalink
sc version of opt
Browse files Browse the repository at this point in the history
  • Loading branch information
jmsull committed Jul 25, 2023
1 parent 286cc64 commit 3cac1fa
Showing 1 changed file with 241 additions and 0 deletions.
241 changes: 241 additions & 0 deletions scripts/ude_fwddiff_deltac_sc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
using OrdinaryDiffEq
using Random
using LinearAlgebra, Statistics,Lux
rng = Xoshiro(123)
using Bolt
using Plots
using Optimization,SciMLSensitivity,OptimizationOptimisers,ComponentArrays,OptimizationOptimJL
using AbstractDifferentiation
import AbstractDifferentiation as AD, ForwardDiff
using SimpleChains
using AdvancedHMC,LogDensityProblems

# setup ks won't use all of them here...
L=2f3
lkmi,lkmax,nk = log10(2.0f0*π/L),log10(0.2f0),8
kk = 10.0f0.^(collect(lkmi:(lkmax-lkmi)/(nk-1):lkmax))
ℓᵧ=15
pertlen=2(ℓᵧ+1)+5

# define network
width=8
# U = Lux.Chain(Lux.Dense(pertlen+1, width, tanh), #input is u,t
# Lux.Dense(width, width,tanh),
# Lux.Dense(width, 2))
# p, st = Lux.setup(rng, U)
function get_nn(m,d_in)
NN = SimpleChain(static(d_in),
TurboDense{true}(tanh, m),
TurboDense{true}(tanh, m),
TurboDense{false}(identity, 2) #have not tested non-scalar output
);
p = SimpleChains.init_params(NN;rng);
G = SimpleChains.alloc_threaded_grad(NN);
return NN,p,G
end

NN,p,G = get_nn(width,pertlen+1)


# copy the hierarchy function os it works for nn - for this to work you need the Hierarchy_nn struct and unpack_nn in perturbations.jl

function hierarchy_nn!(du, u, hierarchy::Hierarchy_nn{T, BasicNewtonian}, x) where T
# compute cosmological quantities at time x, and do some unpacking
k, ℓᵧ, par, bg, ih = hierarchy.k, hierarchy.ℓᵧ, hierarchy.par, hierarchy.bg, hierarchy.ih
Ω_r, Ω_b, Ω_c, H₀² = par.Ω_r, par.Ω_b, par.Ω_c, bg.H₀^2
ℋₓ, ℋₓ′, ηₓ, τₓ′, τₓ′′ = bg.(x), bg.ℋ′(x), bg.η(x), ih.τ′(x), ih.τ′′(x)
a = x2a(x)
R = 4Ω_r / (3Ω_b * a)
csb² = ih.csb²(x)

# the new free cdm index (not used here)
α_c = par.α_c
# get the nn params
p_nn = hierarchy.p

Θ, Θᵖ, Φ, δ_c, v_c,δ_b, v_b = unpack_nn(u, hierarchy)
Θ′, Θᵖ′, _, _, _, _, _ = unpack_nn(du, hierarchy)

# Here I am throwing away the neutrinos entriely, which is probably bad
Ψ = -Φ - 12H₀² / k^2 / a^2 * (Ω_r * Θ[2]
# + Ω_c * a^(4+α_c) * σ_c #ignore this
)

Φ′ = Ψ - k^2 / (3ℋₓ^2) * Φ + H₀² / (2ℋₓ^2) * (
Ω_c * a^(2+α_c) * δ_c
+ Ω_b * a^(-1) * δ_b
+ 4Ω_r * a^(-2) * Θ[0]
)

# matter
nnin = hcat([u...,x])
# û = U(nnin,p_nn,st)[1]
# û = sc_noloss(nnin,p_nn)' .* std(true_u′,dims=1) .+ mean(true_u′,dims=1)
= NN(nnin,p_nn)'
δ′ = û[1]
v′ = û[2]
# here we implicitly assume σ_c = 0

δ_b′ = k / ℋₓ * v_b - 3Φ′
v_b′ = -v_b - k / ℋₓ * ( Ψ + csb² * δ_b) + τₓ′ * R * (3Θ[1] + v_b)
# photons
Π = Θ[2] + Θᵖ[2] + Θᵖ[0]
Θ′[0] = -k / ℋₓ * Θ[1] - Φ′
Θ′[1] = k / (3ℋₓ) * Θ[0] - 2k / (3ℋₓ) * Θ[2] + k / (3ℋₓ) * Ψ + τₓ′ * (Θ[1] + v_b/3)
forin 2:(ℓᵧ-1)
Θ′[ℓ] =* k / ((2+1) * ℋₓ) * Θ[ℓ-1] -
(ℓ+1) * k / ((2+1) * ℋₓ) * Θ[ℓ+1] + τₓ′ * (Θ[ℓ] - Π * Bolt.δ_kron(ℓ, 2) / 10)
end
Θᵖ′[0] = -k / ℋₓ * Θᵖ[1] + τₓ′ * (Θᵖ[0] - Π / 2)
forin 1:(ℓᵧ-1)
Θᵖ′[ℓ] =* k / ((2+1) * ℋₓ) * Θᵖ[ℓ-1] -
(ℓ+1) * k / ((2+1) * ℋₓ) * Θᵖ[ℓ+1] + τₓ′ * (Θᵖ[ℓ] - Π * δ_kron(ℓ, 2) / 10)
end
Θ′[ℓᵧ] = k / ℋₓ * Θ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θ[ℓᵧ]
Θᵖ′[ℓᵧ] = k / ℋₓ * Θᵖ[ℓᵧ-1] - ( (ℓᵧ + 1) / (ℋₓ * ηₓ) - τₓ′ ) * Θᵖ[ℓᵧ]
du[2(ℓᵧ+1)+1:2(ℓᵧ+1)+5] .= Φ′, δ′, v′, δ_b′, v_b′
return nothing
end

# use only the longest k mode
function hierarchy_nnu!(du, u, p, x)
hierarchy = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],p,15);
hierarchy_nn!(du, u, hierarchy, x)
end



# some setup
tspan = (-20.0f0, 0.0f0)
𝕡test = CosmoParams(Ω_c=0.3,α_c=-3.0);
bgtest = Background(𝕡test; x_grid=-20.0f0:1f-1:0.0f0);
ihtest = Bolt.get_saha_ih(𝕡test, bgtest);
hierarchytest = Hierarchy(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],15);
hierarchytestnn = Hierarchy_nn(BasicNewtonian(), 𝕡test, bgtest, ihtest, kk[1],p,15);
u0 = Bolt.initial_conditions_nn(tspan[1],hierarchytestnn)

NN(hcat([u0...,-20.f0]),p)
# NN([-20.f0],p)

# problem for truth and one we will remake
prob_trueode = ODEProblem(Bolt.hierarchy!, u0, tspan,hierarchytest)
# prob_nn = ODEProblem(hierarchy_nnu!, u0, (bgtest.x_grid[1],bgtest.x_grid[end]), ComponentArray{Float64}(p))

# Generate some noisy data (at reduced solver tolerance)
ode_data = Array(solve(prob_trueode, KenCarp4(), saveat = bgtest.x_grid,
abstol = 1f-3, reltol = 1f-3))
δ_true,v_true = ode_data[end-3,:],ode_data[end-2,:]
σfakeode = 0.1f0
noise_fakeode = δ_true .* randn(rng,size(δ_true)).*σfakeode
noise_fakeode_v = v_true .* randn(rng,size(v_true)).*σfakeode
Ytrain_ode = hcat([δ_true .+ noise_fakeode,v_true .+ noise_fakeode_v]...)
noise_both_old = Float32.(hcat([noise_fakeode,noise_fakeode_v]...))
noise_both = Float32.(hcat([δ_true.*σfakeode,v_true.*σfakeode]...))

#float conversion
fl_xgrid = Float32.(bgtest.x_grid)
fu0 = Float32.(u0)
prob_nn = ODEProblem(hierarchy_nnu!, fu0, (fl_xgrid[1],fl_xgrid[end]), p)

function predict(θ, T = fl_xgrid)
_prob = remake(prob_nn, u0 = fu0, tspan = (T[1], T[end]), p = θ)
res = Array(solve(_prob, KenCarp4(), saveat = T,
abstol = 1f-3, reltol = 1f-3))
return hcat(res[end-3,:],res[end-2,:])
end

#log loss
function loss(θ)
= predict(θ)
log(sum(abs2, (Ytrain_ode - X̂)./ noise_both ) )
end


# adtype = Optimization.AutoZygote()
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf,p)



# Training
function make_gif_plot(prediction,iter)
p = Plots.plot()
Plots.scatter!(p,bgtest.x_grid,Ytrain_ode[:,1],label="data")
Plots.plot!(p,bgtest.x_grid,prediction,lw=2,label="iter = $(iter)")
Plots.xlabel!(p,raw"$\log(a)$")
Plots.ylabel!(p,raw"$\delta_{c}(a)$")
savefig("../plots/learning_v1_multnoise$(σfakeode)_Adam$(iter)_s123_sc_chisq_to.png")
p
end

losses = [];
pp =[];
callback = function (p, l)
push!(losses, l)
push!(pp, p)
if length(losses) % 5 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
make_gif_plot(predict(p)[:,1],length(losses))
end
return false
end
niter_1,niter_2,niter_3 =50,50,20
η_1,η_2,η_3 = 1.f0,0.1f0,0.01f0
#loss doesn't go down very much at all with 1f-3 η1

res1 = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = niter_1)
# this is pretty slow, on the order of half an hour, but it is at least running!
#Current loss after 50 iterations: 3.8709838260084415
test_predict_o1 = predict(res1.u)

optprob2 = remake(optprob,u0 = res1.u);
res2 = Optimization.solve(optprob2, ADAM(η_2), callback = callback, maxiters = niter_2)
test_predict_o2 = predict(res2.u)
#Current loss after 100 iterations: 1.1319215191941459

optprob3 = remake(optprob,u0 = res2.u);
res3 = Optimization.solve(optprob3, ADAM(η_3), callback = callback, maxiters = 100)
test_predict_o3 = predict(res3.u)
#Current loss after 120 iterations: 1.0862281116475199

# FIXME MORE INVOLVED ADAM SCHEDULE
# for the heck of it try BFGS, not sure what parameters it usually takes
optprob4 = remake(optprob3,u0 = res4.u);
# res4 = Optimization.solve(optprob4, BFGS(),
res4 = Optimization.solve(optprob4, ADAM(η_3/1000), callback = callback,
maxiters = 100)
# callback = callback, maxiters = 10)
test_predict_o4 = predict(res4.u)
# wow this is slow...I guess due to Hessian approximation?
# We have the gradient so idk why this takes so much longer than ADAM?
# Somehow loss actually goes up also? Maybe we overshoot, can try again with specified initial stepsize...
test_predict_o1[:,1]
# Plots of the learned perturbations
Plots.scatter(bgtest.x_grid,Ytrain_ode[:,1],label="data")
Plots.plot!(bgtest.x_grid,δ_true,label="truth",lw=2.5)#,yscale=:log10)
Plots.plot!(bgtest.x_grid,predict(pc)[:,1],label="opt-v1",lw=2.5,ls=:dash)
Plots.title!(raw"$\delta_{c}$")
Plots.xlabel!(raw"$\log(a)$")
Plots.ylabel!(raw"$\delta_{c}(a)$")
savefig("../plots/deltac_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png")
pc
p
log10.(test_predict_o4[:,2])


# It *SEEMS LIKE* the model isn't flexible enough - i.e. when I shift to
# weighted loss from square loss the late exponential part gets worse
# while the early part gets worse.
# For SC training earlier, regularization helped a bit but not much...

Plots.scatter(bgtest.x_grid,Ytrain_ode[:,2],label="data")#,legend=:bottomright)
Plots.plot!(bgtest.x_grid,v_true,label="truth")
Plots.plot!(bgtest.x_grid,test_predict_o1[:,2],label="opt-v1")
Plots.plot!(bgtest.x_grid,test_predict_o4[:,2],label="opt-v1-full",lw=2.5,ls=:dot)
Plots.plot!(bgtest.x_grid,predict(pc)[:,2],label="opt-v1",lw=2.5,ls=:dash)
Plots.title!(raw"$v_{c}$")
Plots.xlabel!(raw"$\log(a)$")
Plots.ylabel!(raw"$v_{c}(a)$")
savefig("../plots/vc_learning_v1_multnoise$(σfakeode)_Adam$(niter_1)_$(niter_2)_$(niter_3)_$(η_1)_$(η_2)_$(η_3)_bfgs.png")

0 comments on commit 3cac1fa

Please sign in to comment.