From 931f8f5e530a60d250b5905a2c3a32a821d50571 Mon Sep 17 00:00:00 2001 From: DexuanZhou Date: Sat, 2 Dec 2023 00:20:12 +0800 Subject: [PATCH] modifications to embedding and Tucker function --- src/bflow3d.jl | 2 +- src/tensordecomposition/Tucker.jl | 4 +++- src/vmc/multilevel.jl | 25 ++++++++++++++++--------- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/bflow3d.jl b/src/bflow3d.jl index eee153e..35eb111 100644 --- a/src/bflow3d.jl +++ b/src/bflow3d.jl @@ -257,7 +257,7 @@ function BFwf_lux(Nel::Integer, bRnl, bYlm, nuclei, TD::Tucker; totdeg = 15, jastrow_layer = ACEpsi.lux(js) BFwf_chain = Chain(; ϕnlm = aobasis_layer, bA = pooling_layer, TK = tucker_layer, - reshape = Lux.Parallel(nothing, (myReshapeLayer((Nel, 3 * TD.P)) for i = 1:Nel)...), + #reshape = Lux.Parallel(nothing, (myReshapeLayer((Nel, 3 * TD.P)) for i = 1:Nel)...), bAA = Lux.Parallel(nothing, (deepcopy(corr_layer) for i = 1:Nel)...), hidden1 = Lux.Parallel(nothing, (LinearLayer(length(corr1), 1) for i = 1:Nel)...), l_concat = WrappedFunction(x -> hcat(x...)), diff --git a/src/tensordecomposition/Tucker.jl b/src/tensordecomposition/Tucker.jl index 624ae8f..e713922 100644 --- a/src/tensordecomposition/Tucker.jl +++ b/src/tensordecomposition/Tucker.jl @@ -41,7 +41,9 @@ _valtype(l::TuckerLayer, x::AbstractArray, ps) = promote_type(eltype(x), eltype function (l::TuckerLayer)(x::AbstractArray, ps, st) #@tullio out[i, j, p] := ps.W[j, p, m, k] * x[i, j, m, k] - out = ntuple(i -> (@tullio out[i, j, p] := ps.W[i, j, p, m, k] * x[i, j, m, k] (m in 1:l.M, k in 1:l.K)), l.Nel) + #out = ntuple(a -> (@tullio out[i, j, p] := ps.W[a, j, p, m, k] * x[i, j, m, k] (m in 1:l.M, k in 1:l.K)), l.Nel) + A = @tullio out[a, i, j, p] := ps.W[a, j, p, m, k] * x[i, j, m, k] (m in 1:l.M, k in 1:l.K) + out = ntuple(a -> reshape(A[a,:,:,:], l.Nel, :), l.Nel) ignore_derivatives() do release!(x) end diff --git a/src/vmc/multilevel.jl b/src/vmc/multilevel.jl index 3490eb5..5d089ea 100644 --- a/src/vmc/multilevel.jl +++ b/src/vmc/multilevel.jl @@ -32,20 +32,29 @@ end function EmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2) readable_spec = displayspec(spec, spec1p) readable_spec2 = displayspec(spec2, spec1p2) - @assert size(ps.branch.bf.hidden1.W, 1) == size(ps2.branch.bf.hidden1.W, 1) - @assert size(ps.branch.bf.hidden1.W, 2) ≤ size(ps2.branch.bf.hidden1.W, 2) + #@assert size(ps.branch.bf.hidden1.W, 1) == size(ps2.branch.bf.hidden1.W, 1) + #@assert size(ps.branch.bf.hidden1.W, 2) ≤ size(ps2.branch.bf.hidden1.W, 2) + @assert size(ps.branch.bf.hidden1.layer_1.W, 1) == size(ps2.branch.bf.hidden1.layer_1.W, 1) + @assert size(ps.branch.bf.hidden1.layer_1.W, 2) ≤ size(ps2.branch.bf.hidden1.layer_1.W, 2) + @assert all(t in readable_spec2 for t in readable_spec) @assert all(t in specAO2 for t in specAO) # set all parameters to zero - ps2.branch.bf.hidden1.W .= 0.0 + #ps2.branch.bf.hidden1.W .= 0.0 + for i in keys(ps.branch.bf.hidden1) + ps2.branch.bf.hidden1[i].W .= 0.0 + end # _map[spect] = index in readable_spec2 _map = _invmap(readable_spec2) _mapAO = _invmapAO(specAO2) # embed for (idx, t) in enumerate(readable_spec) - ps2.branch.bf.hidden1.W[:, _map[t]] = ps.branch.bf.hidden1.W[:, idx] + #ps2.branch.bf.hidden1.W[:, _map[t]] = ps.branch.bf.hidden1.W[:, idx] + for i in keys(ps.branch.bf.hidden1) + ps2.branch.bf.hidden1[i].W[:, _map[t]] = ps.branch.bf.hidden1[i].W[:, idx] + end end if :ϕnlm in keys(ps.branch.bf) if :ζ in keys(ps.branch.bf.ϕnlm) @@ -58,7 +67,8 @@ function EmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2) if :TK in keys(ps.branch.bf) ps2.branch.bf.TK.W .= 0 - ps2.branch.bf.TK.W[:,1:size(ps.branch.bf.TK.W)[2],:,1:size(ps.branch.bf.TK.W)[4]] .= ps.branch.bf.TK.W + # W = randn(rng, l.Nel, 3, l.P, l.M, l.K) + ps2.branch.bf.TK.W[:,:,1:size(ps.branch.bf.TK.W)[3],:,1:size(ps.branch.bf.TK.W)[5]] .= ps.branch.bf.TK.W end return ps2 end @@ -108,10 +118,7 @@ function gd_GradientByVMC_multilevel(opt_vmc::VMC_multilevel, sam::MHSampler, ha end ν = maximum(length.(spec)) if :hidden1 in keys(ps.branch.bf) - _basis_size = size(ps.branch.bf.hidden1.W, 2) - @info("level = $l, order = $ν, size of basis = $_basis_size") - elseif :hidden1 in keys(ps.branch.bf.Pds.layer_1) - _basis_size = size(ps.branch.bf.Pds.layer_1.hidden1.W, 2) + _basis_size = size(ps.branch.bf.hidden1[1].W, 2) @info("level = $l, order = $ν, size of basis = $_basis_size") else @info("level = $l, order = $ν")