Skip to content

Commit

Permalink
fix a bug to ensure Tucker-bf is anti-symmetric
Browse files Browse the repository at this point in the history
  • Loading branch information
DexuanZhou committed Nov 26, 2023
1 parent b006acd commit 40dfac0
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/bflow3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ function BFwf_lux(Nel::Integer, bRnl, bYlm, nuclei, TD::Tucker; totdeg = 15,
pooling_layer = ACEpsi.lux(pooling)

# P <= length(nuclei) * length(prodbasis_layer.sparsebasis)
tucker_layer = ACEpsi.TD.TuckerLayer(TD.P, Nel, length(nuclei), length(pooling.basis.prodbasis.sparsebasis))
tucker_layer = ACEpsi.TD.TuckerLayer(TD.P, length(nuclei), length(pooling.basis.prodbasis.sparsebasis))

spec1p = get_spec(TD)
# define sparse for n-correlations
Expand Down
7 changes: 3 additions & 4 deletions src/tensordecomposition/Tucker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,22 @@ end
"""
struct TuckerLayer <: AbstractExplicitLayer
P::Integer # reduced dimension
N::Integer # number of electron
M::Integer # number of nuclei
K::Integer # spec1p
@reqfields()
end

TuckerLayer(P::Integer, N::Integer, M::Integer, K::Integer) = TuckerLayer(P, N, M, K, _make_reqfields()...)
TuckerLayer(P::Integer, M::Integer, K::Integer) = TuckerLayer(P, M, K, _make_reqfields()...)

_valtype(l::TuckerLayer, x::AbstractArray, ps) = promote_type(eltype(x), eltype(ps.W))

function (l::TuckerLayer)(x::AbstractArray, ps, st)
@tullio out[i, j, p] := ps.W[i, j, p, m, k] * x[i, j, m, k]
@tullio out[i, j, p] := ps.W[j, p, m, k] * x[i, j, m, k]
ignore_derivatives() do
release!(x)
end
return out, st
end

LuxCore.initialparameters(rng::AbstractRNG, l::TuckerLayer) = ( W = randn(rng, l.N, 3, l.P, l.M, l.K), )
LuxCore.initialparameters(rng::AbstractRNG, l::TuckerLayer) = ( W = randn(rng, 3, l.P, l.M, l.K), )
LuxCore.initialstates(rng::AbstractRNG, l::TuckerLayer) = NamedTuple()
2 changes: 1 addition & 1 deletion src/vmc/multilevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function EmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2)

if :TK in keys(ps.branch.bf)
ps2.branch.bf.TK.W .= 0
ps2.branch.bf.TK.W[:,:,1:size(ps.branch.bf.TK.W)[3],:,1:size(ps.branch.bf.TK.W)[5]] .= ps.branch.bf.TK.W
ps2.branch.bf.TK.W[:,1:size(ps.branch.bf.TK.W)[2],:,1:size(ps.branch.bf.TK.W)[4]] .= ps.branch.bf.TK.W
end
return ps2
end
Expand Down

0 comments on commit 40dfac0

Please sign in to comment.