Skip to content

Commit

Permalink
modifications to embedding and Tucker function
Browse files Browse the repository at this point in the history
  • Loading branch information
DexuanZhou committed Dec 1, 2023
1 parent ad516e8 commit 931f8f5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/bflow3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ function BFwf_lux(Nel::Integer, bRnl, bYlm, nuclei, TD::Tucker; totdeg = 15,
jastrow_layer = ACEpsi.lux(js)

BFwf_chain = Chain(; ϕnlm = aobasis_layer, bA = pooling_layer, TK = tucker_layer,
reshape = Lux.Parallel(nothing, (myReshapeLayer((Nel, 3 * TD.P)) for i = 1:Nel)...),
#reshape = Lux.Parallel(nothing, (myReshapeLayer((Nel, 3 * TD.P)) for i = 1:Nel)...),
bAA = Lux.Parallel(nothing, (deepcopy(corr_layer) for i = 1:Nel)...),
hidden1 = Lux.Parallel(nothing, (LinearLayer(length(corr1), 1) for i = 1:Nel)...),
l_concat = WrappedFunction(x -> hcat(x...)),
Expand Down
4 changes: 3 additions & 1 deletion src/tensordecomposition/Tucker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ _valtype(l::TuckerLayer, x::AbstractArray, ps) = promote_type(eltype(x), eltype

function (l::TuckerLayer)(x::AbstractArray, ps, st)
#@tullio out[i, j, p] := ps.W[j, p, m, k] * x[i, j, m, k]
out = ntuple(i -> (@tullio out[i, j, p] := ps.W[i, j, p, m, k] * x[i, j, m, k] (m in 1:l.M, k in 1:l.K)), l.Nel)
#out = ntuple(a -> (@tullio out[i, j, p] := ps.W[a, j, p, m, k] * x[i, j, m, k] (m in 1:l.M, k in 1:l.K)), l.Nel)
A = @tullio out[a, i, j, p] := ps.W[a, j, p, m, k] * x[i, j, m, k] (m in 1:l.M, k in 1:l.K)
out = ntuple(a -> reshape(A[a,:,:,:], l.Nel, :), l.Nel)
ignore_derivatives() do
release!(x)
end
Expand Down
25 changes: 16 additions & 9 deletions src/vmc/multilevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,29 @@ end
function EmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2)
readable_spec = displayspec(spec, spec1p)
readable_spec2 = displayspec(spec2, spec1p2)
@assert size(ps.branch.bf.hidden1.W, 1) == size(ps2.branch.bf.hidden1.W, 1)
@assert size(ps.branch.bf.hidden1.W, 2) size(ps2.branch.bf.hidden1.W, 2)
#@assert size(ps.branch.bf.hidden1.W, 1) == size(ps2.branch.bf.hidden1.W, 1)
#@assert size(ps.branch.bf.hidden1.W, 2) ≤ size(ps2.branch.bf.hidden1.W, 2)
@assert size(ps.branch.bf.hidden1.layer_1.W, 1) == size(ps2.branch.bf.hidden1.layer_1.W, 1)
@assert size(ps.branch.bf.hidden1.layer_1.W, 2) size(ps2.branch.bf.hidden1.layer_1.W, 2)

@assert all(t in readable_spec2 for t in readable_spec)
@assert all(t in specAO2 for t in specAO)

# set all parameters to zero
ps2.branch.bf.hidden1.W .= 0.0
#ps2.branch.bf.hidden1.W .= 0.0
for i in keys(ps.branch.bf.hidden1)
ps2.branch.bf.hidden1[i].W .= 0.0
end

# _map[spect] = index in readable_spec2
_map = _invmap(readable_spec2)
_mapAO = _invmapAO(specAO2)
# embed
for (idx, t) in enumerate(readable_spec)
ps2.branch.bf.hidden1.W[:, _map[t]] = ps.branch.bf.hidden1.W[:, idx]
#ps2.branch.bf.hidden1.W[:, _map[t]] = ps.branch.bf.hidden1.W[:, idx]
for i in keys(ps.branch.bf.hidden1)
ps2.branch.bf.hidden1[i].W[:, _map[t]] = ps.branch.bf.hidden1[i].W[:, idx]
end
end
if :ϕnlm in keys(ps.branch.bf)
if in keys(ps.branch.bf.ϕnlm)
Expand All @@ -58,7 +67,8 @@ function EmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2)

if :TK in keys(ps.branch.bf)
ps2.branch.bf.TK.W .= 0
ps2.branch.bf.TK.W[:,1:size(ps.branch.bf.TK.W)[2],:,1:size(ps.branch.bf.TK.W)[4]] .= ps.branch.bf.TK.W
# W = randn(rng, l.Nel, 3, l.P, l.M, l.K)
ps2.branch.bf.TK.W[:,:,1:size(ps.branch.bf.TK.W)[3],:,1:size(ps.branch.bf.TK.W)[5]] .= ps.branch.bf.TK.W
end
return ps2
end
Expand Down Expand Up @@ -108,10 +118,7 @@ function gd_GradientByVMC_multilevel(opt_vmc::VMC_multilevel, sam::MHSampler, ha
end
ν = maximum(length.(spec))
if :hidden1 in keys(ps.branch.bf)
_basis_size = size(ps.branch.bf.hidden1.W, 2)
@info("level = $l, order = , size of basis = $_basis_size")
elseif :hidden1 in keys(ps.branch.bf.Pds.layer_1)
_basis_size = size(ps.branch.bf.Pds.layer_1.hidden1.W, 2)
_basis_size = size(ps.branch.bf.hidden1[1].W, 2)
@info("level = $l, order = , size of basis = $_basis_size")
else
@info("level = $l, order = ")
Expand Down

0 comments on commit 931f8f5

Please sign in to comment.