diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 076759e..0cc9b80 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,8 +18,8 @@ jobs: fail-fast: false matrix: version: - - '1.0' - '1.8' + - '1.9' - 'nightly' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index a8e4139..64e0790 100644 --- a/Project.toml +++ b/Project.toml @@ -5,25 +5,37 @@ version = "0.0.1" [deps] ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" -ACEcore = "44c1e890-45d1-48ea-94d6-c2ea5b573f71" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97" +IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" ObjectPools = "658cac36-ff0f-48ad-967c-110375d98c9d" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StrideArrays = "d1fa6d79-ef01-42a6-86c9-f7c551f8593b" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ACEcore = "0.0.1" +ACEbase = "0.4.2" BenchmarkTools = "1" ForwardDiff = "0.10" -ObjectPools = "0.0.2" -Polynomials4ML = "0.0.2" -julia = "1" JSON = "0.21" -ACEbase = "0.2" - +ObjectPools = "0.3.1" +julia = "1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/benchmarks/benchmark_bflow_lux.jl b/benchmarks/benchmark_bflow_lux.jl new file mode 100644 index 0000000..a7c0002 --- /dev/null +++ b/benchmarks/benchmark_bflow_lux.jl @@ -0,0 +1,58 @@ +using ACEpsi, Polynomials4ML, StaticArrays, Test +using Polynomials4ML: natural_indices, degree, SparseProduct +using ACEpsi.AtomicOrbitals: Nuc, make_nlms_spec, evaluate +using ACEpsi: BackflowPooling, BFwf_lux, setupBFState, Jastrow +using ACEpsi.vmc: gradient, laplacian, grad_params +using ACEbase.Testing: print_tf, fdtest +using LuxCore +using Lux +using Zygote +using Optimisers # mainly for the destrcuture(ps) function +using Random +using Printf +using LinearAlgebra +using BenchmarkTools +using HyperDualNumbers: Hyper + +Rnldegree = n1 = 2 +Ylmdegree = 3 +totdegree = 3 +Nel = 2 +X = randn(SVector{3, Float64}, Nel) +Σ = rand(spins(), Nel) +nuclei = [ Nuc(3 * rand(SVector{3, Float64}), 1.0) for _=1:3 ] + +# wrap it as HyperDualNumbers +x2dualwrtj(x, j) = SVector{3}([Hyper(x[i], i == j, i == j, 0) for i = 1:3]) +hX = [x2dualwrtj(x, 0) for x in X] +hX[1] = x2dualwrtj(X[1], 1) # test eval for grad wrt x coord of first elec + +## + +# Defining AtomicOrbitalsBasis +n2 = 2 +Pn = Polynomials4ML.legendre_basis(n1+1) +spec = [(n1 = n1, n2 = n2, l = l) for n1 = 1:n1 for n2 = 1:n2 for l = 0:n1-1] +ζ = rand(length(spec)) +Dn = GaussianBasis(ζ) +bRnl = AtomicOrbitalsRadials(Pn, Dn, spec) +bYlm = RYlmBasis(Ylmdegree) + +# setup state +BFwf_chain, spec, spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree, ν = 2) +ps, st = setupBFState(MersenneTwister(1234), BFwf_chain, Σ) + +## + +@info("Test evaluate") + +@btime BFwf_chain($X, $ps, $st) +@btime gradient($BFwf_chain, $X, $ps, $st) +@btime laplacian($BFwf_chain, $X, $ps, $st) + + +@profview let BFwf_chain = BFwf_chain, X = X, ps = ps, st = st + for i = 1:10_000 + BFwf_chain(X, ps, st) + end +end \ No newline at end of file diff --git a/examples/3D/Be.jl b/examples/3D/Be.jl new file mode 100644 index 0000000..a8e184a --- /dev/null +++ b/examples/3D/Be.jl @@ -0,0 +1,62 @@ +using ACEpsi, Polynomials4ML, StaticArrays, Test +using Polynomials4ML: natural_indices, degree, SparseProduct +using ACEpsi.AtomicOrbitals: Nuc, make_nlms_spec, evaluate +using ACEpsi: BackflowPooling, BFwf_lux, setupBFState, Jastrow +using ACEpsi.vmc: gradient, laplacian, grad_params, SumH, MHSampler, VMC, gd_GradientByVMC +using ACEbase.Testing: print_tf, fdtest +using LuxCore +using Lux +using Zygote +using Optimisers # mainly for the destrcuture(ps) function +using Random +using Printf +using LinearAlgebra +using BenchmarkTools +using HyperDualNumbers: Hyper + +n1 = Rnldegree = 5 +Ylmdegree = 2 +totdegree = 5 +Nel = 4 +X = randn(SVector{3, Float64}, Nel) +Σ = [↑,↑,↓,↓] +nuclei = [ Nuc(zeros(SVector{3, Float64}), Nel * 1.0)] +## + +# Defining AtomicOrbitalsBasis +n2 = 1 +Pn = Polynomials4ML.legendre_basis(n1+1) +spec = [(n1 = n1, n2 = n2, l = l) for n1 = 1:n1 for n2 = 1:n2 for l = 0:n1-1] +ζ = 10 * rand(length(spec)) +Dn = SlaterBasis(ζ) +bRnl = AtomicOrbitalsRadials(Pn, Dn, spec) +bYlm = RYlmBasis(Ylmdegree) + +# setup state +wf, spec, spec1p = BFwf_chain, spec, spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree, ν = 2) +displayspec(spec, spec1p) + +ps, st = setupBFState(MersenneTwister(1234), BFwf_chain, Σ) +p, = destructure(ps) +length(p) + +K(wf, X::AbstractVector, ps, st) = -0.5 * laplacian(wf, X, ps, st) +Vext(wf, X::AbstractVector, ps, st) = -sum(nuclei[i].charge/norm(nuclei[i].rr - X[j]) for i = 1:length(nuclei) for j in 1:length(X)) +Vee(wf, X::AbstractVector, ps, st) = sum(1/norm(X[i]-X[j]) for i = 1:length(X)-1 for j = i+1:length(X)) + +ham = SumH(nuclei) +sam = MHSampler(wf, Nel, nuclei, Δt = 0.5, burnin = 1000, nchains = 2000) + +opt_vmc = VMC(3000, 0.1, ACEpsi.vmc.adamW(), lr_dc = 100) +wf, err_opt, ps = gd_GradientByVMC(opt_vmc, sam, ham, wf, ps, st) + +err = err_opt +per = 0.2 +err1 = zero(err) +for i = 1:length(err) + err1[i] = mean(err[Int(ceil(i-per * i)):i]) +end +err1 + +Eref = -14.667 + diff --git a/examples/3D/He_multi.jl b/examples/3D/He_multi.jl new file mode 100644 index 0000000..5048a5e --- /dev/null +++ b/examples/3D/He_multi.jl @@ -0,0 +1,45 @@ +using ACEpsi, Polynomials4ML, StaticArrays, Test +using Polynomials4ML: natural_indices, degree, SparseProduct +using ACEpsi.AtomicOrbitals: Nuc, make_nlms_spec, evaluate +using ACEpsi.vmc: gradient, laplacian, grad_params, EmbeddingW!, _invmap, VMC_multilevel, wf_multilevel, VMC, gd_GradientByVMC, gd_GradientByVMC_multilevel, AdamW, SR, SumH, MHSampler +using ACEbase.Testing: print_tf, fdtest +using LuxCore +using Lux +using Zygote +using Optimisers # mainly for the destrcuture(ps) function +using Random +using Printf +using LinearAlgebra +using BenchmarkTools +using HyperDualNumbers: Hyper + + +# Define He model +Nel = 2 +X = randn(SVector{3, Float64}, Nel) +Σ = [↑,↓] +nuclei = [ Nuc(zeros(SVector{3, Float64}), 2.0)] + +K(wf, X::AbstractVector, ps, st) = -0.5 * laplacian(wf, X, ps, st) +Vext(wf, X::AbstractVector, ps, st) = -sum(nuclei[i].charge/norm(nuclei[i].rr - X[j]) for i = 1:length(nuclei) for j in 1:length(X)) +Vee(wf, X::AbstractVector, ps, st) = sum(1/norm(X[i]-X[j]) for i = 1:length(X)-1 for j = i+1:length(X)) + +ham = SumH(nuclei) + +# Defining Multilevel +Rnldegree = [4, 6, 6, 7] +Ylmdegree = [2, 2, 3, 4] +totdegree = [2, 3, 3, 4] +n2 = [1, 1, 2, 2] +ν = [1, 1, 2, 2] +MaxIters = [3, 3, 3, 3] +## + +# +wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Rnldegree, Ylmdegree, totdegree, n2, ν) + +sam = MHSampler(wf_list[1], Nel, nuclei, Δt = 0.5, burnin = 1, nchains = 20) +opt_vmc = VMC_multilevel(MaxIters, 0.0015, SR(1e-4, 0.015), lr_dc = 50.0) + +wf, err_opt, ps = gd_GradientByVMC_multilevel(opt_vmc, sam, ham, wf_list, ps_list, st_list, spec_list, spec1p_list, specAO_list) + diff --git a/examples/3D/ccpvdz_H10.jl b/examples/3D/ccpvdz_H10.jl new file mode 100644 index 0000000..1b1c91b --- /dev/null +++ b/examples/3D/ccpvdz_H10.jl @@ -0,0 +1,80 @@ +using ACEpsi, Polynomials4ML, StaticArrays, Test +using Polynomials4ML: natural_indices, degree, SparseProduct +using ACEpsi.AtomicOrbitals: Nuc, make_nlms_spec, evaluate +using ACEpsi: BackflowPooling, BFwf_lux, setupBFState, Jastrow, displayspec +using ACEpsi.vmc: gradient, laplacian, grad_params, SumH, MHSampler, VMC, gd_GradientByVMC +using ACEbase.Testing: print_tf, fdtest +using LuxCore +using Lux +using Zygote +using Optimisers +using Random +using Printf +using LinearAlgebra +using BenchmarkTools +using HyperDualNumbers: Hyper + +n1 = Rnldegree = 2 +Ylmdegree = 1 +totdegree = 30 +Nel = 10 +X = randn(SVector{3, Float64}, Nel) +Σ = [↑,↑,↑,↑,↑,↓,↓,↓,↓,↓] +spacing = 1.0 +nuclei = [Nuc(SVector(0.0,0.0,(i-1/2-Nel/2) * spacing), 1.0) for i = 1:Nel] +Pn = Polynomials4ML.legendre_basis(n1) +spec = [(n1 = 1, n2 = 1, l = 0), (n1 = 1, n2 = 2, l = 0), (n1 = 2, n2 = 1, l = 1)] + +# Ref: http://www.grant-hill.group.shef.ac.uk/ccrepo/hydrogen/hbasis.php +# (4s,1p) -> [2s,1p] +# H S +# 1.301000E+01 1.968500E-02 0.000000E+00 +# 1.962000E+00 1.379770E-01 0.000000E+00 +# 4.446000E-01 4.781480E-01 0.000000E+00 +# 1.220000E-01 5.012400E-01 1.000000E+00 +# H P +# 7.270000E-01 1.0000000 + +ζ = [[1.301000E+01, 1.962000E+00, 4.446000E-01, 1.220000E-01], [1.220000E-01], [7.270000E-01]] +D = [[1.968500E-02, 1.379770E-01, 4.781480E-01, 5.012400E-01], [1.0000000], [1.0000000]] +D[1] = [(2 * ζ[1][i]/pi)^(3/4) * D[1][i] for i = 1:length(ζ[1])] * sqrt(2) * 2 * sqrt(pi) +D[2] = [(2 * ζ[2][i]/pi)^(3/4) * D[2][i] for i = 1:length(ζ[2])] * sqrt(2) * 2 * sqrt(pi) + +Dn = STO_NG((ζ, D)) +bRnl = AtomicOrbitalsRadials(Pn, Dn, spec) +bYlm = RYlmBasis(Ylmdegree) + +ord = 1 +wf, spec, spec1p = BFwf_chain, spec, spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree, ν = ord) + +ps, st = setupBFState(MersenneTwister(1234), BFwf_chain, Σ) +p, = destructure(ps) +length(p) +wf(X, ps, st) +@profview begin for i = 1:100000 wf(X, ps, st) end end + +ham = SumH(nuclei) +sam = MHSampler(wf, Nel, nuclei, Δt = 0.5, burnin = 1000, nchains = 2000) + +#using BenchmarkTools # N = 10, nuc = 10 # 50, 6325, 408425 +#@btime $wf($X, $ps, $st) # ord = 1: 22.764 μs (17 allocations: 19.67 KiB) + # ord = 2: 119.143 μs (17 allocations: 19.67 KiB) + # ord = 3: 9.800 ms (17 allocations: 19.67 KiB) + +#@btime $gradient($wf, $X, $ps, $st) # ord = 1: 81.785 μs (165 allocations: 153.27 KiB) + # ord = 2: 742.233 μs (167 allocations: 1.11 MiB) + # ord = 3: 60.314 ms (167 allocations: 62.47 MiB) + +#@btime $grad_params($wf, $X, $ps, $st) # ord = 1: 74.807 μs (142 allocations: 94.42 KiB) + # ord = 2: 732.846 μs (144 allocations: 1.05 MiB) + # ord = 3: 59.758 ms (144 allocations: 62.41 MiB) + +#@btime $laplacian($wf, $X, $ps, $st) # ord = 1: 1.215 ms (611 allocations: 2.20 MiB) + # ord = 2: 25.875 ms (611 allocations: 2.20 MiB) + # ord = 3: 2.648 s (611 allocations: 2.20 MiB) + +opt_vmc = VMC(5000, 0.015, ACEpsi.vmc.adamW(); lr_dc = 300.0) +#wf, err_opt, ps = gd_GradientByVMC(opt_vmc, sam, ham, wf, ps, st) + +## MRCI+Q: -23.5092 +## UHF: -23.2997 diff --git a/examples/3D/sto-6g_H10.jl b/examples/3D/sto-6g_H10.jl new file mode 100644 index 0000000..90c21c6 --- /dev/null +++ b/examples/3D/sto-6g_H10.jl @@ -0,0 +1,94 @@ +using ACEpsi, Polynomials4ML, StaticArrays, Test +using Polynomials4ML: natural_indices, degree, SparseProduct +using ACEpsi.AtomicOrbitals: Nuc, make_nlms_spec, evaluate +using ACEpsi: BackflowPooling, BFwf_lux, setupBFState, Jastrow, displayspec +using ACEpsi.vmc: gradient, laplacian, grad_params, SumH, MHSampler, VMC, gd_GradientByVMC +using ACEbase.Testing: print_tf, fdtest +using LuxCore +using Lux +using Zygote +using Optimisers +using Random +using Printf +using LinearAlgebra +using BenchmarkTools +using HyperDualNumbers: Hyper + +n1 = Rnldegree = 1 +Ylmdegree = 0 +totdegree = 20 +Nel = 10 +X = randn(SVector{3, Float64}, Nel) +Σ = [↑,↑,↑,↑,↑,↓,↓,↓,↓,↓] +spacing = 1.0 +nuclei = [Nuc(SVector(0.0,0.0,(i-1/2-Nel/2) * spacing), 1.0) for i = 1:Nel] +Pn = Polynomials4ML.legendre_basis(n1+1) +spec = [(n1 = 1, n2 = 1, l = 0)] + +# Ref: https://link.springer.com/book/10.1007/978-90-481-3862-3: P235 +# STO: 0.7790 * e^(-1.24 * r) +# ϕ_1s(1, r) = \sum_(k = 1)^K d_1s,k g_1s(α_1k, r) +# g_1s(α, r) = (2α/π)^(3/4) * exp(-αr^2): α ∼ ζ, g ∼ D +# sto-3g: Ref: https://www.basissetexchange.org/ +# BASIS SET: (3s) -> [1s] +# H S +# 0.3425250914E+01 0.1543289673E+00 +# 0.6239137298E+00 0.5353281423E+00 +# 0.1688554040E+00 0.4446345422E+00 + +# ζ = [[0.3425250914E+01, 0.6239137298E+00, 0.1688554040E+00]] +# D = [[0.1543289673E+00, 0.5353281423E+00, 0.4446345422E+00]] +# D[1] = [(2 * ζ[1][i]/pi)^(3/4) * D[1][i] for i = 1:length(ζ[1])] +# P_0(x) = 1/sqrt(2) +# Y_0(x) = 1/(2*sqrt(pi)) + +# sto-6g: Ref: https://www.basissetexchange.org/ +# BASIS SET: (6s) -> [1s] +# H S +# 0.3552322122E+02 0.9163596281E-02 +# 0.6513143725E+01 0.4936149294E-01 +# 0.1822142904E+01 0.1685383049E+00 +# 0.6259552659E+00 0.3705627997E+00 +# 0.2430767471E+00 0.4164915298E+00 +# 0.1001124280E+00 0.1303340841E+00 + +ζ = [[0.3552322122E+02, 0.6513143725E+01, 0.1822142904E+01,0.6259552659E+00, 0.2430767471E+00, 0.1001124280E+00]] +D = [[0.9163596281E-02, 0.4936149294E-01,0.1685383049E+00,0.3705627997E+00, 0.4164915298E+00, 0.1303340841E+00]] +D[1] = [(2 * ζ[1][i]/pi)^(3/4) * D[1][i] for i = 1:length(ζ[1])] * sqrt(2) * 2 * sqrt(pi) + +Dn = STO_NG((ζ, D)) +bRnl = AtomicOrbitalsRadials(Pn, Dn, spec) +bYlm = RYlmBasis(Ylmdegree) + +ord = 3 +wf, spec, spec1p = BFwf_chain, spec, spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree, ν = ord) + +ps, st = setupBFState(MersenneTwister(1234), BFwf_chain, Σ) # ps.hidden1.W: Nels * basis +p, = destructure(ps) +length(p) + +ham = SumH(nuclei) +sam = MHSampler(wf, Nel, nuclei, Δt = 0.5, burnin = 1000, nchains = 2000) + +#using BenchmarkTools # N = 10, nuc = 10 # 10, 265, 3685 +#@btime $wf($X, $ps, $st) # ord = 1: 17.197 μs (17 allocations: 6.86 KiB) + # ord = 2: 23.086 μs (17 allocations: 6.86 KiB) + # ord = 3: 78.264 μs (17 allocations: 6.86 KiB) + +#@btime $gradient($wf, $X, $ps, $st) # ord = 1: 56.969 μs (165 allocations: 100.73 KiB) + # ord = 2: 75.596 μs (167 allocations: 143.98 KiB) + # ord = 3: 889.464 μs (167 allocations: 681.77 KiB) + +#@btime $grad_params($wf, $X, $ps, $st) # ord = 1: 48.483 μs (142 allocations: 41.89 KiB) + # ord = 2: 67.085 μs (144 allocations: 81.73 KiB) + # ord = 3: 887.024 μs (144 allocations: 616.11 KiB) + +#@btime $laplacian($wf, $X, $ps, $st) # ord = 1: 738.616 μs (581 allocations: 720.28 KiB) + # ord = 2: 1.747 ms (581 allocations: 720.28 KiB) + # ord = 3: 20.752 ms (581 allocations: 720.28 KiB) + +opt_vmc = VMC(5000, 0.015, ACEpsi.vmc.adamW(); lr_dc = 100.0) +wf, err_opt, ps = gd_GradientByVMC(opt_vmc, sam, ham, wf, ps, st) + +## FCI: -23.1140: ord = 2: -23.3829 +## UHF: -23.0414: ord = 1: -23.0432 diff --git a/profile/profile_bflow1.jl b/profile/profile_bflow1.jl index f3c3f42..2673150 100644 --- a/profile/profile_bflow1.jl +++ b/profile/profile_bflow1.jl @@ -1,6 +1,6 @@ -using Polynomials4ML, ACEcore, ACEpsi, ACEbase, Printf +using Polynomials4ML, ACEpsi, ACEbase, Printf using ACEpsi: BFwf, gradient, evaluate, laplacian using LinearAlgebra using BenchmarkTools diff --git a/src/ACEpsi.jl b/src/ACEpsi.jl index 6d3c2e5..73b6440 100644 --- a/src/ACEpsi.jl +++ b/src/ACEpsi.jl @@ -1,6 +1,26 @@ module ACEpsi +# define operation on HyperDualNumbers +include("hyper.jl") + +# define spin symbols and some basic functionality +include("spins.jl") + +# the old 1d backflow code, keep around for now... include("bflow.jl") include("envelope.jl") +# the new 3d backflow code +include("atomicorbitals/atomicorbitals.jl") +include("jastrow.jl") +include("bflow3d.jl") + +include("backflowpooling.jl") + +# lux utils for bflow +include("lux_utils.jl") + +# vmc +include("vmc/opt.jl") + end diff --git a/src/atomicorbitals/atomicorbitals.jl b/src/atomicorbitals/atomicorbitals.jl new file mode 100644 index 0000000..64c5b13 --- /dev/null +++ b/src/atomicorbitals/atomicorbitals.jl @@ -0,0 +1,9 @@ +module AtomicOrbitals + +import Polynomials4ML: degree + +include("atorbbasis.jl") +include("rnlexample.jl") +include("productbasis.jl") + +end \ No newline at end of file diff --git a/src/atomicorbitals/atorbbasis.jl b/src/atomicorbitals/atorbbasis.jl new file mode 100644 index 0000000..320f15f --- /dev/null +++ b/src/atomicorbitals/atorbbasis.jl @@ -0,0 +1,276 @@ +import Polynomials4ML: evaluate + +using ACEpsi: ↑, ↓, ∅, spins, extspins, Spin, spin2idx, idx2spin +using Polynomials4ML: SparseProduct, _make_reqfields +using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer +using Random: AbstractRNG + +using Polynomials4ML: _make_reqfields, @reqfields, POOL, TMP, META + +using StaticArrays +using LinearAlgebra: norm + +using ChainRulesCore +using ChainRulesCore: NoTangent +using Zygote + +using Lux: Chain, apply +using ObjectPools: acquire! + +struct Nuc{T} + rr::SVector{3, T} + charge::T # should this be an integer? +end + +# +# Ordering of the embedding +# nuc | 1 2 3 1 2 3 1 2 3 +# k | 1 1 1 2 2 2 2 2 2 +# + +const NTRNL1 = NamedTuple{(:n, :l, :m), Tuple{Int, Int, Int}} +const NTRNLIS = NamedTuple{(:I, :s, :n, :l, :m), Tuple{Int, Spin, Int, Int, Int}} + +""" +This constructs the specification of all the atomic orbitals for one +nucleus. + +* bRnl : radial basis +* Ylm : angular basis, assumed to be spherical harmonics +* admissible : a filter, default is a total degree +""" +function make_nlms_spec(bRnl, bYlm; + totaldegree::Integer = -1, + admissible = ( (br, by) -> degree(bRnl, br) + + degree(bYlm, by) <= totaldegree), + nnuc = 0) + + spec_Rnl = natural_indices(bRnl) + spec_Ylm = natural_indices(bYlm) + + spec1 = [] + for (iR, br) in enumerate(spec_Rnl), (iY, by) in enumerate(spec_Ylm) + if br.l != by.l + continue + end + if admissible(br, by) + push!(spec1, (br..., m = by.m)) + end + end + return spec1 +end + + +# mutable struct AtomicOrbitalsBasis{NB, T} +# prodbasis::ProductBasis{NB} +# nuclei::Vector{Nuc{T}} # nuclei (defines the shifted orbitals) +# end + +# (aobasis::AtomicOrbitalsBasis)(args...) = evaluate(aobasis, args...) +# Base.length(aobasis::AtomicOrbitalsBasis) = length(aobasis.prodbasis.spec1) * length(aobasis.nuclei) + +# function AtomicOrbitalsBasis(bRnl, bYlm; +# totaldegree=3, +# nuclei = Nuc{Float64}[], +# ) +# spec1 = make_nlms_spec(bRnl, bYlm; +# totaldegree = totaldegree) +# prodbasis = ProductBasis(spec1, bRnl, bYlm) +# return AtomicOrbitalsBasis(prodbasis, nuclei) +# end + + +# function evaluate(basis::AtomicOrbitalsBasis, X::AbstractVector{<: AbstractVector}, Σ) +# nuc = basis.nuclei +# Nnuc = length(nuc) +# Nel = size(X, 1) +# T = promote_type(eltype(X[1])) + +# # XX = zeros(VT, (Nnuc, Nel)) + +# # # trans +# # for I = 1:Nnuc, i = 1:Nel +# # XX[I, i] = X[i] - nuc[I].rr +# # end + +# Nnlm = length(basis.prodbasis.sparsebasis.spec) +# ϕnlm = Zygote.Buffer(zeros(T, (Nnuc, Nel, Nnlm))) + +# # Think how to prevent this extra FLOPS here while keeping it Zygote-friendly +# for I = 1:Nnuc +# ϕnlm[I,:,:] = evaluate(basis.prodbasis, map(x -> x - nuc[I].rr, X)) +# end + +# return copy(ϕnlm) +# end + +# # ------------ utils for AtomicOrbitalsBasis ------------ +# function set_nuclei!(basis::AtomicOrbitalsBasis, nuclei::AbstractVector{<: Nuc}) +# basis.nuclei = copy(collect(nuclei)) +# return nothing +# end + + +# function get_spec(basis::AtomicOrbitalsBasis) +# spec = [] +# Nnuc = length(basis.nuclei) + +# spec = Array{Any}(undef, (3, Nnuc, length(basis.prodbasis.spec1))) + +# for (k, nlm) in enumerate(basis.prodbasis.spec1) +# for I = 1:Nnuc +# for (is, s) in enumerate(extspins()) +# spec[is, I, k] = (I = I, s=s, nlm...) +# end +# end +# end + +# return spec +# end + + +# # ------------ Evaluation kernels + + +# # ------------ connect with ChainRulesCore + +# # Placeholder for now, fix this later after making sure Zygote is done correct with Lux +# # function ChainRulesCore.rrule(::typeof(evaluate), basis::AtomicOrbitalsBasis, X::AbstractVector{<: AbstractVector}, Σ) +# # val = evaluate(basis, X, Σ) +# # dB = similar(X) +# # function pb(dA) +# # return NoTangent(), NoTangent(), dB, NoTangent() +# # end +# # return val, pb +# # end + +# # ------------ connect with Lux +# struct AtomicOrbitalsBasisLayer{TB} <: AbstractExplicitLayer +# basis::TB +# meta::Dict{String, Any} +# end + +# Base.length(l::AtomicOrbitalsBasisLayer) = length(l.basis) + +# lux(basis::AtomicOrbitalsBasis) = AtomicOrbitalsBasisLayer(basis, Dict{String, Any}()) + +# initialparameters(rng::AbstractRNG, l::AtomicOrbitalsBasisLayer) = _init_luxparams(rng, l.basis) + +# initialstates(rng::AbstractRNG, l::AtomicOrbitalsBasisLayer) = _init_luxstate(rng, l.basis) + +# (l::AtomicOrbitalsBasisLayer)(X, ps, st) = +# evaluate(l.basis, X, st.Σ), st + + +# This can be done using ObjectPools, but for simplicity I didn't do that for now since I +# don't want lux layers storing ObjectPools stuffs + +struct AtomicOrbitalsBasisLayer{L, T} <: AbstractExplicitContainerLayer{(:prodbasis, )} + prodbasis::L + nuclei::Vector{Nuc{T}} + @reqfields() +end + +#Base.length(l::AtomicOrbitalsBasisLayer) = length(l.prodbasis.layer.ϕnlms.basis.spec) * length(l.nuclei) + +function AtomicOrbitalsBasisLayer(prodbasis, nuclei) + return AtomicOrbitalsBasisLayer(prodbasis, nuclei, _make_reqfields()...) +end + +function evaluate(l::AtomicOrbitalsBasisLayer, X, ps, st) + nuc = l.nuclei + Nnuc = length(nuc) + Nel = size(X, 1) + T = promote_type(eltype(X[1])) + # acquire FlexArray/FlexArrayCached from state + Nnlm = l.prodbasis.L + ϕnlm = acquire!(l.pool, :ϕnlm, (Nnuc, Nel, Nnlm), T) + # inplace evaluation X + @inbounds for I = 1:Nnuc + @simd ivdep for i = 1:Nel + X[i] = X[i] - nuc[I].rr + end + ϕnlm[I,:,:], _ = l.prodbasis(X, ps, st) + @simd ivdep for i = 1:Nel + X[i] = X[i] + nuc[I].rr + end + end + # ϕnlm should be released in the next layer + return ϕnlm, st +end + +function ChainRulesCore.rrule(::typeof(apply), l::AtomicOrbitalsBasisLayer{L, T}, X::Vector{SVector{3, TX}}, ps, st) where {L, T, TX} + val = evaluate(l, X, ps, st) + nuc = l.nuclei + Nnuc = length(nuc) + function pb(dϕnlm) # dA is of a tuple (dAmat, st), dAmat is of size (Nnuc, Nel, Nnlm) + # first we pullback up to each Xts, which should be of size (Nnuc, Nel, 3) + dXts = Vector{SVector{3, TX}}[] + dps = deepcopy(ps) + if :ζ in keys(dps) + for t in 1:length(dps.ζ) + dps.ζ[t] = zero(TX) + end + for I = 1:Nnuc + # inplace trans X + X .-= Ref(nuc[I].rr) + # pullback of productbasis[I], now I used productbasis but generalized to specified atom-dependent basis later + # pbI : (Nel, Nnlm) -> vector of length Nel of SVector{3, T} + _out, pbI = Zygote.pullback(l.prodbasis::L, X, ps, st) + # write to dXts + Xts, _dp = pbI((dϕnlm[1][I,:,:], _out[2])) + push!(dXts, Xts) # out[2] is the state + for t in 1:length(dps.ζ) + dps.ζ[t] += _dp.ζ[t] + end + # get back to original X + X .+= Ref(nuc[I].rr) + end + # finally sum all contributions from different I channel, reduces to vector of length Nel of SVector{3, T} again + return NoTangent(), NoTangent(), sum(dXts), dps, NoTangent() + else + for I = 1:Nnuc + # inplace trans X + X .-= Ref(nuc[I].rr) + # pullback of productbasis[I], now I used productbasis but generalized to specified atom-dependent basis later + # pbI : (Nel, Nnlm) -> vector of length Nel of SVector{3, T} + _out, pbI = Zygote.pullback(l.prodbasis::L, X, ps, st) + # write to dXts + Xts, = pbI((dϕnlm[1][I,:,:], _out[2])) + push!(dXts, Xts) # out[2] is the state + # get back to original X + X .+= Ref(nuc[I].rr) + end + # finally sum all contributions from different I channel, reduces to vector of length Nel of SVector{3, T} again + return NoTangent(), NoTangent(), sum(dXts), NoTangent(), NoTangent() + end + end + return val, pb +end + +(l::AtomicOrbitalsBasisLayer)(X, ps, st) = + evaluate(l, X, ps, st) + +# ------------ utils for AtomicOrbitalsBasis ------------ +function set_nuclei!(basis::AtomicOrbitalsBasisLayer, nuclei::AbstractVector{<: Nuc}) + basis.nuclei = copy(collect(nuclei)) + return nothing +end + + +function get_spec(l::AtomicOrbitalsBasisLayer, spec1p) + spec = [] + Nnuc = length(l.nuclei) + + spec = Array{Any}(undef, (3, Nnuc, length(spec1p))) + + for (k, nlm) in enumerate(spec1p) + for I = 1:Nnuc + for (is, s) in enumerate(extspins()) + spec[is, I, k] = (I = I, s=s, nlm...) + end + end + end + + return spec +end \ No newline at end of file diff --git a/src/atomicorbitals/productbasis.jl b/src/atomicorbitals/productbasis.jl new file mode 100644 index 0000000..b2d29ac --- /dev/null +++ b/src/atomicorbitals/productbasis.jl @@ -0,0 +1,191 @@ +using Lux: WrappedFunction +using Lux +using Polynomials4ML: SparseProduct, AbstractPoly4MLBasis, release!, GaussianBasis, SlaterBasis, STO_NG, AtomicOrbitalsRadials, SVecPoly4MLBasis +import LuxCore +import LuxCore: initialparameters, initialstates, AbstractExplicitLayer +using Random: AbstractRNG + +AOR_type(TP, T, TI, Dn) = AtomicOrbitalsRadials{TP, Dn{T}, TI} + +function _invmap(a::AbstractVector) + inva = Dict{eltype(a), Int}() + for i = 1:length(a) + inva[a[i]] = i + end + return inva +end + +function dropnames(namedtuple::NamedTuple, names::Tuple{Vararg{Symbol}}) + keepnames = Base.diff_names(Base._nt_names(namedtuple), names) + return NamedTuple{keepnames}(namedtuple) +end + +struct ProductBasis{TDN, TP, T, TT, TI, NB} <: AbstractExplicitLayer + sparsebasis::SparseProduct{NB} + bRnl::Union{AOR_type(TP, T, TI, STO_NG), AOR_type(TP, T, TI, GaussianBasis), AOR_type(TP, T, TI, SlaterBasis)} + bYlm::Union{RYlmBasis{TT}, RRlmBasis{TT}} + L::Int + Dn::TDN +end + +ProductBasisLayer(spec1::Vector, bRnl::AbstractPoly4MLBasis, bYlm::AbstractPoly4MLBasis) = begin + spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) + spec_Rnl = natural_indices(bRnl); inv_Rnl = _invmap(spec_Rnl) + spec_Ylm = natural_indices(bYlm); inv_Ylm = _invmap(spec_Ylm) + + spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b,(:m,))], inv_Ylm[(l=b.l, m=b.m)]) + end + sparsebasis = SparseProduct(spec1idx) + return ProductBasis(sparsebasis, bRnl, bYlm, length(spec1), bRnl.Dn) +end + +initialparameters(rng::AbstractRNG, l::ProductBasis{GaussianBasis{T}, TP, T, TT, TI, NB}) where {TP, T, TT, TI, NB} = ( ζ = l.bRnl.Dn.ζ, ) +initialparameters(rng::AbstractRNG, l::ProductBasis{SlaterBasis{T}, TP, T, TT, TI, NB}) where {TP, T, TT, TI, NB} = ( ζ = l.bRnl.Dn.ζ, ) +initialparameters(rng::AbstractRNG, l::ProductBasis{STO_NG{T}, TP, T, TT, TI, NB}) where {TP, T, TT, TI, NB} = NamedTuple() + +initialstates(rng::AbstractRNG, l::ProductBasis{GaussianBasis{T}, TP, T, TT, TI, NB}) where {TP, T, TT, TI, NB} = NamedTuple() +initialstates(rng::AbstractRNG, l::ProductBasis{SlaterBasis{T}, TP, T, TT, TI, NB}) where {TP, T, TT, TI, NB} = NamedTuple() +initialstates(rng::AbstractRNG, l::ProductBasis{STO_NG{T}, TP, T, TT, TI, NB}) where {TP, T, TT, TI, NB} = (ζ = l.bRnl.Dn.ζ, ) + +(l::ProductBasis)(X, ps, st) = evaluate(l, X, ps, st) + +function evaluate(l::ProductBasis{GaussianBasis{T}, TP, T, TT, TI, NB}, X::Vector{SVector{3, TX}}, ps, st) where {TP, T, TT, TI, NB, TX} + RT = promote_type(T, TT, TX) + Nel = length(X) + R = acquire!(l.bRnl.pool, :R, (Nel,), RT) + @simd ivdep for i = 1:Nel + R[i] = norm(X[i]) + end + l.bRnl.Dn.ζ = ps[1] + _bRnl = evaluate(l.bRnl, R) + _Ylm = evaluate(l.bYlm, X) + _ϕnlm = evaluate(l.sparsebasis,(_bRnl, _Ylm)) + release!(_bRnl) + release!(_Ylm) + return _ϕnlm, st +end + +function evaluate(l::ProductBasis{SlaterBasis{T}, TP, T, TT, TI, NB}, X::Vector{SVector{3, TX}}, ps, st) where {TP, T, TT, TI, NB, TX} + RT = promote_type(T, TT, TX) + Nel = length(X) + R = acquire!(l.bRnl.pool, :R, (Nel,), RT) + @simd ivdep for i = 1:Nel + R[i] = norm(X[i]) + end + l.bRnl.Dn.ζ = ps[1] + _bRnl = evaluate(l.bRnl, R) + _Ylm = evaluate(l.bYlm, X) + _ϕnlm = evaluate(l.sparsebasis,(_bRnl, _Ylm)) + release!(_bRnl) + release!(_Ylm) + return _ϕnlm, st +end + +function evaluate(l::ProductBasis{STO_NG{T}, TP, T, TT, TI, NB}, X::Vector{SVector{3, TX}}, ps, st) where {TP, T, TT, TI, NB, TX} + RT = promote_type(T, TT, TX) + Nel = length(X) + R = acquire!(l.bRnl.pool, :R, (Nel,), RT) + @simd ivdep for i = 1:Nel + R[i] = norm(X[i]) + end + l.bRnl.Dn.ζ = st[1] + _bRnl = evaluate(l.bRnl, R) + _Ylm = evaluate(l.bYlm, X) + _ϕnlm = evaluate(l.sparsebasis,(_bRnl, _Ylm)) + release!(R) + release!(_bRnl) + release!(_Ylm) + return _ϕnlm, st +end + +using ChainRulesCore + +function ChainRulesCore.rrule(::typeof(evaluate), l::ProductBasis{GaussianBasis{T}, TP, T, TT, TI, NB}, X::Vector{SVector{3, TX}}, ps, st) where {TP, T, TT, TI, NB, TX} + RT = promote_type(T, TT, TX) + Nel = length(X) + R = acquire!(l.bRnl.pool, :R, (Nel,), RT) + @simd ivdep for i = 1:Nel + R[i] = norm(X[i]) + end + dnorm = X ./ R + _bRnl, dR, dζ = Polynomials4ML.evaluate_ed_dp(l.bRnl, R) + _bYlm, dX = evaluate_ed(l.bYlm, X) + val = evaluate(l.sparsebasis, (_bRnl, _bYlm)) + release!(_bRnl); release!(_bYlm) + ∂X = similar(X) + ∂ζ = similar(l.bRnl.Dn.ζ) + function pb(Δ) + ∂BB = Polynomials4ML._pullback_evaluate(Δ[1], l.sparsebasis, (_bRnl, _bYlm)) + for i = 1:length(X) + ∂X[i] = dot(@view(∂BB[1][i, :]), @view(dR[i, :])) * dnorm[i] + for j = 1:length(dX[i,:]) + ∂X[i] = muladd(∂BB[2][i,j], dX[i,j], ∂X[i]) + end + end + for i = 1:length(l.bRnl.Dn.ζ) + ∂ζ[i] = dot(@view(∂BB[1][:, i]), @view(dζ[:, i])) + end + return NoTangent(), NoTangent(), ∂X, (ζ = ∂ζ,), NoTangent() + end + release!(dX);release!(dR);release!(dζ); + return (val, st), pb +end + +function ChainRulesCore.rrule(::typeof(evaluate), l::ProductBasis{SlaterBasis{T}, TP, T, TT, TI, NB}, X::Vector{SVector{3, TX}}, ps, st) where {TP, T, TT, TI, NB, TX} + RT = promote_type(T, TT, TX) + Nel = length(X) + R = acquire!(l.bRnl.pool, :R, (Nel,), RT) + @simd ivdep for i = 1:Nel + R[i] = norm(X[i]) + end + dnorm = X ./ R + _bRnl, dR, dζ = Polynomials4ML.evaluate_ed_dp(l.bRnl, R) + _bYlm, dX = evaluate_ed(l.bYlm, X) + val = evaluate(l.sparsebasis, (_bRnl, _bYlm)) + release!(_bRnl); release!(_bYlm) + ∂X = similar(X) + ∂ζ = similar(l.bRnl.Dn.ζ) + function pb(Δ) + ∂BB = Polynomials4ML._pullback_evaluate(Δ[1], l.sparsebasis, (_bRnl, _bYlm)) + for i = 1:length(X) + ∂X[i] = dot(@view(∂BB[1][i, :]), @view(dR[i, :])) * dnorm[i] + for j = 1:length(dX[i,:]) + ∂X[i] = muladd(∂BB[2][i,j], dX[i,j], ∂X[i]) + end + end + for i = 1:length(l.bRnl.Dn.ζ) + ∂ζ[i] = dot(@view(∂BB[1][:, i]), @view(dζ[:, i])) + end + return NoTangent(), NoTangent(), ∂X, (ζ = ∂ζ,), NoTangent() + end + release!(dX);release!(dR);release!(dζ); + return (val, st), pb +end + + +function ChainRulesCore.rrule(::typeof(evaluate), l::ProductBasis{STO_NG{T}, TP, T, TT, TI, NB}, X::Vector{SVector{3, TX}}, ps, st) where {TP, T, TT, TI, NB, TX} + RT = promote_type(T, TT, TX) + Nel = length(X) + R = acquire!(l.bRnl.pool, :R, (Nel,), RT) + @simd ivdep for i = 1:Nel + R[i] = norm(X[i]) + end + dnorm = X ./ R + _bRnl, dR = evaluate_ed(l.bRnl, R) + _bYlm, dX = evaluate_ed(l.bYlm, X) + val = evaluate(l.sparsebasis, (_bRnl, _bYlm)) + ∂X = similar(X) + function pb(Δ) + ∂BB = Polynomials4ML._pullback_evaluate(Δ[1], l.sparsebasis, (_bRnl, _bYlm)) + for i = 1:length(X) + ∂X[i] = dot(@view(∂BB[1][i, :]), @view(dR[i, :])) * dnorm[i] + for j = 1:length(dX[i,:]) + ∂X[i] = muladd(∂BB[2][i,j], dX[i,j], ∂X[i]) + end + end + return NoTangent(), NoTangent(), ∂X, NoTangent(), NoTangent() + end + return (val, st), pb +end diff --git a/src/atomicorbitals/rnlexample.jl b/src/atomicorbitals/rnlexample.jl new file mode 100644 index 0000000..e9a3d68 --- /dev/null +++ b/src/atomicorbitals/rnlexample.jl @@ -0,0 +1,146 @@ +using Polynomials4ML, ForwardDiff +import Polynomials4ML: evaluate, evaluate_ed, evaluate_ed2, + natural_indices +using ChainRulesCore +using ChainRulesCore: NoTangent +const NLM{T} = NamedTuple{(:n, :l, :m), Tuple{T, T, T}} +const NL{T} = NamedTuple{(:n, :l), Tuple{T, T}} + +struct RnlExample{TP, TI} <: Polynomials4ML.AbstractPoly4MLBasis + Pn::TP + spec::Vector{NL{TI}} +end + +function RnlExample(totaldegree::Integer) + bPn = legendre_basis(totaldegree+1) + maxn = length(bPn) + spec = [ (n=n, l=l) for n = 1:maxn for l = 0:(totaldegree-n+1)] + return RnlExample(bPn, spec) +end + +Base.length(basis::RnlExample) = length(basis.spec) + +natural_indices(basis::RnlExample) = copy(basis.spec) + +degree(basis::RnlExample, i::Integer) = degree(basis, spec[i]) +degree(basis::RnlExample, b::NamedTuple) = b.n + b.l + +# -------- Evaluation Code + +_alloc(basis::RnlExample, r::T) where T = + zeros(T, length(Rnl)) + +_alloc(basis::RnlExample, rr::Vector{T}) where T = + zeros(T, length(rr), length(basis)) + + +evaluate(basis::RnlExample, r::Number) = evaluate(basis, [r,])[:] + +function evaluate(basis::RnlExample, R::AbstractVector) + nR = length(R) + Pn = Polynomials4ML.evaluate(basis.Pn, R) + Rnl = _alloc(basis, R) + + maxL = maximum(b.l for b in basis.spec) + rL = ones(eltype(R), length(R), maxL+1) + + # @inbounds begin + + for l = 1:maxL + # @simd ivdep + for j = 1:nR + rL[j, l+1] = R[j] * rL[j, l] # r^l + end + end + + for (i, b) in enumerate(basis.spec) + # @simd ivdep + for j = 1:nR + Rnl[j, i] = Pn[j, b.n] * rL[j, b.l+1] # r^l * P_n -> degree l+n + end + end + + # end + + return Rnl +end + + +function evaluate_ed(basis::RnlExample, R) + nR = length(R) + Pn, dPn = Polynomials4ML.evaluate_ed(basis.Pn, R) + Rnl = _alloc(basis, R) + dRnl = _alloc(basis, R) + + maxL = maximum(b.l for b in basis.spec) + rL = ones(eltype(R), length(R), maxL+1) + drL = zeros(eltype(R), length(R), maxL+1) + for l = 1:maxL + # @simd ivdep + for j = 1:nR + rL[j, l+1] = R[j] * rL[j, l] + drL[j, l+1] = l * rL[j, l] + end + end + + for (i, b) in enumerate(basis.spec) + # @simd ivdep + for j = 1:nR + Rnl[j, i] = Pn[j, b.n] * rL[j, b.l+1] + dRnl[j, i] = dPn[j, b.n] * rL[j, b.l+1] + Pn[j, b.n] * drL[j, b.l+1] + end + end + + return Rnl, dRnl +end + + +function evaluate_ed2(basis::RnlExample, R) + nR = length(R) + Pn, dPn, ddPn = Polynomials4ML.evaluate_ed2(basis.Pn, R) + Rnl = _alloc(basis, R) + dRnl = _alloc(basis, R) + ddRnl = _alloc(basis, R) + + maxL = maximum(b.l for b in basis.spec) + rL = ones(eltype(R), length(R), maxL+1) + drL = zeros(eltype(R), length(R), maxL+1) + ddrL = zeros(eltype(R), length(R), maxL+1) + for l = 1:maxL + # @simd ivdep + for j = 1:nR + rL[j, l+1] = R[j] * rL[j, l] # r^l + drL[j, l+1] = l * rL[j, l] # l * r^(l-1) + ddrL[j, l+1] = l * drL[j, l] # (l-1) * drL[j, l] # l * (l-1) * r^(l-2) + end + end + + for (i, b) in enumerate(basis.spec) + # @simd ivdep + for j = 1:nR + Rnl[j, i] = Pn[j, b.n] * rL[j, b.l+1] + dRnl[j, i] = dPn[j, b.n] * rL[j, b.l+1] + Pn[j, b.n] * drL[j, b.l+1] + ddRnl[j, i] = ( ddPn[j, b.n] * rL[j, b.l+1] + + 2 * dPn[j, b.n] * drL[j, b.l+1] + + Pn[j, b.n] * ddrL[j, b.l+1] ) + end + end + + return Rnl, dRnl, ddRnl +end + +using LinearAlgebra:dot + +function ChainRulesCore.rrule(::typeof(evaluate), basis::RnlExample, R) + A = evaluate(basis, R) + ∂R = similar(R) + dR = evaluate_ed(basis, R)[2] + function pb(∂A) + @assert size(∂A) == (length(R), length(basis)) + for i = 1:length(R) + ∂R[i] = dot(@view(∂A[i, :]), @view(dR[i, :])) + end + return NoTangent(), NoTangent(), ∂R + end + return A, pb +end \ No newline at end of file diff --git a/src/backflowpooling.jl b/src/backflowpooling.jl new file mode 100644 index 0000000..b6d6796 --- /dev/null +++ b/src/backflowpooling.jl @@ -0,0 +1,151 @@ +using ACEpsi.AtomicOrbitals: AtomicOrbitalsBasisLayer +using LuxCore: AbstractExplicitLayer +using Random: AbstractRNG +using ChainRulesCore: NoTangent + +using Polynomials4ML: _make_reqfields, @reqfields, POOL, TMP, META +using ObjectPools: acquire! + +import ChainRulesCore: rrule + +mutable struct BackflowPooling + basis::AtomicOrbitalsBasisLayer + @reqfields +end + +function BackflowPooling(basis::AtomicOrbitalsBasisLayer) + return BackflowPooling(basis, _make_reqfields()...) +end + +(pooling::BackflowPooling)(args...) = evaluate(pooling, args...) + +Base.length(pooling::BackflowPooling) = 3 * length(pooling.basis) # length(spin()) * length(1pbasis) + +function evaluate(pooling::BackflowPooling, ϕnlm::AbstractArray, Σ::AbstractVector) + Nnuc, _, Nnlm = size(ϕnlm) + Nel = length(Σ) + T = promote_type(eltype(ϕnlm)) + + # evaluate the pooling operation + # spin I k = (nlm) + + Aall = acquire!(pooling.tmp, :Aall, (2, Nnuc, Nnlm), T) + fill!(Aall, 0) + + @inbounds begin + for k = 1:Nnlm + for I = 1:Nnuc + @simd ivdep for i = 1:Nel + iσ = spin2idx(Σ[i]) + Aall[iσ, I, k] += ϕnlm[I, i, k] + end + end + end + end # inbounds + + # now correct the pooling Aall and write into A^(i) + # with do it with i leading so that the N-correlations can + # be parallelized over i + # + # A[i, :] = A-basis for electron i, with channels, s, I, k=nlm + # A[i, ∅, I, k] = ϕnlm[I, i, k] + # for σ = ↑ or ↓ we have + # A[i, σ, I, k] = ∑_{j ≂̸ i : Σ[j] == σ} ϕnlm[I, j, k] + # = ∑_{j : Σ[j] == σ} ϕnlm[I, j, k] - (Σ[i] == σ) * ϕnlm[I, i, k] + # + # + # TODO: discuss - this could be stored much more efficiently as a + # lazy array. Could that have advantages? + # + + @assert spin2idx(↑) == 1 + @assert spin2idx(↓) == 2 + @assert spin2idx(∅) == 3 + + A = acquire!(pooling.pool, :Aall, (Nel, 3, Nnuc, Nnlm), T) + fill!(A, 0) + + @inbounds begin + for k = 1:Nnlm + for I = 1:Nnuc + @simd ivdep for i = 1:Nel + A[i, 3, I, k] = ϕnlm[I, i, k] + end + @simd ivdep for iσ = 1:2 + σ = idx2spin(iσ) + for i = 1:Nel + A[i, iσ, I, k] = Aall[iσ, I, k] - (Σ[i] == σ) * ϕnlm[I, i, k] + end + end + end + end + end # inbounds + + release!(Aall) + release!(ϕnlm) + + return A +end + +# --------------------- connect with ChainRule +function rrule(::typeof(evaluate), pooling::BackflowPooling, ϕnlm, Σ::AbstractVector) + A = pooling(ϕnlm, Σ) + function pb(∂A) + return NoTangent(), NoTangent(), _pullback_evaluate(∂A, pooling, ϕnlm, Σ), NoTangent() + end + return A, pb +end + +function _rrule_evaluate(pooling::BackflowPooling, ϕnlm, Σ) + A = pooling(ϕnlm, Σ) + return A, ∂A -> _pullback_evaluate(∂A, pooling, ϕnlm, Σ) +end + +function _pullback_evaluate(∂A, pooling::BackflowPooling, ϕnlm, Σ) + TA = eltype(ϕnlm) + ∂ϕnlm = acquire!(pooling.pool, :∂ϕnlm, size(ϕnlm), TA) + fill!(∂ϕnlm, zero(TA)) + _pullback_evaluate!(∂ϕnlm, ∂A, pooling, ϕnlm, Σ) + return ∂ϕnlm +end + + +function _pullback_evaluate!(∂ϕnlm, ∂A, pooling::BackflowPooling, ϕnlm, Σ) + Nnuc, Nel, Nnlm = size(ϕnlm) + #basis = pooling.basis + + #@assert Nnlm == length(basis.prodbasis.layers.ϕnlms.basis.spec) + @assert Nel == length(Σ) + @assert size(∂ϕnlm) == (Nnuc, Nel, Nnlm) + @assert size(∂A) == (Nel, 3, Nnuc, Nnlm) + + for I = 1:Nnuc + for i = 1:Nel + for k = 1:Nnlm + ∂ϕnlm[I, i, k] += ∂A[i, 3, I, k] + for ii = 1:Nel + ∂ϕnlm[I, i, k] += ∂A[ii, spin2idx(Σ[i]), I, k] .* (i != ii) + end + end + end + end + + return nothing +end + +# --------------------- connect with Lux + +struct BackflowPoolingLayer <: AbstractExplicitLayer + basis::BackflowPooling +end + +lux(basis::BackflowPooling) = BackflowPoolingLayer(basis) + +initialparameters(rng::AbstractRNG, l::BackflowPoolingLayer) = _init_luxparams(rng, l.basis) + +initialstates(rng::AbstractRNG, l::BackflowPoolingLayer) = _init_luxstate(rng, l.basis) + +# This should be removed later and replace by ObejctPools +(l::BackflowPoolingLayer)(ϕnlm, ps, st) = + evaluate(l.basis, ϕnlm, st.Σ), st + diff --git a/src/bflow.jl b/src/bflow.jl index 2b9286a..07b5971 100644 --- a/src/bflow.jl +++ b/src/bflow.jl @@ -1,19 +1,20 @@ -using ACEcore, Polynomials4ML -using Polynomials4ML: OrthPolyBasis1D3T -using ACEcore: PooledSparseProduct, SparseSymmProdDAG, SparseSymmProd, release! -using ACEcore.Utils: gensparse +using Polynomials4ML +using Polynomials4ML: OrthPolyBasis1D3T, PooledSparseProduct, SparseSymmProdDAG, SparseSymmProd, release! +using Polynomials4ML.Utils: gensparse using LinearAlgebra: qr, I, logabsdet, pinv, mul!, dot , tr +using ObjectPools: unwrap + import ForwardDiff -mutable struct BFwf{T, TT, TPOLY, TE} +mutable struct BFwf1{T, TT, TPOLY, TE} trans::TT polys::TPOLY pooling::PooledSparseProduct{2} - corr::SparseSymmProdDAG{T} + corr::SparseSymmProdDAG W::Matrix{T} envelope::TE - spec::Vector{Vector{Int64}} # corr.spec TODO: this needs to be remove + spec::AbstractArray # ---------------- Temporaries P::Matrix{T} ∂P::Matrix{T} @@ -30,9 +31,9 @@ mutable struct BFwf{T, TT, TPOLY, TE} ∇AA::Array{T, 3} end -(Φ::BFwf)(args...) = evaluate(Φ, args...) +(Φ::BFwf1)(args...) = evaluate(Φ, args...) -function BFwf(Nel::Integer, polys; totdeg = length(polys), +function BFwf1(Nel::Integer, polys; totdeg = length(polys), ν = 3, T = Float64, trans = identity, sd_admissible = bb -> (true), @@ -58,14 +59,16 @@ function BFwf(Nel::Integer, polys; totdeg = length(polys), # further restrict spec = [t for t in spec if sd_admissible([spec1p[t[j]] for j = 1:length(t)])] - corr1 = SparseSymmProd(spec; T = Float64) - corr = corr1.dag + corr1 = SparseSymmProdDAG(spec) + corr = corr1 + + spec = Tuple.(spec) # initial guess for weights Q, _ = qr(randn(T, length(corr), Nel)) W = Matrix(Q) - return BFwf(trans, polys, pooling, corr, W, envelope, spec, + return BFwf1(trans, polys, pooling, corr, W, envelope, spec, zeros(T, Nel, length(polys)), zeros(T, Nel, length(polys)), zeros(T, Nel, length(polys)), @@ -88,7 +91,7 @@ This function return correct Si for pooling operation. function onehot!(Si, i, Σ) Si .= 0 for k = 1:length(Σ) - Si[k, spin2num(Σ[k])] = 1 + Si[k, spin2num1d(Σ[k])] = 1 end # set current electron to ϕ, also remove their contribution in the sum of ↑ or ↓ basis Si[i, 1] = 1 @@ -99,7 +102,7 @@ end """ This function convert spin to corresponding integer value used in spec """ -function spin2num(σ) +function spin2num1d(σ) if σ == '↑' return 2 elseif σ == '↓' @@ -110,25 +113,11 @@ function spin2num(σ) error("illegal spin char for spin2num") end -""" -This function convert num to corresponding spin string. -""" -function num2spin(σ) - if σ == 2 - return '↑' - elseif σ == 3 - return '↓' - elseif σ == 1 - return '∅' - end - error("illegal integer value for num2spin") -end - """ This function returns a nice version of spec. """ -function displayspec(wf::BFwf) +function displayspec(wf::BFwf1) K = length(wf.polys) spec1p = [ (k, σ) for σ in [1, 2, 3] for k in 1:K] spec1p = sort(spec1p, by = b -> b[1]) @@ -141,7 +130,7 @@ function displayspec(wf::BFwf) end -function assemble_A(wf::BFwf, X::AbstractVector, Σ, Pnn=nothing) +function assemble_A(wf::BFwf1, X::AbstractVector, Σ, Pnn=nothing) nX = length(X) # position embedding @@ -155,16 +144,16 @@ function assemble_A(wf::BFwf, X::AbstractVector, Σ, Pnn=nothing) for i = 1:nX onehot!(Si, i, Σ) - ACEcore.evalpool!(Ai, wf.pooling, (parent(P), Si)) + Polynomials4ML.evaluate!(Ai, wf.pooling, (unwrap(P), Si)) A[i, :] .= Ai end return A end -function evaluate(wf::BFwf, X::AbstractVector, Σ, Pnn=nothing) +function evaluate(wf::BFwf1, X::AbstractVector, Σ, Pnn=nothing) nX = length(X) A = assemble_A(wf, X, Σ) - AA = ACEcore.evaluate(wf.corr, A) # nX x length(wf.corr) + AA = Polynomials4ML.evaluate(wf.corr, A) # nX x length(wf.corr) # the only basis to be purified are those with same spin # scan through all corr basis, if they comes from same spin, remove self interation by using basis @@ -175,7 +164,7 @@ function evaluate(wf::BFwf, X::AbstractVector, Σ, Pnn=nothing) # === # Φ = wf.Φ - mul!(Φ, parent(AA), wf.W) # nX x nX + mul!(Φ, unwrap(AA), wf.W) # nX x nX Φ = Φ .* [Σ[i] == Σ[j] for j = 1:nX, i = 1:nX] # the resulting matrix should contains two block each comes from each spin release!(AA) @@ -184,13 +173,13 @@ function evaluate(wf::BFwf, X::AbstractVector, Σ, Pnn=nothing) end -function gradp_evaluate(wf::BFwf, X::AbstractVector, Σ) +function gradp_evaluate(wf::BFwf1, X::AbstractVector, Σ) nX = length(X) A = assemble_A(wf, X, Σ) - AA = ACEcore.evaluate(wf.corr, A) # nX x length(wf.corr) + AA = Polynomials4ML.evaluate(wf.corr, A) # nX x length(wf.corr) Φ = wf.Φ - mul!(Φ, parent(AA), wf.W) + mul!(Φ, unwrap(AA), wf.W) Φ = Φ .* [Σ[i] == Σ[j] for j = 1:nX, i = 1:nX] # the resulting matrix should contains two block each comes from each spin @@ -202,7 +191,7 @@ function gradp_evaluate(wf::BFwf, X::AbstractVector, Σ) # ∂Wij = ∑_ab ∂Φab * ∂_Wij( ∑_k AA_ak W_kb ) # = ∑_ab ∂Φab * ∑_k δ_ik δ_bj AA_ak # = ∑_a ∂Φaj AA_ai = ∂Φaj' * AA_ai - ∇p = transpose(parent(AA)) * ∂Φ + ∇p = transpose(unwrap(AA)) * ∂Φ release!(AA) ∇p = ∇p * 2 @@ -227,7 +216,7 @@ Base.setindex!(A::ZeroNoEffect, args...) = nothing Base.getindex(A::ZeroNoEffect, args...) = Bool(0) -function gradient(wf::BFwf, X, Σ) +function gradient(wf::BFwf1, X, Σ) nX = length(X) # ------ forward pass ----- @@ -257,16 +246,16 @@ function gradient(wf::BFwf, X, Σ) for i = 1:nX onehot!(Si, i, Σ) - ACEcore.evalpool!(Ai, wf.pooling, (parent(P), Si)) + Polynomials4ML.evaluate!(Ai, wf.pooling, (unwrap(P), Si)) A[i, :] .= Ai end # n-correlations - AA = ACEcore.evaluate(wf.corr, A) # nX x length(wf.corr) + AA = Polynomials4ML.evaluate(wf.corr, A) # nX x length(wf.corr) # generalized orbitals Φ = wf.Φ - mul!(Φ, parent(AA), wf.W) + mul!(Φ, unwrap(AA), wf.W) # the resulting matrix should contains two block each comes from each spin Φ = Φ .* [Σ[i] == Σ[j] for j = 1:nX, i = 1:nX] @@ -287,7 +276,7 @@ function gradient(wf::BFwf, X, Σ) # ∂A = ∂ψ/∂A = ∂ψ/∂AA * ∂AA/∂A -> use custom pullback ∂A = wf.∂A # zeros(size(A)) - ACEcore.pullback_arg!(∂A, ∂AA, wf.corr, parent(AA)) + Polynomials4ML.pullback_arg!(∂A, ∂AA, wf.corr, unwrap(AA)) release!(AA) # ∂P = ∂ψ/∂P = ∂ψ/∂A * ∂A/∂P -> use custom pullback @@ -301,7 +290,7 @@ function gradient(wf::BFwf, X, Σ) onehot!(Si_, i, Σ) # note this line ADDS the pullback into ∂P, not overwrite the content!! ∂Ai = @view ∂A[i, :] - ACEcore._pullback_evalpool!((∂P, ∂Si), ∂Ai, wf.pooling, (P, Si_)) + Polynomials4ML._pullback_evaluate!((∂P, ∂Si), ∂Ai, wf.pooling, (P, Si_)) end # ∂X = ∂ψ/∂X = ∂ψ/∂P * ∂P/∂X @@ -326,7 +315,7 @@ end # ------------------ Laplacian implementation -function laplacian(wf::BFwf, X, Σ) +function laplacian(wf::BFwf1, X, Σ) A, ∇A, ΔA = _assemble_A_∇A_ΔA(wf, X, Σ) AA, ∇AA, ΔAA = _assemble_AA_∇AA_ΔAA(A, ∇A, ΔA, wf) @@ -370,7 +359,7 @@ function _assemble_A_∇A_ΔA(wf, X, Σ) @inbounds for i = 1:nX # loop over orbital bases (which i becomes ∅) fill!(Si_, 0) onehot!(Si_, i, Σ) - ACEcore.evalpool!(Ai, wf.pooling, (P, Si_)) + Polynomials4ML.evaluate!(Ai, wf.pooling, (P, Si_)) @. A[i, :] .= Ai for (iA, (k, σ)) in enumerate(spec_A) for a = 1:nX @@ -424,7 +413,7 @@ function _laplacian_inner(AA, ∇AA, ΔAA, wf, Σ) # the wf, and the first layer of derivatives Φ = wf.Φ - mul!(Φ, parent(AA), wf.W) + mul!(Φ, unwrap(AA), wf.W) Φ = Φ .* [Σ[i] == Σ[j] for j = 1:nX, i = 1:nX] # the resulting matrix should contains two block each comes from each spin Φ⁻ᵀ = transpose(pinv(Φ)) @@ -453,7 +442,7 @@ end # ------------------ gradp of Laplacian -function gradp_laplacian(wf::BFwf, X, Σ) +function gradp_laplacian(wf::BFwf1, X, Σ) # ---- gradp of Laplacian of Ψ ---- @@ -466,7 +455,7 @@ function gradp_laplacian(wf::BFwf, X, Σ) # the wf, and the first layer of derivatives Φ = wf.Φ - mul!(Φ, parent(AA), wf.W) + mul!(Φ, unwrap(AA), wf.W) Φ = Φ .* [Σ[i] == Σ[j] for j = 1:nX, i = 1:nX] # the resulting matrix should contains two block each comes from each spin Φ⁻¹ = pinv(Φ) @@ -527,13 +516,13 @@ end -# ----------------- BFwf parameter wraging +# ----------------- BFwf1 parameter wraging -function get_params(U::BFwf) +function get_params(U::BFwf1) return (U.W, U.envelope.ξ) end -function set_params!(U::BFwf, para) +function set_params!(U::BFwf1, para) U.W = para[1] set_params!(U.envelope, para[2]) return U diff --git a/src/bflow3d.jl b/src/bflow3d.jl new file mode 100644 index 0000000..140aaea --- /dev/null +++ b/src/bflow3d.jl @@ -0,0 +1,133 @@ +using LinearAlgebra: det +using LuxCore: AbstractExplicitLayer +using Lux: Chain, WrappedFunction, BranchLayer +using ChainRulesCore: NoTangent + +using ACEpsi: ↑, ↓, ∅, spins, extspins, Spin, spin2idx, idx2spin +using ACEpsi.AtomicOrbitals: make_nlms_spec +using Polynomials4ML: LinearLayer, PooledSparseProduct, SparseSymmProd +using Polynomials4ML.Utils: gensparse +using ObjectPools: release! + +using Polynomials4ML, Random, ACEpsi, Lux, ChainRulesCore + +import ForwardDiff + +# ----------------- custom layers ------------------ +import ChainRulesCore: rrule + +struct MaskLayer <: AbstractExplicitLayer + nX::Int64 +end + +(l::MaskLayer)(Φ, ps, st) = begin + T = eltype(Φ) + A::Matrix{Bool} = [st.Σ[i] == st.Σ[j] for j = 1:l.nX, i = 1:l.nX] + val::Matrix{T} = Φ .* A + release!(Φ) + return val, st +end + +function rrule(::typeof(Lux.apply), l::MaskLayer, Φ, ps, st) + T = eltype(Φ) + A::Matrix{Bool} = [st.Σ[i] == st.Σ[j] for j = 1:l.nX, i = 1:l.nX] + val::Matrix{T} = Φ .* A + function pb(dΦ) + return NoTangent(), NoTangent(), dΦ[1] .* A, NoTangent(), NoTangent() + end + release!(Φ) + return (val, st), pb +end + +struct myReshapeLayer{N} <: AbstractExplicitLayer + dims::NTuple{N, Int} +end + +@inline function (r::myReshapeLayer)(x::AbstractArray, ps, st::NamedTuple) + return reshape(unwrap(x), r.dims), st +end + +function rrule(::typeof(Lux.apply), l::myReshapeLayer{N}, X, ps, st) where {N} + val = l(X, ps, st) + function pb(dϕnlm) # dA is of a tuple (dAmat, st), dAmat is of size (Nnuc, Nel, Nnlm) + A = reshape(unwrap(dϕnlm[1]), size(X)) + return NoTangent(), NoTangent(), A, NoTangent(), NoTangent() + end + return val, pb +end + +# ----------------- wf utils ------------------ +function get_spec(nuclei, spec1p) + spec = [] + Nnuc = length(nuclei) + + spec = Array{Any}(undef, (3, Nnuc, length(spec1p))) + + for (k, nlm) in enumerate(spec1p) + for I = 1:Nnuc + for (is, s) in enumerate(extspins()) + spec[is, I, k] = (s=s, I = I, nlm...) + end + end + end + + return spec[:] +end + +function displayspec(spec, spec1p) + nicespec = [] + for k = 1:length(spec) + push!(nicespec, ([spec1p[spec[k][j]] for j = 1:length(spec[k])])) + end + return nicespec +end + +# ---------------- BFlux ---------------------- +function BFwf_lux(Nel::Integer, bRnl, bYlm, nuclei; totdeg = 15, + ν = 3, T = Float64, + sd_admissible = bb -> prod(b.s != '∅' for b in bb) == 0) + + spec1p = make_nlms_spec(bRnl, bYlm; + totaldegree = totdeg) + + # ----------- Lux connections --------- + # AtomicOrbitalsBasis: (X, Σ) -> (length(nuclei), nX, length(spec1)) + prodbasis_layer = ACEpsi.AtomicOrbitals.ProductBasisLayer(spec1p, bRnl, bYlm) + aobasis_layer = ACEpsi.AtomicOrbitals.AtomicOrbitalsBasisLayer(prodbasis_layer, nuclei) + + # BackFlowPooling: (length(nuclei), nX, length(spec1 from totaldegree)) -> (nX, 3, length(nuclei), length(spec1)) + pooling = BackflowPooling(aobasis_layer) + pooling_layer = ACEpsi.lux(pooling) + + spec1p = get_spec(nuclei, spec1p) + # define sparse for n-correlations + tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] + default_admissible = bb -> (length(bb) == 0) || (sum(b.n1 - 1 for b in bb ) <= totdeg) + + specAA = gensparse(; NU = ν, tup2b = tup2b, admissible = default_admissible, + minvv = fill(0, ν), + maxvv = fill(length(spec1p), ν), + ordered = true) + spec = [ vv[vv .> 0] for vv in specAA if !(isempty(vv[vv .> 0]))] + + # further restrict + spec = [t for t in spec if sd_admissible([spec1p[t[j]] for j = 1:length(t)])] + + # define n-correlation + corr1 = Polynomials4ML.SparseSymmProd(spec) + + # (nX, 3, length(nuclei), length(spec1 from totaldegree)) -> (nX, length(spec)) + corr_layer = Polynomials4ML.lux(corr1; use_cache = false) + + #js = Jastrow(nuclei) + #jastrow_layer = ACEpsi.lux(js) + + #reshape_func = x -> reshape(x, (size(x, 1), prod(size(x)[2:end]))) + + #_det = x -> size(x) == (1, 1) ? x[1,1] : det(Matrix(x)) + BFwf_chain = Chain(; ϕnlm = aobasis_layer, bA = pooling_layer, reshape = myReshapeLayer((Nel, 3 * length(nuclei) * length(prodbasis_layer.sparsebasis))), + bAA = corr_layer, hidden1 = LinearLayer(length(corr1), Nel), + Mask = ACEpsi.MaskLayer(Nel), det = WrappedFunction(x -> det(x)), logabs = WrappedFunction(x -> 2 * log(abs(x))) ) + # return Chain(; branch = BranchLayer(; js = jastrow_layer, bf = BFwf_chain, ), prod = WrappedFunction(x -> x[1] * x[2]), logabs = WrappedFunction(x -> 2 * log(abs(x))) ), spec, spec1p + return BFwf_chain, spec, spec1p +end diff --git a/src/hyper.jl b/src/hyper.jl new file mode 100644 index 0000000..b0d4d50 --- /dev/null +++ b/src/hyper.jl @@ -0,0 +1,99 @@ +using HyperDualNumbers: Hyper + +Base.real(x::Hyper{<:Number}) = Hyper(real(x.value), real(x.epsilon1), real(x.epsilon2), real(x.epsilon12)) + +struct NTarr{NTT} + nt::NTT +end + +export array + +array(nt::NamedTuple) = NTarr(nt) + +# ------------------------------ +# 0 + +zero!(a::AbstractArray) = fill!(a, zero(eltype(a))) +zero!(a::Nothing) = nothing + +function zero!(nt::NamedTuple) + for k in keys(nt) + zero!(nt[k]) + end + return nt +end + +Base.zero(nt::NamedTuple) = zero!(deepcopy(nt)) + +Base.zero(nt::NTarr) = NTarr(zero(nt.nt)) + +# ------------------------------ +# + + + +function _add!(a1::AbstractArray, a2::AbstractArray) + a1[:] .= a1[:] .+ a2[:] + return nothing +end + +_add!(at::Nothing, args...) = nothing + +function _add!(nt1::NamedTuple, nt2) + for k in keys(nt1) + _add!(nt1[k], nt2[k]) + end + return nothing +end + +function _add(nt1::NamedTuple, nt2::NamedTuple) + nt = deepcopy(nt1) + _add!(nt, nt2) + return nt +end + +Base.:+(nt1::NTarr, nt2::NTarr) = NTarr(_add(nt1.nt, nt2.nt)) + +# ------------------------------ +# * + +_mul!(::Nothing, args... ) = nothing + +function _mul!(a::AbstractArray, λ::Number) + a[:] .= a[:] .* λ + return nothing +end + +function _mul!(nt::NamedTuple, λ::Number) + for k in keys(nt) + _mul!(nt[k], λ) + end + return nothing +end + +function _mul(nt::NamedTuple, λ::Number) + nt = deepcopy(nt) + _mul!(nt, λ) + return nt +end + +Base.:*(λ::Number, nt::NTarr) = NTarr(_mul(nt.nt, λ)) +Base.:*(nt::NTarr, λ::Number) = NTarr(_mul(nt.nt, λ)) + +# ------------------------------ +# map + +_map!(f, a::AbstractArray) = map!(f, a, a) + +_map!(f, ::Nothing) = nothing + +function _map!(f, nt::NamedTuple) + for k in keys(nt) + _map!(f, nt[k]) + end + return nothing +end + +function Base.map!(f, dest::NTarr, src::NTarr) + _map!(f, nt.nt) + return nt +end diff --git a/src/jastrow.jl b/src/jastrow.jl new file mode 100644 index 0000000..e110103 --- /dev/null +++ b/src/jastrow.jl @@ -0,0 +1,84 @@ +# This should probably be re-implemented properly and be made performant. +# the current version is just for testing +using ACEpsi.AtomicOrbitals: Nuc +using LuxCore +using LuxCore: AbstractExplicitLayer +using Random: AbstractRNG +using Zygote: Buffer + +mutable struct Jastrow{T} + nuclei::Vector{Nuc{T}} # nuclei +end + +(f::Jastrow)(args...) = evaluate(f, args...) + +## F_2(x) = -1/2\sum_{l=1}^L \sum_{i=1}^N Z_l|yi,l|+1/4\sum_{1\leq i (Σ = Σ, ))...) + else + rp_st = (; rp_st..., keys(nt)[i] => replace_namedtuples(nt[i], (;), Σ)) + end + end + return rp_st + end +end + +function setupBFState(rng, bf, Σ) + ps, st = LuxCore.setup(rng, bf) + rp_st = replace_namedtuples(st, (;), Σ) + return ps, rp_st +end diff --git a/src/spins.jl b/src/spins.jl new file mode 100644 index 0000000..6e1fad9 --- /dev/null +++ b/src/spins.jl @@ -0,0 +1,50 @@ + +using StaticArrays: SA + +export ↑, ↓, spins + + +# Define the spin types and variables +const Spin = Char +const ↑ = '↑' +const ↓ = '↓' +const ∅ = '∅' # this is only for internal use +_spins = SA[↑, ↓] +_extspins = SA[↑, ↓, ∅] + +spins() = _spins +extspins() = _extspins + + +""" +This function convert spin to corresponding integer value used in spec +""" +function spin2idx(σ) + if σ == ↑ + return 1 + elseif σ == ↓ + return 2 + elseif σ == ∅ + return 3 + end + error("illegal spin char for spin2idx") +end + +""" +This function convert idx to corresponding spin string. +""" +function idx2spin(i) + if i == 1 + return ↑ + elseif i == 2 + return ↓ + elseif i == 3 + return ∅ + end + error("illegal integer value for idx2spin") +end + + +# TODO : deprecate these +const spin2num = spin2idx +const num2spin = idx2spin diff --git a/src/vmc/Eloc.jl b/src/vmc/Eloc.jl new file mode 100644 index 0000000..194b951 --- /dev/null +++ b/src/vmc/Eloc.jl @@ -0,0 +1,58 @@ +export SumH +using ACEpsi.AtomicOrbitals: Nuc +using StaticArrays + +# INTERFACE FOR HAMILTIANS H ψ -> H(psi, X) +struct SumH{T} + nuclei::Vector{Nuc{T}} +end + +function Vee(wf, X::Vector{SVector{3, T}}, ps, st) where {T} + nX = length(X) + v = zero(T) + r = zero(T) + @inbounds begin + for i = 1:nX-1 + @simd ivdep for j = i+1:nX + r = norm(X[i]-X[j]) + v = muladd(1, 1/r, v) + end + end + end + return v +end + +function Vext(wf, X::Vector{SVector{3, T}}, nuclei::Vector{Nuc{TT}}, ps, st) where {T, TT} + nX = length(X) + v = zero(T) + r = zero(T) + @inbounds begin + for i = 1:length(nuclei) + @simd ivdep for j = 1:nX + r = norm(nuclei[i].rr - X[j]) + v = muladd(nuclei[i].charge, 1/r, v) + end + end + end + return -v +end + +K(wf, X::Vector{SVector{3, T}}, ps, st) where {T} = -0.5 * laplacian(wf, X, ps, st) + +(H::SumH)(wf, X::Vector{SVector{3, T}}, ps, st) where {T} = + K(wf, X, ps, st) + (Vext(wf, X, H.nuclei, ps, st) + Vee(wf, X, ps, st)) * evaluate(wf, X, ps, st) + + +# evaluate local energy with SumH + +""" +E_loc = E_pot - 1/4 ∇²ᵣ ϕ(r) - 1/8 (∇ᵣ ϕ)²(r) +https://arxiv.org/abs/2105.08351 +""" + +function Elocal(H::SumH, wf, X::AbstractVector, ps, st) + gra = gradient(wf, X, ps, st) + val = Vext(wf, X, H.nuclei, ps, st) + Vee(wf, X, ps, st) - 1/4 * laplacian(wf, X, ps, st) - 1/8 * gra' * gra + return val +end + diff --git a/src/vmc/gradient.jl b/src/vmc/gradient.jl new file mode 100644 index 0000000..d5cb390 --- /dev/null +++ b/src/vmc/gradient.jl @@ -0,0 +1,22 @@ +using Zygote +using HyperDualNumbers: Hyper + +x2dualwrtj(x, j) = SVector{3}([Hyper(x[i], i == j, i == j, 0) for i = 1:3]) + +gradient(wf, x, ps, st) = Zygote.gradient(x -> wf(x, ps, st)[1], x)[1] + +grad_params(wf, x, ps, st) = Zygote.gradient(p -> wf(x, p, st)[1], ps)[1] + +function laplacian(wf, x, ps, st) + ΔΨ = 0.0 + hX = [x2dualwrtj(xx, 0) for xx in x] + Nel = length(x) + for i = 1:3 + for j = 1:Nel + hX[j] = x2dualwrtj(x[j], i) # ∂Φ/∂xj_{i} + ΔΨ += wf(hX, ps, st)[1].epsilon12 + hX[j] = x2dualwrtj(x[j], 0) + end + end + return ΔΨ +end diff --git a/src/vmc/metropolis.jl b/src/vmc/metropolis.jl new file mode 100644 index 0000000..ee6c895 --- /dev/null +++ b/src/vmc/metropolis.jl @@ -0,0 +1,166 @@ +using StatsBase +using StaticArrays +using Optimisers +export MHSampler +using ACEpsi.AtomicOrbitals: Nuc +using Lux: Chain + +""" +`MHSampler` +Metropolis-Hastings sampling algorithm. +""" +mutable struct MHSampler{T} + Nel::Int64 + nuclei::Vector{Nuc{T}} + Δt::Float64 # step size (of Gaussian proposal) + burnin::Int64 # burn-in iterations + lag::Int64 # iterations between successive samples + N_batch::Int64 # batch size + nchains::Int64 # Number of chains + Ψ::Chain # many-body wavefunction for sampling + x0::Vector # initial sampling + walkerType::String # walker type: "unbiased", "Langevin" + bc::String # boundary condition + type::Int64 # move how many electron one time +end + +MHSampler(Ψ, Nel, nuclei; Δt = 0.1, + burnin = 100, + lag = 10, + N_batch = 1, + nchains = 1000, + x0 = Vector{Vector{SVector{3, Float64}}}(undef, nchains), + wT = "unbiased", + bc = "periodic", + type = 1) = + MHSampler(Nel, nuclei, Δt, burnin, lag, N_batch, nchains, Ψ, x0, wT, bc, type) + + +""" +unbiased random walk: R_n+1 = R_n + Δ⋅Wn +biased random walk: R_n+1 = R_n + Δ⋅Wn + Δ⋅∇(log Ψ)(R_n) +""" + +eval(wf, X::AbstractVector, ps, st) = wf(X, ps, st)[1] + +function MHstep(r0::Vector{Vector{SVector{3, TT}}}, + Ψx0::Vector{T}, + Nels::Int64, + sam::MHSampler, ps::NamedTuple, st::NamedTuple) where {T, TT} + rand_sample(X::Vector{SVector{3, TX}}, Nels::Int, Δt::Float64) where {TX}= begin + return X + Δt * randn(SVector{3, TX}, Nels) + end + rp = rand_sample.(r0, Ref(Nels), Ref(sam.Δt)) + Ψxp::Vector{T} = eval.(Ref(sam.Ψ), rp, Ref(ps), Ref(st)) + accprob = accfcn(Ψx0, Ψxp) + u = rand(sam.nchains) + acc = u .<= accprob[:] + r::Vector{Vector{SVector{3, TT}}} = acc .* rp + (1.0 .- acc) .* r0 + Ψ = acc .* Ψxp + (1.0 .- acc) .* Ψx0 + return r, Ψ, acc +end + +""" +acceptance rate for log|Ψ| +ψₜ₊₁²/ψₜ² = exp((log|Ψₜ₊₁|^2-log |ψₜ|^2)) +""" + +function accfcn(Ψx0::Vector{T}, Ψxp::Vector{T}) where {T} + acc = exp.(Ψxp .- Ψx0) + return acc +end + +"""============== Metropolis sampling algorithm ============ +type = "restart" +""" + +function pos(sam::MHSampler) + T = eltype(sam.nuclei[1].rr) + M = length(sam.nuclei) + rr = zeros(SVector{3, T}, sam.Nel) + tt = zeros(Int, 1) + @inbounds begin + for i = 1:M + @simd ivdep for j = Int(ceil(sam.nuclei[i].charge)) + tt[1] += 1 + rr[tt[1]] = sam.nuclei[i].rr + end + end + end + return rr +end + +function sampler_restart(sam::MHSampler, ps, st) + r = pos(sam) + T = eltype(r[1]) + r0 = sam.x0 + r0 = [sam.Δt * randn(SVector{3, T}, sam.Nel) + r for _ = 1:sam.nchains] + Ψx0 = eval.(Ref(sam.Ψ), r0, Ref(ps), Ref(st)) + acc = zeros(T, sam.burnin) + for i = 1 : sam.burnin + r0, Ψx0, a = MHstep(r0, Ψx0, sam.Nel, sam, ps, st); + acc[i] = mean(a) + end + return r0, Ψx0, mean(acc) +end + +""" +type = "continue" +start from the previous sampling x0 +""" +function sampler(sam::MHSampler, ps, st) + r0 = sam.x0 + Ψx0 = eval.(Ref(sam.Ψ), r0, Ref(ps), Ref(st)) + T = eltype(r0[1][1]) + acc = zeros(T, sam.lag) + for i = 1:sam.lag + r0, Ψx0, a = MHstep(r0, Ψx0, sam.Nel, sam, ps, st); + acc[i] = mean(a) + end + return r0, Ψx0, mean(acc) +end + + + +""" +Rayleigh quotient by VMC using Metropolis sampling +""" +function rq_MC(Ψ, sam::MHSampler, ham::SumH, ps, st) + r, ~, acc = sampler(sam, ps, st); + Eloc = Elocal.(Ref(ham), Ref(Ψ), r, Ref(sam.Σ)) + val = sum(Eloc) / length(Eloc) + var = sqrt(sum((Eloc .-val).^2)/(length(Eloc)*(length(Eloc)-1))) + return val, var, acc +end + +function Eloc_Exp_TV_clip(wf, ps, st, + sam::MHSampler, + ham::SumH; + clip = 5.) + x, ~, acc = sampler(sam, ps, st) + Eloc = Elocal.(Ref(ham), Ref(wf), x, Ref(ps), Ref(st)) + val = sum(Eloc) / length(Eloc) + var = sqrt(sum((Eloc .-val).^2)/(length(Eloc)*(length(Eloc) -1))) + ΔE = Eloc .- median( Eloc ) + a = clip * mean( abs.(ΔE) ) + ind = findall(x -> abs(x) > a, ΔE) + ΔE[ind] = (a * sign.(ΔE) .* (1 .+ log.((1 .+(abs.(ΔE)/a).^2)/2)))[ind] + E_clip = median(Eloc) .+ ΔE + return val, var, E_clip, x, acc +end + +function params(a::NamedTuple) + p,= destructure(a) + return p +end + +function grad(wf, x, ps, st, E); + dy = grad_params.(Ref(wf), x, Ref(ps), Ref(st)); + N = length(x) + p = params.(dy) + _,t = destructure(dy[1]) + g = 1/N * sum( p .* E) - 1/(N^2) * sum(E) * sum(p) + g = t(g) + return g; +end + diff --git a/src/vmc/multilevel.jl b/src/vmc/multilevel.jl new file mode 100644 index 0000000..18761dd --- /dev/null +++ b/src/vmc/multilevel.jl @@ -0,0 +1,151 @@ +export EmbeddingW!, _invmap, VMC_multilevel, wf_multilevel, gd_GradientByVMC_multilevel +using Printf +using LinearAlgebra +using Optimisers +using Polynomials4ML +using Random +using ACEpsi: BackflowPooling, BFwf_lux, setupBFState, Jastrow, displayspec +using ACEpsi.AtomicOrbitals: _invmap + +mutable struct VMC_multilevel + tol::Float64 + MaxIter::Vector{Int} + lr::Float64 + lr_dc::Float64 + type::opt +end + +VMC_multilevel(MaxIter::Vector{Int}, lr::Float64, type; tol = 1.0e-3, lr_dc = 50.0) = VMC_multilevel(tol, MaxIter, lr, lr_dc, type); + +# TODO: this should be implemented to recursively embed the wavefunction + +function _invmapAO(a::AbstractVector) + inva = Dict{eltype(a), Int}() + for i = 1:length(a) + inva[a[i]] = i + end + return inva +end + +function EmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2) + readable_spec = displayspec(spec, spec1p) + readable_spec2 = displayspec(spec2, spec1p2) + @assert size(ps.hidden1.W, 1) == size(ps2.hidden1.W, 1) + @assert size(ps.hidden1.W, 2) ≤ size(ps2.hidden1.W, 2) + @assert all(t in readable_spec2 for t in readable_spec) + @assert all(t in specAO2 for t in specAO) + + # set all parameters to zero + ps2.hidden1.W .= 0.0 + + # _map[spect] = index in readable_spec2 + _map = _invmap(readable_spec2) + _mapAO = _invmapAO(specAO2) + # embed + for (idx, t) in enumerate(readable_spec) + ps2.hidden1.W[:, _map[t]] = ps.hidden1.W[:, idx] + end + if :ϕnlm in keys(ps) + if :ζ in keys(ps.ϕnlm) + ps2.ϕnlm.ζ .= 0.0 + for (idx, t) in enumerate(specAO) + ps2.ϕnlm.ζ[_mapAO[t]] = ps.ϕnlm.ζ[idx] + end + end + end + return ps2 +end + +function gd_GradientByVMC_multilevel(opt_vmc::VMC_multilevel, sam::MHSampler, ham::SumH, wf_list, ps_list, st_list, spec_list, spec1p_list, specAO_list; verbose = true, accMCMC = [10, [0.45, 0.55]]) + + # first level + wf = wf_list[1] + ps = ps_list[1] + st = st_list[1] + spec = spec_list[1] + spec1p = spec1p_list[1] + specAO = specAO_list[1] + + # burn in + res, λ₀, α = 1.0, 0., opt_vmc.lr + err_opt = [zeros(opt_vmc.MaxIter[i]) for i = 1:length(opt_vmc.MaxIter)] + + x0, ~, acc = sampler_restart(sam, ps, st) + acc_step, acc_range = accMCMC + acc_opt = zeros(acc_step) + + + verbose && @printf("Initialize MCMC: Δt = %.2f, accRate = %.4f \n", sam.Δt, acc) + + verbose && @printf(" k | 𝔼[E_L] | V[E_L] | res | LR |accRate| Δt \n") + for l in 1:length(wf_list) + # do embeddings + if l > 1 + wf = wf_list[l] + # embed + ps = EmbeddingW!(ps, ps_list[l], spec, spec_list[l], spec1p, spec1p_list[l], specAO, specAO_list[l]) + st = st_list[l] + spec = spec_list[l] + spec1p = spec1p_list[l] + sam.Ψ = wf + end + _basis_size = size(ps.hidden1.W, 2) + ν = maximum(length.(spec)) + # optimization + @info("level = $l, order = $ν, size of basis = $_basis_size") + for k = 1 : opt_vmc.MaxIter[l] + sam.x0 = x0 + + # adjust Δt + acc_opt[mod(k,acc_step)+1] = acc + sam.Δt = acc_adjust(k, sam.Δt, acc_opt, acc_range, acc_step) + + # adjust learning rate + α, ν = InverseLR(ν, opt_vmc.lr, opt_vmc.lr_dc) + + # optimization + ps, acc, λ₀, res, σ = Optimization(opt_vmc.type, wf, ps, st, sam, ham, α) + + # err + verbose && @printf(" %3.d | %.5f | %.5f | %.5f | %.5f | %.3f | %.3f \n", k, λ₀, σ, res, α, acc, sam.Δt) + err_opt[l][k] = λ₀ + + if res < opt_vmc.tol + ps_list[l] = deepcopy(ps) + break; + end + end + ps_list[l] = deepcopy(ps) + end + + return wf_list, err_opt, ps_list +end + +function wf_multilevel(Nel::Int, Σ::Vector{Char}, nuclei::Vector{Nuc{T}}, Rnldegree::Vector{Int}, Ylmdegree::Vector{Int}, totdegree::Vector{Int}, n2::Vector{Int}, ν::Vector{Int}) where {T} + level = length(Rnldegree) + # init a list of wf + wf = [] + specAO = [] + spec = [] + spec1p = [] + ps = [] + st = [] + for i = 1:level + Pn = Polynomials4ML.legendre_basis(Rnldegree[i]+1) + _spec = [(n1 = n1, n2 = _n2, l = l) for n1 = 1:Rnldegree[i] for _n2 = 1:n2[i] for l = 0:Rnldegree[i]-1] + push!(specAO, _spec) + ζ = 10 * rand(length(_spec)) + Dn = SlaterBasis(ζ) + bRnl = AtomicOrbitalsRadials(Pn, Dn, _spec) + bYlm = RYlmBasis(Ylmdegree[i]) + _wf, _spec, _spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree[i], ν = ν[i]) + _ps, _st = setupBFState(MersenneTwister(1234), _wf, Σ) + push!(wf, _wf) + push!(spec, _spec) + push!(spec1p, _spec1p) + push!(ps, _ps) + push!(st, _st) + end + return wf, spec, spec1p, specAO, ps, st +end + diff --git a/src/vmc/opt.jl b/src/vmc/opt.jl new file mode 100644 index 0000000..5227c5a --- /dev/null +++ b/src/vmc/opt.jl @@ -0,0 +1,26 @@ +module vmc + +abstract type opt end + +abstract type sr_type end + +struct QGT <: sr_type +end + +struct QGTJacobian <: sr_type +end +struct QGTOnTheFly <: sr_type +end + + +include("Eloc.jl") +include("gradient.jl") +include("metropolis.jl") + +include("vmc_utils.jl") +include("vmc.jl") +include("multilevel.jl") +include("optimisers/adamw.jl") +include("optimisers/sr.jl") + +end \ No newline at end of file diff --git a/src/vmc/optimisers/adamw.jl b/src/vmc/optimisers/adamw.jl new file mode 100644 index 0000000..1edce16 --- /dev/null +++ b/src/vmc/optimisers/adamw.jl @@ -0,0 +1,18 @@ +using Optimisers + +mutable struct adamW <: opt + β::Tuple + γ::Number + ϵ::Number +end + +adamW() = adamW((9f-1, 9.99f-1), 0.0, eps()) + +function Optimization(type::adamW, wf, ps, st, sam::MHSampler, ham::SumH, α) + λ₀, σ, E, x0, acc = Eloc_Exp_TV_clip(wf, ps, st, sam, ham) + g = grad(wf, x0, ps, st, E) + st_opt = Optimisers.setup(Optimisers.AdamW(α, type.β, type.γ, type.ϵ), ps) + st_opt, ps = Optimisers.update(st_opt, ps, g) + res = norm(destructure(g)[1]) + return ps, acc, λ₀, res, σ, x0 +end \ No newline at end of file diff --git a/src/vmc/optimisers/sr.jl b/src/vmc/optimisers/sr.jl new file mode 100644 index 0000000..98ff05b --- /dev/null +++ b/src/vmc/optimisers/sr.jl @@ -0,0 +1,127 @@ +using Optimisers +using LinearMaps +using LinearAlgebra +using IterativeSolvers +using Polynomials4ML: _make_reqfields, @reqfields, POOL, TMP, META, release! +using ObjectPools: acquire! +# stochastic reconfiguration + +mutable struct SR <: opt + ϵ1::Number + ϵ2::Number + _sr_type::sr_type +end + +SR() = SR(0., 0.01, QGT()) + +SR(ϵ1::Number, ϵ2::Number) = SR(ϵ1, ϵ2, QGT()) + +_destructure(ps) = destructure(ps)[1] + +function Optimization(type::SR, wf, ps, st, sam::MHSampler, ham::SumH, α) + ϵ1 = type.ϵ1 + ϵ2 = type.ϵ2 + + g, acc, λ₀, σ = grad_sr(type._sr_type, wf, ps, st, sam, ham, ϵ1, ϵ2) + res = norm(g) + + p, s = destructure(ps) + p = p - α * g + ps = s(p) + return ps, acc, λ₀, res, σ, x0 +end + + +# O_kl = ∂ln ψθ(x_k)/∂θ_l : N_ps × N_sample +# Ō_k = 1/N_sample ∑_i=1^N_sample O_ki : N_ps × 1 +# ΔO_ki = O_ki - Ō_k -> ΔO_ki/sqrt(N_sample) +function Jacobian_O(wf, ps, st, sam::MHSampler, ham::SumH) + λ₀, σ, E, x0, acc = Eloc_Exp_TV_clip(wf, ps, st, sam, ham) + dps = grad_params.(Ref(wf), x0, Ref(ps), Ref(st)) + O = 1/2 * reshape(_destructure(dps), (length(_destructure(ps)),sam.nchains)) + Ō = mean(O, dims =2) + ΔO = (O .- Ō)/sqrt(sam.nchains) + return λ₀, σ, E, acc, ΔO +end + +function grad_sr(_sr_type::QGT, wf, ps, st, sam::MHSampler, ham::SumH, ϵ1::Number, ϵ2::Number) + λ₀, σ, E, acc, ΔO = Jacobian_O(wf, ps, st, sam, ham) + g0 = 2.0 * ΔO * E/sqrt(sam.nchains) + + # S_ij = 1/N_sample ∑_k=1^N_sample ΔO_ik * ΔO_jk = ΔO * ΔO'/N_sample -> ΔO * ΔO': N_ps × N_ps + # Sx = g0 + S = ΔO * ΔO' + S[diagind(S)] .*= (1+ϵ1) + S[diagind(S)] .+= ϵ2 + g = S \ g0 + return g, acc, λ₀, σ +end + +function grad_sr(_sr_type::QGTJacobian, wf, ps, st, sam::MHSampler, ham::SumH, ϵ1::Number, ϵ2::Number) + λ₀, σ, E, acc, ΔO = Jacobian_O(wf, ps, st, sam, ham) + g0 = 2.0 * ΔO * E/sqrt(sam.nchains) + + # S_ij = 1/N_sample ∑_k=1^N_sample ΔO_ik * ΔO_jk = ΔO * ΔO'/N_sample -> ΔO * ΔO': N_ps × N_ps + # Sx = g0 + function Svp!(w, v) + Δw = v' * ΔO + for i = 1:length(v) + w[i] = ϵ2 * v[i] + end + @inbounds begin + for i = 1:length(v) + @simd ivdep for j = 1:length(Δw) + w[i] += ΔO[i,j] * Δw[j] + ϵ1 * ΔO[i,j]^2 * v[i] + end + end + end + return w + end + LM_S = LinearMap(Svp!, size(ΔO)[1]; issymmetric=true, ismutating=true) + g = gmres(LM_S, g0) + return g, acc, λ₀, σ +end + +function grad_sr(_sr_type::QGTOnTheFly, wf, ps, st, sam::MHSampler, ham::SumH, ϵ1::Number, ϵ2::Number) + λ₀, σ, E, x0, acc = Eloc_Exp_TV_clip(wf, ps, st, sam, ham) + + # w = O * v + function jvp(v::AbstractVector, wf, ps::NamedTuple, x0) + _destructp, = destructure(ps) + w = zero(_destructp) + for i = 1:length(x0) + _, back = Zygote.pullback(p -> wf(x0[i], p, st)[1], ps) + w += 1/2 * destructure(back(v[i]))[1] + end + return w + end + + # w = v' * O + function vjp(v::AbstractVector, wf, ps::NamedTuple, x0) + _destructp, s = destructure(ps) + w = zeros(length(x0)) + for i = 1:length(x0) + f(t) = begin + p = s(_destructp + t * v) + return wf(x0[i], p, st)[1] + end + w[i] = 1/2 * Zygote.gradient(f, 0.0)[1] + end + return w + end + + g0 = 2 * jvp(E .- mean(E), wf, ps, x0)/sam.nchains # 2 * O * E/sam.nchains + + function Svp!(w, v) + w̃ = 1/sam.nchains * vjp(v, wf, ps, x0) + Δw = w̃ .- mean(w̃) + ṽ = jvp(Δw, wf, ps, x0) + for i = 1:length(v) + w[i] = ṽ[i] + ϵ2 * v[i] + end + return w + end + LM_S = LinearMap(Svp!, length(g0); issymmetric=true, ismutating=true) + g = gmres(LM_S, g0) + return g, acc, λ₀, σ +end \ No newline at end of file diff --git a/src/vmc/vmc.jl b/src/vmc/vmc.jl new file mode 100644 index 0000000..8353c24 --- /dev/null +++ b/src/vmc/vmc.jl @@ -0,0 +1,53 @@ +export VMC +using Printf +using LinearAlgebra +using Optimisers + +mutable struct VMC + tol::Float64 + MaxIter::Int + lr::Float64 + lr_dc::Float64 + type::opt +end + +VMC(MaxIter::Int, lr::Float64, type; tol = 1.0e-3, lr_dc = 50.0) = VMC(tol, MaxIter, lr, lr_dc, type); + +function gd_GradientByVMC(opt_vmc::VMC, sam::MHSampler, ham::SumH, + wf, ps, st, + ν = 1, verbose = true, accMCMC = [10, [0.45, 0.55]]) + + res, λ₀, α = 1.0, 0., opt_vmc.lr + err_opt = zeros(opt_vmc.MaxIter) + + x0, ~, acc = sampler_restart(sam, ps, st) + acc_step, acc_range = accMCMC + acc_opt = zeros(acc_step) + + verbose && @printf("Initialize MCMC: Δt = %.2f, accRate = %.4f \n", sam.Δt, acc) + verbose && @printf(" k | 𝔼[E_L] | V[E_L] | res | LR |accRate| Δt \n") + for k = 1 : opt_vmc.MaxIter + sam.x0 = x0 + + # adjust Δt + acc_opt[mod(k,acc_step)+1] = acc + sam.Δt = acc_adjust(k, sam.Δt, acc_opt, acc_range, acc_step) + + # adjust learning rate + α, ν = InverseLR(ν, opt_vmc.lr, opt_vmc.lr_dc) + + # optimization + ps, acc, λ₀, res, σ, x0 = Optimization(opt_vmc.type, wf, ps, st, sam, ham, α) + + # err + verbose && @printf(" %3.d | %.5f | %.5f | %.5f | %.5f | %.3f | %.3f \n", k, λ₀, σ, res, α, acc, sam.Δt) + err_opt[k] = λ₀ + + if res < opt_vmc.tol + break; + end + end + return wf, err_opt, ps +end + + diff --git a/src/vmc/vmc_utils.jl b/src/vmc/vmc_utils.jl new file mode 100644 index 0000000..df57e2c --- /dev/null +++ b/src/vmc/vmc_utils.jl @@ -0,0 +1,23 @@ +using LinearAlgebra + +function InverseLR(ν, lr, lr_dc) + return lr / (1 + ν / lr_dc), ν+1 +end + +function acc_adjust(k::Int, Δt::Number, acc_opt::AbstractVector, acc_range::AbstractVector, acc_step::Int) + if mod(k, acc_step) == 0 + if mean(acc_opt) < acc_range[1] + Δt = Δt * exp(1/10 * (mean(acc_opt) - acc_range[1])/acc_range[1]) + elseif mean(acc_opt) > acc_range[2] + Δt = Δt * exp(1/10 * (mean(acc_opt) - acc_range[2])/acc_range[2]) + end + end + return Δt +end + + + + + + + \ No newline at end of file diff --git a/test/ACESchrodingerRef/bftest.json b/test/ACESchrodingerRef/bftest.json deleted file mode 100644 index 53b37fc..0000000 --- a/test/ACESchrodingerRef/bftest.json +++ /dev/null @@ -1 +0,0 @@ -[{"X":[0.865674388682704,0.5593067117763364,0.7174253143031734,0.7441985665675926,0.5142646432433161],"P":[[0.0,-0.007273160835590364,0.016835134978959394,-0.05518907618250436,-0.013991696863045475,0.07727358901069811,-0.012208833323334363,0.00734169703086515,-0.010748681761954645,-0.004935955469421341,-0.002051589510234265,-0.007324816327756153,-0.0012150238303724283,-0.0041011296845646,-0.005622401475235145,-0.00033823061468851224,0.0018562963486714315,-0.0021593636943063716,-0.0012960853248113044,-9.257597403594705e-5,0.00015680119226319837,0.0010062258941460165,0.00020016178871157811,-0.001984019458548015,-0.0007327566145175912,0.0002311968723983094,0.0012465424624981008,-0.0002594792759211216,0.00024282324758588263,-0.0009735266760982919,0.0012701285395614488,-0.03120219973811509,-0.08762242223320123,-0.037451458620417544,0.015111096088431522,-0.012946126612891346,0.000558694985035626,0.001817837590677066,-0.010421408920535407,0.0034639646609274455,0.00079623752475448,-0.0048933534161406216,-0.00390375609672264,-0.004111600580489995,-0.0023695204420709254,0.0025733813382050785,-0.0028635128228424334,-7.909063891721852e-5,0.0029030381101404932,-0.0011625503460157905,-0.0013615615980300479,0.00021332697092271057,-0.00109786953861817,-0.0013941167446012129,-0.001490221295716501,-0.0001204696890062787,0.00015039779402193146,0.0015989019630892932,-0.0005821101180100074,0.0002013464602539328,0.0006329917978056687,-0.006915002618838396,-0.007600718250933177,-0.03858143382430678,-0.0175540066031235,-0.03318559508361785,0.008162432125145763,0.0021664489646505657,0.006806617189295271,-0.01451556369214969,-0.0086280196386515,-0.004818462018391713,0.0033186913105087367,-0.0005654313147579877,0.0022253635879058787,0.0001690587595510904,-0.002728984530868199,-0.0001950587651646309,-0.0011279545714458744,-0.0022226762547764706,0.0004916217375389001,0.0017909659462801437,0.0006744511701872369,-0.001551745035556326,-0.002895207427547411,-0.0006219214244312627,0.0009350389808923484,0.00016914793131505497,-0.0003452791792570605,-0.0003786816803868763,-0.01724391291318526,0.0015827677190570928,0.019830661143682055,-0.0005066998330636165,0.00020379950526259723,0.0034856005396233227,-0.000837864448065468,-0.00020115503547858517,-0.002460130404689392,0.0034914041464041315,0.0012819208111355205,0.0005390847087198669,0.00033834247158049996,0.00480614497351621,-0.006664690149082625,-0.0007453626878329316,0.00481376099061003,-0.0015549819968863534,-0.0008326186090949145,0.0018187397433018623,-0.001052692255136756,-0.0018153414744282587,-0.0003622576389311419,-0.00046909665822451006,-0.0006824731487953,-0.0007674462395043458,0.001152487981805814,0.00043849707611555937,0.0003175049893582542,-0.019043660671018557,-7.236050985906757e-5,-0.006225567859219423,0.0026038963224999487,-0.002046061988076383,-0.008413650104816227,-0.006267158027232268,0.001275688517479131,0.0015130214146165486,0.002907855885308389,0.0030036266149898513,0.0035278519117723705,-0.0017120549385533433,-0.0030989746933174816,0.0007666331347008639,0.0015189226547655697,0.0025970957249300064,-0.0020262326327436463,-0.0008831971546530497,0.00097209645752064,0.001798515115113,0.0005735720344046434,-0.0007705128076893775,0.0011114996248305841,0.00043085091688689853,0.0015350106998117316,-0.01020334057149607,-0.006671474048155625,-0.006691990485942631,-0.006925657242282331,-0.010209286713439432,-0.004499314807209286,-0.00349994458874293,0.004758877985775156,0.002439109930616411,0.0020473949096695148,-0.005942465920297297,0.0016042183709269917,-0.0012249775262228415,0.00210158086447522,0.00023977686102514238,0.0008672776381895977,0.002307169063277749,0.001541764645656077,0.0016122659784287463,-0.0005156455435007184,0.00021686290322525572,-0.0004924140087772139,-0.00041270045705382236,6.449869413117129e-5,-0.0009506759863607153,0.00015188792094967778,-0.014407666723961108,0.0027924612931630303,0.0028203461532720804,0.0038875131881751265,0.0046656720472962095,0.003121756952884793,0.00430676164985073,-0.00029668458285071677,-9.718441204005839e-5,-0.0032593092302272315,-0.0013437136275793354,7.475862453128584e-5,0.003180114218341685,-0.001050077202390746,0.002276099273747626,0.0001039980097248826,-0.00016037775522721693,-0.002060579020860672,0.00029309920297411837,0.0006729751887474921,0.0013929414579501016,-0.0007953082056825871,-0.0002467560848502057,0.00030325451502967254,1.1477998920399384e-5,-0.0030218974762138856,-0.0006020517613651615,-0.002648457204957628,0.0009168403790160222,-6.209893505350881e-5,-0.0009408360972925062,-0.0006902561796251357,-0.0006267627356047118,-0.0018815950845586325,0.001547085182744839,0.000779925182126374,0.0006746602298233217,0.0006488089538048761,0.0015598699777386066,-0.0003086159427842343,0.0009293670883072223,-0.001067785662864171,-0.0019115196768754617,-0.0003368689546790447,0.0005845479023218864,0.00023023192073515575,0.0013122170365125702,0.00022606709814607,-0.00031715989281457554,0.0016211182624186814,-0.0010148110061284302,0.0026628512691860397,0.003848579387554078,-0.005605156257441095,-0.0033637582833332272,-0.0011445573172373024,-0.0025159277350477885,-0.001585619490199665,0.00029816093712908853,1.341637085903415e-5,0.0004384131197410367,0.0013383489126430812,-0.0024946429295095323,0.00096931791419475,-0.00025100078458357067,0.0018234251508331873,-0.004305014808757479,-0.006040190820391921,0.0017377368014513215,-0.0015739783203371424,-0.003239626275420088,0.00223614893273366,-0.0018881738863567076,-0.0001427796704715549,-0.0012777487600203947,-0.002531410787272651,-0.0008390546235009237,0.0016817300735055617,0.0021300202652746376,0.00035628850132810394,-0.0009123377635161678,0.000618372094015173,0.0003406999378993869,0.001146975716821958,-0.0013172196655281774,-0.00015263422279787665,0.004042234168403429,0.0037452560245207402,-0.0007213939922429654,-0.0017252343168027335,0.0038763017419173105,0.0006787826109075411,-0.002340911386781388,-0.003326384460095757,0.00020595580781043344,0.0008929421361456529,-0.0015282751325236027,-0.0002532217139647597,0.0011252257475698126,0.0006188154666556426,0.000953459215436979,0.004475627653389925,-0.0032812080234037727,-0.0006571692913155281,-0.0013525487398722332,0.00194060137273032,-0.001620802951929255,-0.003499570981797633,0.0013122309270225882,-0.001942392898749902,-0.0008416427522398282,-0.0003937568759779808,-0.0014885866863977805,0.0011429989853829776,-0.0027537698924907012,-0.000851885168886742,-0.0032782921770560546,0.0032033836275734916,-0.0008045983269435984,-0.0005743553522711756,0.0026320272590839248,-0.0007435890545893948,0.0021025899252853334,0.00012260003559063013,-0.0010216120209657707,0.0007236231626159587,0.0021342688345096826,-0.00043697061634049135,-0.0009202521266326297,0.002879209006791936,-0.002140183433497032,0.0012134517692996014,-0.0015333183624764112,-0.0014382263138208152,0.00029763766886444507,-0.0009531660521184255,-0.0030912111743179337,-0.0003490169388219783,0.00042498437790248664,0.002130058338787467,0.0008957200665292495,-0.00037629573277469765,-0.002196930423843293,0.001466177008307038,0.0011736716521662536,-0.0016482242593388794,-0.0006018734840737325,0.001365207627867828,0.0023229609284834674,0.0022045806817354847,0.0018295904215531792,0.00022096611277266307,0.000903689171183338,-0.0012449085315198315,0.0005612445670031732,-0.0005762671502525746,-0.0022401928230521757],[0.0,-0.1415014645357159,0.03389722562049187,0.023969457534646535,0.0025789586766420767,0.013696179679857092,0.0004874807053287539,0.0022719882019194095,0.0027875258779416377,-0.011640730117812643,0.000941077665747104,0.005427135057118886,-0.007537562580797538,-0.004060026476873175,-0.0016552026056099159,-0.00032880518695875555,0.003718385943319974,0.0004029200354172692,-0.0013051176707870667,0.002439405725676743,0.0024972680098622377,-0.0017040225074165846,-0.0007543674847935729,-0.0007526713922301543,0.0004676740036212627,0.000778633211695623,-0.00028153576981056877,0.0010093164611364228,-0.0013198341530317496,6.359782165865776e-5,0.00042617370429111334,-0.06416247614801514,-0.05189577002323411,-0.020171621608307554,-0.0015242825292753033,0.013559890962133081,-0.00838203941094169,0.004836822934091367,-0.0008460654654378996,-0.007200788787867077,0.0034928318232435186,-0.0038034822741939344,0.0015878310378883353,0.0049451699995633445,0.0017350919453703291,-0.0017048357659562123,-0.0006153755612234325,0.0006945811921420385,-0.0038500833247343985,-0.0021193343060258094,0.0012146748354543754,2.1452229103609626e-5,-0.0012620399744945842,-0.0012228058700368208,-0.0014592033806210979,-0.0006834663476424457,-0.0009550066268569219,-0.0010835891300964228,-0.00010903423324790761,0.0006824577594390323,0.0005083487887822505,-0.033625345647078315,-0.032606979368444235,-0.0011417473414791505,0.0014443128564567342,0.017516251977753056,0.0012301184713490356,-2.9154193785463e-6,0.0017577224352692171,0.003653596685516202,0.00023624488454955346,-0.005923242198205954,0.0016740754963943463,0.0009757378573581285,0.0035911253805730875,0.0025650766807768744,-0.002073173176902431,-0.00030656236121647055,0.0004551403777908367,0.00015230199459836165,-0.00032510222335411035,-0.00024121898907764573,-0.0025802716141693784,0.00028484129938606827,-0.0007395547767799947,-0.0017197557176612584,0.0011972500231066186,0.000641256648915454,0.0003212171122332008,6.222078125770343e-5,-0.01731598416682352,0.008431158003897874,-0.009183858926086213,0.010929844500984033,0.008966934939013101,-0.009537970619365253,0.005757335743731189,-0.005865994219580617,0.004662518646578089,0.0038620501661730693,-0.0006168612883295073,-0.0019255752063450653,-0.0009215986974067855,0.0004839036480230309,0.0017514838331171608,-0.0016028696652950355,-0.0023923985606763205,-0.00036322863478577483,0.0017569820158912077,-0.0003767069453539647,0.0010454414246169365,0.00035013873988570294,0.0007107553790174087,-0.0008070956008169248,0.0009916482969493832,-0.00022074220067070668,0.0006318929074324685,-0.0007172779253403688,0.002891895359074585,0.004077198149168721,0.003284046969469344,-0.0006428534733165632,-0.008736284584531075,-0.00171509465192322,-0.0014401423693541971,-0.003641494027137925,0.002115925923916403,0.0033065335912449466,0.001208199738344879,-0.004226872902893575,0.0002449078414662611,0.0011423199213950479,-0.0012395864430411324,-0.0008738265032233334,0.0007083171010477472,-0.002123536794353635,0.0003612651905357475,-0.0014442473426582903,0.0004654383171098462,-0.00013659926288265266,-1.8936471649763285e-5,0.0007984005565635783,-0.0013169108708772034,-0.0005885073948266743,-4.827692111346017e-5,0.011551030582564722,-0.007112840538444509,5.5344534124386515e-5,0.0029736351126296275,-0.0006095507011469122,-0.0022727056287102317,0.0009589307011599023,0.004103178694479104,0.0042263223063950745,0.004622334458406624,-0.0001882531493094113,0.0003893512457331008,-0.00012795604771846403,-0.00024398719775172622,-0.00039977983796185204,-0.0012038849640740223,0.0010660724838580714,-0.0006960713096831122,6.34316512798159e-5,-0.0006516188817970727,0.0003890438561351152,-0.00032124553636511953,-0.0010239406095511024,0.0008595547264888112,-0.0004144223915150102,-0.0007578193188752278,0.0004967101170510563,0.0020909959048182763,-0.009355120140933933,-0.0008980593622959161,-0.0004574787831058427,9.743256267623031e-5,-0.004466375242725531,-0.0019794750014059926,-0.0027257323849125725,-0.0006720835337924623,-0.0024556313173180427,-0.0015335596306337803,-0.00032581872238008453,-0.0011599460344633354,-0.001211789767266311,-0.0010063867816401564,0.001735351279051978,0.0008116851164913951,0.0006291254751834421,0.0005378787812638067,0.0002672681452831687,0.0011770798195594497,0.0008020829210705136,-0.00043979814011494653,0.00035531133077323235,-0.00010827529786401046,0.0019433175143146336,-0.003093826080194325,-0.001487235016792136,0.002719009869361629,0.0025709665029451404,-0.0003353466722323951,-0.00044087809769537524,-0.001245792587253495,-0.002094183795271219,0.0004256034652395511,0.001161230349992884,0.004431291641480512,-0.0009633169671075653,0.002496383196031251,-0.00043758554159473097,0.0009687964115699736,0.00022181374778521053,-0.0002987706966402035,0.0007845514605978501,0.00022116849945458478,-0.0033070389014662267,0.004464086415957261,-0.003493931516370567,0.007708158346440653,-0.000589266003260251,-0.0015641110202145882,-0.002631726665078621,-0.002965835266920086,0.0008853309505307582,0.0009687690136981908,-0.0007507214597020755,-0.0008568549146337563,0.0006632181541579786,-0.0004833078325781586,-0.0005910619677281819,-0.0004826278418138978,-0.00019975616258931573,0.0004476934784564265,-0.0002021350888062385,-0.00036466348087067795,0.004283655598752125,-0.0028421946766189784,0.003593317783542888,0.0020459939287765846,0.00046718310113147687,0.003046468262816395,-0.00032920942992507965,0.0005443835212580235,-0.0041374119225665385,-8.10555056969186e-6,-0.002208170701622244,0.0002552947833695971,-0.0008145383461805036,-3.290341649087815e-5,-0.00034519257318151844,-0.0004570810547912395,-0.0012706805719087822,0.0010866498813543461,0.0004163710588487817,0.0011249396511859365,0.0003881711945219168,0.0009846422412594388,-0.0012042746999701285,-0.0020179831211198174,0.0004147143364935039,-0.0002380366188330622,-0.001913689400311321,0.002074737638719054,-0.0007647655994505971,-4.995452494484828e-5,0.00017041213878006095,0.0007414918572514411,-0.0006333078471490984,-0.0006942484942779821,-0.0016052332884127322,-0.0028935671643886447,0.0010183277298312337,-0.002163335912843784,-0.0003381440849861135,-0.0005022036905737118,-0.001062907643126416,-0.0012750005046449785,-0.00040014264086854007,0.0002353322515053723,-0.001673088881577215,-0.0007378558464321916,0.0007474990947378235,-0.0010324884572754205,-0.0022876920390454856,-0.0024018236692271017,0.0002700832593224041,-0.0018461956442216913,-0.0007187736794087258,0.0008001844891502288,0.0002207302826278403,-5.501999688902741e-5,7.208716865763236e-5,0.00013385539543922452,0.00033188035875307195,0.00023481013905259752,-0.0003766028049852175,0.0009561643056896107,-0.0013338769215958646,0.0007116849028602084,0.002329535787956947,0.0011229278933551979,-0.0004733601238994762,7.431131037837821e-6,-0.0007216073749576395,0.000529714588039374,-0.00039780257411345903,-0.0006131875286034484,-0.00214512239736763,0.0016627360135492277,0.0011382452794842244,6.423837861671132e-5,4.630065694135529e-5,7.802738027131558e-5,0.0019354199815806057,-0.0010431401026664692,-0.0010036577568998315,-0.0007202916468236336,-0.0007820606723768914,-6.7903850190233715e-6,0.0014193085961399115,0.0008515175723127755,5.387625947839787e-5,-0.0002910719631016457,-0.0005685460068819297,0.0008558063573472564,0.0002852396426840457],[0.0,-0.04402805754322954,-0.011142958961571527,-0.08355185615758533,0.007693057836078031,0.008122049065450365,0.014620488905996001,0.013363446056571266,-0.003909192028721517,-0.016326309914148164,-0.0009084404511918344,-0.000909721499046237,-0.013045523953146959,0.007944791735203387,0.005237424863455664,0.009071443576143916,-0.0034181574157606757,-0.000520000238696533,0.0008654754867416369,0.0004901290187083446,0.00011367666729697018,-0.0019599735293295063,0.0027663935251146746,0.001853162249835104,-0.0003921160809921587,-0.0008333880797629452,-0.0011039630020117331,-0.0005380107882431915,0.0022818292851276363,-0.001408490290052521,0.002372085705611642,-0.06501171552417831,0.03931962838066117,0.06671429363118775,0.04110530626743748,-0.016468920504775393,0.05310087125963052,-0.02156144368731037,0.005416850888297784,-0.00666128483988743,4.403586415969488e-5,-0.005247110590249385,0.005303421958830963,0.0033519452413559594,0.0004935586089650928,0.0011158361580500085,0.0012066733990109494,8.396064438458496e-5,-0.002590415228648016,0.0025816381050334097,0.002647505292393702,-1.5064779562123335e-5,0.0006931461755098325,-0.0011143324107808922,-0.0015215268555085041,-0.0005886855001973576,-0.0024173864641307465,0.0013905041919884602,0.0004926344598753639,-0.00020832945619871546,0.0005620741216666262,0.009783053438224228,0.08697588820022967,-0.004879529526639961,0.007542049198434136,0.03395731855284114,0.0012835565510867614,-0.009375665969310403,0.010148602202229908,0.0006842783448509161,0.004823590277254532,0.008260571872242954,-0.0033154065756634166,-0.0024863028459776036,0.00327411103440807,-0.001703641358365943,-0.002452753235582638,-0.0003952651975097854,-0.0002143458539703357,4.713807205955768e-5,0.000706683133291111,-0.0013234438411582358,0.0006065750590348115,0.000417423763979268,0.00038964776003984586,-0.0015594685766375203,0.0006788145188039321,0.0012682394085866164,-0.0010839331160297466,-0.000942795045233883,-0.0033156086330046438,0.03204771084118612,-0.006811021836947997,0.012649033563862877,0.006662179493776104,-0.007401725213088014,0.004520970556574989,-0.00159477324750003,0.002382057750515347,-0.004979235186786568,-0.00010056772852928918,0.0014916866676000094,0.0035232939692925462,-0.002270295806538275,-6.34926173722664e-5,0.0005768314239320663,-0.0011290210984718353,-0.003583493988041029,0.004257678732062989,-0.002019148804893395,-0.00048054189884131986,-0.001915026372407728,0.002226001017782774,0.0014360832467830821,-0.001735464787315388,0.0013434057579571543,-0.0013057855682751937,-0.0018876222034419265,-0.010258239959763157,-0.003074382724101865,-0.007447637565539284,-0.0050137566004224915,0.012100434782207037,0.012104921299818178,0.0044915070193090115,1.2555084738962484e-5,-0.006631439879646405,0.0018492364775181095,-0.0006667463849852597,-0.0006769973392898922,-0.00407935434449415,-0.0001324921889132793,-0.00012770304314821735,-0.0020790062014052067,0.00014192073468858067,0.0009680434531452223,0.001206538736714719,-0.0021861293815579577,0.0005333290902102261,-0.00018666261399317058,-0.000361888195218511,-7.909513539306612e-5,0.0002822125825621509,0.0006889852074675617,0.0009057106931889674,-0.03932349434282124,0.0013937456408790441,0.00779167772664732,0.017228823527280477,0.008307462897016827,-0.004276165621028891,0.003931943480262218,-0.002021544886301114,-0.003051605184114296,0.001652990394435074,-0.0019125164883423559,0.00038060930282128987,0.003343192451703172,0.0038836801418774963,0.0017505816599185623,0.0002864246939532309,-0.0014133675128177306,-0.0014283417161188175,0.00016087011763094182,-0.00035983113464459806,0.0005161047686018014,-0.0011238035739506056,0.00047112103207757716,-0.00045293544575067613,-0.00021089288594467178,0.0002148471261830107,-0.0006493604475529478,0.004842703847682515,-0.007813490082150334,0.005381696453756595,-0.0037133052007379133,0.0008532236705529915,-0.00011489177779547837,-0.0015738134977480492,-0.005373261318093009,-0.001869732349100704,7.388011166879513e-5,-0.0009350600458858547,3.013646394007428e-5,0.0002447844079466207,0.00013505054923966734,0.0012969418066866964,-0.0010376854446819991,-0.0011024019622741027,0.0015575854137105018,-1.187748855846456e-5,-0.001334325434432206,-0.0021191347699637997,-0.0015671937894741327,-0.0004955098365748804,0.00018539770164681111,0.0008493915249244286,-0.008175331207764591,-0.010719814223587898,0.0003931661254402116,-0.0017763814290830164,0.003311708935961154,-0.0004330927517330433,0.00017423467351135674,0.002587179395049181,-0.0010821699950854303,0.0008329475910821957,-0.00129349657377583,-0.001429025370821752,0.0007293106349703788,-0.0005955565586078021,-0.002739854580265446,0.0002384200220269314,0.0012496581397800947,0.0016121812292173568,0.0005730163943619895,-0.0002897753484847473,-0.006771639747668315,-0.004766772912445124,0.0003774473805386835,-0.003401937408364154,0.00112752753245001,-0.0017589574503238396,0.000734786784821611,-0.00037199693866514776,2.0491239172717422e-5,-0.002984910932501567,-0.0008654803560837526,-2.6186409968423833e-5,-0.0002225502414050418,-0.0009788027821722909,-0.000621392688684772,-0.0008905316315846754,-0.00017526186728466584,-0.0003341539379904442,0.0009629319081178466,-0.0017473020626600415,0.001184374913867703,-0.0018548226658740682,-0.00021183693867368953,-0.0009751885968951757,-0.0016107053724142718,0.0006604633275967079,0.0014161881691609245,0.0002800127948441434,-0.0004888999839426727,0.006143275067142836,0.00248242927134876,-0.002714376808795551,0.0030490190186851117,-0.0010341942651478529,0.0009150098676030077,0.00014690156554332965,0.001189141581451135,0.0008642327512319273,-0.0006276276801628641,0.00416342244039392,0.00461253216666197,0.0027302865057217056,0.0023642623900998953,0.000797301537349976,-0.0009496142539453641,-0.0026206892723266138,0.003054583381817993,-0.00031724197060036723,-8.379334445510498e-6,0.0005259166337449302,0.001001251562591875,0.0006049032540461225,-0.0004296440216764725,0.0015870454564155814,-0.00210698971181611,0.0017150192162836407,-0.000238657034065277,0.003058733680938888,0.0002708723922554791,-0.0012818049518332948,0.0015919010307478773,0.00048291670407832994,-0.001471674265555419,0.0010580793900541606,-3.261602460114984e-6,-0.00011315579085891213,0.0002142754613707886,0.001882378165803877,0.0019234805354636725,3.3777079039617604e-6,-0.0013696407432787835,-0.0006706186549689767,0.0023200499084230814,-0.0005432800541000589,-2.13892999411673e-5,0.002058854114475561,0.0006472978712984316,0.0007772420388506206,0.0007646638851072101,-0.0009493846413989477,-0.0009312409924910824,0.0006246346717898026,0.002058491285870386,-0.0035527847626472455,0.00033566961151770445,-0.0019841343292748403,0.0031847107298774084,-0.0019893279813137164,0.00038620783357522573,0.0012180373007744153,-0.0007132656701880948,-0.0009309917255598434,-0.0010838891750669986,-0.0005407783805422664,-1.0571773331642237e-5,0.002801791411167163,-0.0006579644426444493,0.0007738916239531171,-0.001071636884062243,-0.002202576366653914,-0.0022274201750426144,0.001438460330324162,0.0006302693683711359,-0.0008433664136932255,0.000678236164329376,-0.001075976524111195,0.0010435724722433009,-0.0005049818177228774,-0.0010366845334077895,-0.0018339613790994609,0.0008185203684755713],[0.0,0.08386970740551974,-0.0009704306216677576,-0.04592588008463851,0.0028685666346850015,0.011697220999670856,-0.011191614938387965,0.012759375798334523,0.0076885813293625805,-0.010441836843530634,-0.005853702178546021,-0.0038763203121232473,0.009362360450358098,-0.009967831820895345,-0.005260418881412472,-0.0029681372515559606,8.475607294543344e-5,-0.001989187371111607,-0.0009281661927614702,0.0001423988562177603,-0.003236642353563714,-0.002708530300814982,0.001669006292384854,0.0009584925945143888,-0.0011277269311636157,-0.0008964077731046439,0.0006807807387834202,-0.0008432977158589945,8.144121919870023e-5,0.0012541733191314473,0.0017023748281976617,0.08394985986960388,0.08089151516794894,0.048173035412213114,-0.031157233611531644,-0.007280896370582254,0.01535271010143602,0.01172848156913997,-0.008483364758833118,0.0040047205334004425,-0.009007200173387518,0.005639961608809443,0.00684173345333467,0.004341328180363217,0.0032977148482211086,-0.001626259062349758,-0.00011662957063659573,0.003799955976918868,0.0015279553421839642,-0.00039753522937658177,0.0017416131611761685,0.0002762383757838782,-0.002642735215462036,9.957383032887576e-5,-0.0011690958117569698,-0.0008643196858167544,0.0033429550029164534,0.0007656252491542861,0.0018589178554346917,-0.00023470163787052775,-4.5564349221181865e-5,-0.04708021499221474,-0.02461094734598668,-0.008965477341672684,-0.007238796642957464,-0.012613851043720021,-0.0036892716098293112,0.010458235730608256,-0.005975186899136275,0.014523669855506233,-0.00344622262700508,-0.003364473371098094,0.0009159395955967794,-0.006676497212212839,0.0019342230146693405,0.0004574917667275668,-0.001126378219379552,-0.00010582558083096318,0.0019879815957745336,0.0012152978009537761,-0.001854326006266154,0.0008535902428968038,-0.0027617665414098468,0.001740128007730472,-0.0002144713668677936,-0.003187755930819152,0.0007492188428399304,-0.00021293309234032254,-0.00019522761989200044,-0.0007258970198273656,-0.014956631260894006,-0.014480930921992862,0.005641218962234258,-0.016808840708572068,-0.0027254874054495402,-0.00928719566835116,0.0074661220139690294,0.001354033247278211,0.004933222472434316,0.000987488797272827,-0.004374381275238683,0.0002622807611444714,0.004583898115007326,0.00015870122766491036,0.0010636754461488618,0.0037628891724539935,0.0001523011513075129,-0.002532603804604314,0.00018930804895944342,0.0012530453535384242,-0.002319025756432803,0.00330161666828842,-0.00035504507401571154,-0.0006148453459930125,-0.001898693630071422,-0.0001813342417193562,0.0015752583259168306,-0.002187683131135863,-0.006025965351132038,-0.0001260552501196851,-0.010788184092271438,-0.01774803408872944,-0.010878135830473519,0.0023293875211294907,0.003211729248191183,0.0008676369685897532,-0.0025474869227369735,0.0042323901403919255,-0.005934769198670274,-0.0005773801029342905,0.004644567522015575,-0.00046124697349072223,-0.0017707804336112441,0.00206902921673181,0.0015139748320817678,-0.0029059826751051897,-0.0020962922733804962,-0.0008349597503149107,0.0035615462406532183,0.0004176653669464563,-0.0003107523755837688,0.0006899792019573006,0.00010419563912716009,0.0012295531762552387,-0.0009410602877643348,0.019513975111578274,-0.01398339413236611,-0.0050144803279762425,0.0032568604999901295,-0.0056279839545375155,-0.00022393440317839425,0.0022137768427113232,0.0013132735840286312,-0.0007332576873981875,0.003497472897330408,-0.0027800740407611184,-0.0012935280003663551,0.0010061008694466967,-0.0004508868870308679,-0.0033446699911708587,-0.0006889293869881519,-0.001852824275169578,0.001007650874079352,-0.0013693880271834007,0.001536880444284109,-0.002694917146410994,0.0011672250527397746,-0.0013247867941921139,0.0008738391224414858,-0.00015240992519973702,0.000930517877203611,-0.012112942480814192,-0.00491030075213,0.000911178940268876,0.004560419912531687,0.005811907896141531,-0.0013727069424115826,0.00045811054767689886,0.001673653858994744,0.0035808693978879717,0.0025178801041032994,0.002302246490716388,0.0050685925324770315,-0.0004221223125181713,-0.000648624705839004,-0.0012792170216180795,-0.001196919171934974,0.00044336963581870204,0.00065154515712236,0.0018339618059007633,0.0005134357838113822,-0.0004902313504978116,-5.72285690893406e-5,-0.0005687510353394116,-0.0025195276719083056,-0.0014450799819706599,-0.005393020128644126,-0.0006524545761657728,-0.0030716438559565675,-0.005321526498393636,-0.002714876438099475,0.007895331999310798,-0.003997047883807325,-0.0007582136620135889,-0.005355539027361355,0.0005753451144473372,-0.0006081422687335,0.001063966442888716,-0.003270757296706621,0.0006574548363340465,0.0015234035605945413,0.0008334992041817737,-0.0012248997152901717,0.0006968252411365405,0.0010735212052443376,-0.0022243402796182266,-0.0013276234743278847,-0.007905413262602525,-0.004283701426925043,-0.002943687529171942,-0.0010329011274543152,-0.0025809997575823975,-0.002561857953901345,-0.0024551518204565664,0.00048602991876391555,-0.00015879011746549327,-0.0019373293385056033,0.0015290705749877022,-0.0013502782511744598,-0.004327981073774667,0.004505639459318477,0.0008116708736342237,0.0007135010956583237,-0.0015816834097709634,-0.001198832377448572,-0.0005161310398885224,-0.0009804922047427346,0.0006744337214528709,-0.0005888197078642406,0.008197202656616784,0.0035149639205623567,-0.004284564573678459,0.0028658598650236344,-0.0012150575437991323,0.0015004292298309004,-0.002884749841116353,-0.003972969937024104,-0.0013145004493110043,-0.0010340417811295574,-0.0034854643798388233,-0.0015646718558845584,-0.00022542987125176702,0.0001580010509333612,-0.0008046643104319553,0.0008595865618353582,-0.002076309364368041,0.004503731623806539,9.000338536135329e-5,0.0014719216541901005,-0.00023809721975739665,-0.001067825665623421,0.004834179169484606,-8.45209963501152e-5,-0.0020270461624701175,0.0010993332726971427,0.0019282846905620042,0.0008428080251332878,0.000548688975593386,-0.0017879236797723576,-0.0005074793605964643,-0.000820108446125938,0.0005038678104534605,0.0026446923957827865,0.0025042487004194387,0.00363096703343181,0.00025430171332200733,0.001067397218429686,0.0011887892775203177,-0.0014597409694551888,-0.0003873545732979744,0.00022294337065004153,0.001782709120448477,-0.0009657157168968443,-0.0023430486582372797,0.00283446920402062,0.0016249986097077351,-0.0004595559487800244,-0.004396999360243408,0.0022798959646779357,0.0007396997008099389,-0.0018600194770316447,0.0004488038002656552,-0.0004904332718951571,-0.0017400405192807304,0.002406710143100841,-0.0008618779699613526,-0.001193172129634847,0.0006317976673705298,0.00017966878857362205,4.8320864076113224e-5,-0.0008682564303648513,0.0005057191142357151,-0.0013995181609154559,-0.0009428957286025687,0.0009318933104878499,0.0018077042096996628,-0.0015292130806268092,0.001024262942456958,0.005212085913413507,-0.001933593782977388,0.0008320796960362342,0.0010255566339031693,-0.0024562652697863703,0.00027756092078103274,-0.0012496514255268569,0.005542501308233257,0.0011621541630538568,0.0013907263942731648,0.0015505810059378332,-0.0002642823414091117,-0.002013810485467708,-0.002121666676426142,-0.0013219769509399452,0.0015649422329395966,-0.002046314408365764,-9.74143649550335e-5,-0.0032759648441764507,0.0010095299905392905],[0.0,-0.029408724152558186,-0.07961485608575579,0.02429575161434835,0.012695165514308115,0.034068488481416366,0.004738260547612065,-0.014810662920320124,0.016466818076198042,-0.01599656483421443,0.0011903852857444446,0.006565802667208097,0.0042835390856049245,0.0019173407089454434,0.0006695561076745098,0.0009603816933978828,0.0015027008696173873,-0.00036629233732104577,-0.00013771907219151924,-0.0016196597311155367,-0.0008159605539326211,0.0021723680038954114,-0.002202971410170692,0.0002067039000915997,-0.0013505835324627415,0.00016975100363364763,0.0009294012772919252,-0.0017563765489161222,7.043996297012429e-5,-0.0005201823221314795,-0.0016720491394588781,-0.05019820961521885,0.030601506208131325,-0.026487825613379132,-0.00694758312979253,0.018854068705853363,-0.018853861732462732,0.007725252848314819,0.0015886333072159325,0.00025206450947173404,0.005136216682454122,-0.002760569053680654,0.0026998815750210203,0.00043926049717113143,-0.01138589353137207,-0.00045346931438390846,-0.003970880561849017,0.0018195884621474837,-0.0008243499865267709,-0.00118079344414198,-0.003542703097294454,0.0012291106496810644,0.0029148895627155,0.0025033476044724503,0.0009441961879264032,0.0012477960242394262,-0.0008551776661664053,0.001203813520334674,0.0008883090443908005,0.00017518775429020524,-0.0015112861710620369,0.011133823347798675,0.023148877744152908,0.0001452976963559497,-0.030071051914032556,-0.01039970998333628,0.007089975319574339,-0.0009275967441523235,0.010190114734163959,0.0038978273242153898,-0.0011939769276209148,0.001479477436782824,0.010303052040057464,-0.002378504695930744,-0.0024474054107217133,-0.0015574102653112926,0.0003037315375750598,-0.005150203865968845,0.0018444588288032819,-0.0010183470886383607,-0.0010469442861759865,-0.002177309984464359,-0.0006847821561775338,0.0006051240195543805,0.0002787617778268732,0.0015175161580350505,-0.0019739281151331875,0.0004627477855440448,0.00040091359001521727,0.0006306187330422004,0.07040621307290723,0.007867536843038318,0.01535579495109694,0.00805449132171127,0.015169656785107797,0.00770683970861356,-0.0014789689777540828,-0.002155866634194098,-0.003998601038954876,0.003701042632020009,-0.00011269424268033158,-0.009525295150184385,0.002304384646366992,0.0029793110186764424,-0.0015338397656715645,-0.002507088053703426,-0.0013140568792070653,0.0018741206698632176,-0.00213422617179473,0.0006258227207557781,-0.0016322421574461112,-0.0030104996643087823,0.0006628968099637554,0.0008337753278824599,0.0015701119337084581,0.00017972646297288597,9.115018399893803e-6,-0.0010693639994815225,-0.0036565907275016266,0.01048954721176878,0.009979825068175179,-0.0021501523876782747,-0.0002571947193638018,-0.0005231113734642835,-0.008381791053287072,0.0036110062181667984,0.01236811658984308,-0.002363760834755232,-0.002049340736992869,0.0034190399802136937,3.095013064701948e-5,0.0006881654687300348,-0.0006459685874635523,5.981300927238139e-5,0.0009929929669522516,-0.0007034139220815446,0.0025947507006204916,-0.0012465678930106438,0.0007298349856393977,-0.001311239016098477,-0.0002709496829527328,0.0003573570583260926,0.00017650325155420783,0.0004312778200459983,0.0002596828077437762,0.007668832630592485,0.0005668885227145245,-0.007445005445540795,-0.002180049158424807,0.004189576444394635,0.007336798795207091,-0.010760264137570938,0.004354057937301244,-0.0021417276148978774,0.00020343679966305575,-0.0011667986148276442,-0.001521548187795829,-0.004072542747398399,-0.0033692565513772188,0.0015418630815053921,0.00012751746427295814,-0.0015908183644575793,0.0014867070682234214,-0.00023476601173898662,2.1392727250567297e-5,0.0015439300440425665,0.0006302750553393902,-0.0011745832126825054,-0.0011917268128131756,0.00016698934751635427,-0.001246909903159716,0.010608565691786264,-0.004123101721741894,-0.003980937183441734,0.00052464458878169,0.004077728857357225,0.0050046230587418895,-0.0007116415348767724,-0.003524126603693882,0.004896585005985333,0.002278768126393503,0.0005344721996156696,0.003961012418138575,0.00025395796499933545,0.0011120121797384098,0.0005336754795107502,-0.0008495722132742464,-0.002129987392663864,-2.3393446854187832e-5,0.00037732769694367486,0.00010266718689826554,0.0001049838555660403,-0.00047712254635120404,-0.0006418516622543081,0.0009096422703404578,0.002196005396419578,0.007097008536745168,2.35839755921466e-6,0.0003232356109497796,0.004112241635083961,-0.002445571490645541,-0.005084662960157807,0.0012746625108936813,-0.004128282947508196,0.00019133088889600327,0.0018843123826957659,-0.004069364823367721,7.853215200029877e-5,-0.0012148357243647121,0.0004326981650370287,0.0019524404007283242,0.00046665980679603397,0.0023328425160715513,0.001506668866074644,-0.0006260524576469025,0.000751630045850698,0.0009400209322299511,0.0023567058073696093,-0.0021051219255385964,-0.0024695747483556496,-0.0007619377632184466,0.005174178860544337,-0.004239008687786166,0.003322399382719688,-0.004274322055981807,-0.0012111416082931455,0.0026293109799787186,0.0007014015943274976,0.0014882587750042902,0.0022473986538776927,-0.0007260684797390137,-0.0018601255192169858,-0.0025226276740071005,-2.823778125084083e-5,-0.0007475326349349818,-0.0001664865480409446,0.0013555876672334247,0.004517889021508121,-0.003080990384298615,-0.002944226629132014,0.00039422473970278555,-0.0018302926283059334,-0.0009153153730956427,-0.0009196996627754447,0.001182214608830942,0.0010125654826699834,-0.002922331161292741,0.0009748063584593677,0.0013127571762226987,0.0006242878896491025,0.0019185374649353874,0.0003799275583028376,0.0015817012209455483,-7.06858787727385e-5,-0.0003773612494805833,0.0005187857570435263,-0.005941911086557807,0.005347178996690836,-0.0043680735836008665,0.0022371845811890775,0.0003764863510579694,0.0003505903455593586,-0.0022425222468880953,-0.001444535117892683,0.0013382419869145035,0.000965932589276706,-0.00023727747402093013,-3.0971534290143194e-5,-0.00010206701238134864,-0.0007092478170961671,-0.0008259484893786002,0.00015204223937540336,0.0050102217489241254,2.097848647156118e-6,0.0063745142053332916,0.004323809882398249,0.0010068062753886603,-0.0005253075282202657,-0.002365265947834636,-0.0008032393134943724,-7.157078372089952e-5,0.0014606474433231562,-0.0003292806653449654,-9.743060576889866e-5,-0.0009870641358958533,-0.0033757541402940606,0.003445804863635187,-0.004400916716619389,0.00034731670621810816,-0.0019689166585952333,-0.0021550350170804517,0.0024851583069887754,-0.0013706845149323353,-0.0005462649105470952,0.0005302940174428687,-0.0013759041802922235,-0.0006509915292246132,-0.001729754701923505,0.000409807659427466,0.00014575905523010274,-0.0014846375025426633,0.0018183882442078555,-0.002363980622635043,-0.0010981201010358212,-0.0019384169361203842,3.7819805769491953e-5,-0.00039790593744764316,-0.0015708995366575157,-0.0007170786048079367,-0.000461402403979508,0.002020057264371874,-0.000500210646455958,0.00011327896585982094,-0.0006232406487397653,0.00016235786985906031,-0.00045079067914616624,-0.0020334490888162665,-0.00023768097308445344,0.0007160088437790237,-0.0012011295069093027,-0.0015064500353723694,-0.0016361780988608278,-0.002101848634816187,0.0007902590684578514,-0.001362987501192283,0.0016024013681653398,-0.00039855329053647943,0.0006006279265734773]],"refval":-39.00416165865923}] \ No newline at end of file diff --git a/test/ACESchrodingerRef/orthopolyweights.json b/test/ACESchrodingerRef/orthopolyweights.json deleted file mode 100644 index 7d619e8..0000000 --- a/test/ACESchrodingerRef/orthopolyweights.json +++ /dev/null @@ -1 +0,0 @@ -[{"ww":[0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335,0.0033333333333333335],"tdf":[-1.5655603390389137,-1.5550883635269477,-1.5446163880149817,-1.5341444125030157,-1.5236724369910497,-1.5132004614790837,-1.5027284859671177,-1.4922565104551517,-1.481784534943186,-1.47131255943122,-1.460840583919254,-1.450368608407288,-1.439896632895322,-1.429424657383356,-1.41895268187139,-1.408480706359424,-1.398008730847458,-1.387536755335492,-1.377064779823526,-1.36659280431156,-1.356120828799594,-1.345648853287628,-1.335176877775662,-1.324704902263696,-1.3142329267517303,-1.3037609512397643,-1.2932889757277983,-1.2828170002158323,-1.2723450247038663,-1.2618730491919004,-1.2514010736799344,-1.2409290981679684,-1.2304571226560024,-1.2199851471440364,-1.2095131716320704,-1.1990411961201044,-1.1885692206081384,-1.1780972450961724,-1.1676252695842064,-1.1571532940722404,-1.1466813185602747,-1.1362093430483087,-1.1257373675363427,-1.1152653920243767,-1.1047934165124107,-1.0943214410004447,-1.0838494654884787,-1.0733774899765127,-1.0629055144645467,-1.0524335389525807,-1.0419615634406147,-1.0314895879286488,-1.0210176124166828,-1.0105456369047168,-1.0000736613927508,-0.9896016858807849,-0.9791297103688189,-0.9686577348568529,-0.9581857593448869,-0.9477137838329209,-0.937241808320955,-0.9267698328089891,-0.9162978572970231,-0.9058258817850571,-0.8953539062730911,-0.8848819307611251,-0.8744099552491591,-0.8639379797371932,-0.8534660042252272,-0.8429940287132612,-0.8325220532012952,-0.8220500776893293,-0.8115781021773633,-0.8011061266653973,-0.7906341511534313,-0.7801621756414654,-0.7696902001294994,-0.7592182246175334,-0.7487462491055674,-0.7382742735936014,-0.7278022980816354,-0.7173303225696694,-0.7068583470577035,-0.6963863715457376,-0.6859143960337716,-0.6754424205218056,-0.6649704450098396,-0.6544984694978736,-0.6440264939859076,-0.6335545184739416,-0.6230825429619756,-0.6126105674500097,-0.6021385919380438,-0.5916666164260778,-0.5811946409141118,-0.5707226654021458,-0.5602506898901798,-0.5497787143782138,-0.5393067388662478,-0.5288347633542819,-0.5183627878423159,-0.5078908123303499,-0.49741883681838395,-0.48694686130641796,-0.47647488579445196,-0.466002910282486,-0.45553093477052004,-0.44505895925855404,-0.43458698374658805,-0.4241150082346221,-0.4136430327226561,-0.40317105721069013,-0.39269908169872414,-0.3822271061867582,-0.3717551306747922,-0.3612831551628262,-0.3508111796508602,-0.3403392041388943,-0.3298672286269283,-0.3193952531149623,-0.30892327760299637,-0.2984513020910304,-0.2879793265790644,-0.2775073510670984,-0.26703537555513246,-0.25656340004316647,-0.24609142453120048,-0.2356194490192345,-0.22514747350726852,-0.21467549799530256,-0.20420352248333656,-0.1937315469713706,-0.1832595714594046,-0.17278759594743864,-0.16231562043547265,-0.1518436449235067,-0.1413716694115407,-0.13089969389957473,-0.12042771838760874,-0.10995574287564276,-0.09948376736367678,-0.0890117918517108,-0.07853981633974483,-0.06806784082777885,-0.05759586531581288,-0.0471238898038469,-0.03665191429188092,-0.026179938779914945,-0.015707963267948967,-0.005235987755982989,0.005235987755982989,0.015707963267948967,0.026179938779914945,0.03665191429188092,0.0471238898038469,0.05759586531581288,0.06806784082777885,0.07853981633974483,0.0890117918517108,0.09948376736367678,0.10995574287564276,0.12042771838760874,0.13089969389957473,0.1413716694115407,0.1518436449235067,0.16231562043547265,0.17278759594743864,0.1832595714594046,0.1937315469713706,0.20420352248333656,0.21467549799530256,0.22514747350726852,0.2356194490192345,0.24609142453120048,0.25656340004316647,0.26703537555513246,0.2775073510670984,0.2879793265790644,0.2984513020910304,0.30892327760299637,0.3193952531149623,0.3298672286269283,0.3403392041388943,0.3508111796508602,0.3612831551628262,0.3717551306747922,0.3822271061867582,0.39269908169872414,0.40317105721069013,0.4136430327226561,0.4241150082346221,0.43458698374658805,0.44505895925855404,0.45553093477052004,0.466002910282486,0.47647488579445196,0.48694686130641796,0.49741883681838395,0.5078908123303499,0.5183627878423159,0.5288347633542819,0.5393067388662478,0.5497787143782138,0.5602506898901798,0.5707226654021458,0.5811946409141118,0.5916666164260778,0.6021385919380438,0.6126105674500097,0.6230825429619756,0.6335545184739416,0.6440264939859076,0.6544984694978736,0.6649704450098396,0.6754424205218056,0.6859143960337716,0.6963863715457376,0.7068583470577035,0.7173303225696694,0.7278022980816354,0.7382742735936014,0.7487462491055674,0.7592182246175334,0.7696902001294994,0.7801621756414654,0.7906341511534313,0.8011061266653973,0.8115781021773633,0.8220500776893293,0.8325220532012952,0.8429940287132612,0.8534660042252272,0.8639379797371932,0.8744099552491591,0.8848819307611251,0.8953539062730911,0.9058258817850571,0.9162978572970231,0.9267698328089891,0.937241808320955,0.9477137838329209,0.9581857593448869,0.9686577348568529,0.9791297103688189,0.9896016858807849,1.0000736613927508,1.0105456369047168,1.0210176124166828,1.0314895879286488,1.0419615634406147,1.0524335389525807,1.0629055144645467,1.0733774899765127,1.0838494654884787,1.0943214410004447,1.1047934165124107,1.1152653920243767,1.1257373675363427,1.1362093430483087,1.1466813185602747,1.1571532940722404,1.1676252695842064,1.1780972450961724,1.1885692206081384,1.1990411961201044,1.2095131716320704,1.2199851471440364,1.2304571226560024,1.2409290981679684,1.2514010736799344,1.2618730491919004,1.2723450247038663,1.2828170002158323,1.2932889757277983,1.3037609512397643,1.3142329267517303,1.324704902263696,1.335176877775662,1.345648853287628,1.356120828799594,1.36659280431156,1.377064779823526,1.387536755335492,1.398008730847458,1.408480706359424,1.41895268187139,1.429424657383356,1.439896632895322,1.450368608407288,1.460840583919254,1.47131255943122,1.481784534943186,1.4922565104551517,1.5027284859671177,1.5132004614790837,1.5236724369910497,1.5341444125030157,1.5446163880149817,1.5550883635269477,1.5655603390389137]}] \ No newline at end of file diff --git a/test/compare_bflow.jl b/test/compare_bflow.jl deleted file mode 100644 index 6fd5175..0000000 --- a/test/compare_bflow.jl +++ /dev/null @@ -1,34 +0,0 @@ -using Polynomials4ML, ACEcore, ACEbase -using ACEpsi: BFwf, gradient, evaluate, envelopefcn -using JSON -using Printf -using LinearAlgebra - -const ↑, ↓, ∅ = '↑','↓','∅' -using JSON - -# == test configs == -data_ortho = JSON.parse(open("test/ACESchrodingerRef/orthopolyweights.json")) # import weights from ACESchrodinger.jl -data = JSON.parse(open("test/ACESchrodingerRef/bftest.json")) # import input electron position and parameter of model -ww = data_ortho[1]["ww"] -xx = data_ortho[1]["tdf"] -X = data[1]["X"] -PP = data[1]["P"] -refval = data[1]["refval"] -ww = Float64.(ww) -xx = Float64.(xx) -Σ = [↑, ↑, ↓, ↓, ↓]; -Nel = 5 -WW = DiscreteWeights(xx, ww) -polys = orthpolybasis(10, WW) - -wf = BFwf(Nel, polys; ν=2, totdeg = 10, trans = atan, envelope = envelopefcn(x -> sqrt(x^2 + 1), 0.5)) -# == - -for i = 1:5 - wf.W[:, i] = PP[i][2:end] # the first entry of PP[i] is the extra constant -end -println("ACESchrodinger.BFwf - ACEpsi.BFwf: ", wf(X, Σ) - refval) -spec1p = [ (k, σ) for σ in [1, 2, 3] for k in 1:length(polys) ] # (1, 2, 3) = (∅, ↑, ↓); -spec1p = sort(spec1p, by = b->b[1]) -@show displayspec(wf.spec, spec1p) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 62daed3..84817bc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,8 +2,8 @@ using ACEpsi using Test @testset "ACEpsi.jl" begin - - @testset "BFwf" begin include("test_bflow.jl") end - + @testset "BFwf_lux" begin include("test_bflow_lux.jl") end + @testset "AtomicOrbitals" begin include("test_atomicorbitals.jl") end + @testset "AtomicOrbitalsBasis" begin include("test_atorbbasis.jl") end end diff --git a/test/test_admissiible.jl b/test/test_admissiible.jl deleted file mode 100644 index 104865d..0000000 --- a/test/test_admissiible.jl +++ /dev/null @@ -1,20 +0,0 @@ -using Polynomials4ML, ACEcore, ACEpsi, ACEbase, Printf -using ACEpsi: BFwf, gradient, evaluate, laplacian, envelopefcn, displayspec -using LinearAlgebra -using BenchmarkTools -using JSON -const ↑, ↓, ∅ = '↑','↓','∅' -Σ = [↑, ↑, ↑, ↓, ↓]; -Nel = 5 -polys = legendre_basis(16) -MaxDeg = [16, 5] - - -test_ad = bb -> (@show bb; @show all([bb[i][1] < MaxDeg[i] for i = 1:length(bb)]); (length(bb) == 0 || all([bb[i][1] <= MaxDeg[length(bb)] for i = 1:length(bb)]))) -wf = BFwf(Nel, polys; ν=2, sd_admissible = test_ad) -@show length(wf.polys) -LL = displayspec(wf) -@show LL -for i = 1:length(LL) - @show LL[i] -end diff --git a/test/test_atomicorbitals.jl b/test/test_atomicorbitals.jl new file mode 100644 index 0000000..64f0f46 --- /dev/null +++ b/test/test_atomicorbitals.jl @@ -0,0 +1,52 @@ + + +using Polynomials4ML, ForwardDiff, Test, ACEpsi +using Polynomials4ML.Testing: println_slim, print_tf +using Polynomials4ML: evaluate, evaluate_ed, evaluate_ed2 + +# -------------- RnlExample ----------------- +@info("Testing RnlExample basis") +bRnl = ACEpsi.AtomicOrbitals.RnlExample(5) + +rr = 2 * rand(10) .- 1 +Rnl = evaluate(bRnl, rr) +Rnl1, dRnl1 = evaluate_ed(bRnl, rr) +Rnl2, dRnl2, ddRnl2 = evaluate_ed2(bRnl, rr) + +fdRnl = vcat([ ForwardDiff.derivative(r -> evaluate(bRnl, [r,]), r) + for r in rr ]...) +fddRnl = vcat([ ForwardDiff.derivative(r -> evaluate_ed(bRnl, [r,])[2], r) + for r in rr ]...) + +println_slim(@test Rnl ≈ Rnl1 ≈ Rnl2) +println_slim(@test dRnl1 ≈ dRnl2 ≈ fdRnl) +println_slim(@test ddRnl2 ≈ fddRnl) + + +# -------------- **** ----------------- +using ACEbase.Testing: fdtest +using Zygote + +@info("Test rrule") +using LinearAlgebra: dot + +for ntest = 1:30 + local rr + local uu + local Rnl + local u + + rr = 2 .* randn(10) .- 1 + uu = 2 .* randn(10) .- 1 + _rr(t) = rr + t * uu + Rnl = evaluate(bRnl, rr) + u = randn(size(Rnl)) + F(t) = dot(u, evaluate(bRnl, _rr(t))) + dF(t) = begin + val, pb = Zygote.pullback(evaluate, bRnl, _rr(t)) + ∂BB = pb(u)[2] # pb(u)[1] returns NoTangent() for basis argument + return sum( dot(∂BB[i], uu[i]) for i = 1:length(uu) ) + end + print_tf(@test fdtest(F, dF, 0.0; verbose = false)) +end +println() \ No newline at end of file diff --git a/test/test_atorbbasis.jl b/test/test_atorbbasis.jl new file mode 100644 index 0000000..3bfa7d3 --- /dev/null +++ b/test/test_atorbbasis.jl @@ -0,0 +1,119 @@ +using ACEpsi, Polynomials4ML, StaticArrays, Test +using Polynomials4ML: natural_indices, degree, SparseProduct +using ACEpsi.AtomicOrbitals: Nuc, make_nlms_spec, evaluate, AtomicOrbitalsBasisLayer +using ACEpsi: extspins +using ACEpsi: BackflowPooling +using ACEbase.Testing: print_tf, fdtest +using LuxCore +using Random +using Zygote + +# test configs +Rnldegree = 4 +Ylmdegree = 4 +totdegree = 8 +Nel = 5 +X = randn(SVector{3, Float64}, Nel) +Σ = rand(spins(), Nel) + +nuclei = [ Nuc(3 * rand(SVector{3, Float64}), 1.0) for _=1:3 ] +## + +# Defining AtomicOrbitalsBasis +n1 = 5 +n2 = 1 +Pn = Polynomials4ML.legendre_basis(n1+1) +spec = [(n1 = n1, n2 = n2, l = l) for n1 = 1:n1 for n2 = 1:n2 for l = 0:n1-1] +ζ = 10 * rand(length(spec)) +Dn = SlaterBasis(ζ) +bRnl = AtomicOrbitalsRadials(Pn, Dn, spec) +bYlm = RYlmBasis(Ylmdegree) +spec1p = make_nlms_spec(bRnl, bYlm; totaldegree = totdegree) + +# define basis and pooling operations +prodbasis_layer = ACEpsi.AtomicOrbitals.ProductBasisLayer(spec1p, bRnl, bYlm) +aobasis_layer = ACEpsi.AtomicOrbitals.AtomicOrbitalsBasisLayer(prodbasis_layer, nuclei) + +pooling = BackflowPooling(aobasis_layer) +pooling_layer = ACEpsi.lux(pooling) + +println() +@info("Test evaluate ProductBasisLayer") +ps1, st1 = LuxCore.setup(MersenneTwister(1234), prodbasis_layer) +bϕnlm, st1 = prodbasis_layer(X, ps1, st1) + +@info("Test evaluate AtomicOrbitalsBasis") +ps, st = LuxCore.setup(MersenneTwister(1234), aobasis_layer) +bϕnlm, st = aobasis_layer(X, ps, st) + +@info("Test BackflowPooling") +A = pooling(bϕnlm, Σ) + +println() + + +## +@info("Check get_spec is working") +spec = ACEpsi.AtomicOrbitals.get_spec(aobasis_layer, spec1p) + + +@info("Test evaluation by manual construction") +using LinearAlgebra: norm +bYlm_ = RYlmBasis(totdegree) +Nnlm = length(aobasis_layer.prodbasis.sparsebasis) +Nnuc = length(aobasis_layer.nuclei) + +for I = 1:Nnuc + XI = X .- Ref(aobasis_layer.nuclei[I].rr) + xI = norm.(XI) + Rnl = evaluate(bRnl, xI) + Ylm = evaluate(bYlm_, XI) + for k = 1:Nnlm + nlm = aobasis_layer.prodbasis.sparsebasis.spec[k] + iR = nlm[1] + iY = nlm[2] + + for i = 1:Nel + for (is, s) in enumerate(ACEpsi.extspins()) + a1 = A[i, is, I, k] + + if s in [↑, ↓] + a2 = sum( Rnl[j, iR] * Ylm[j, iY] * (Σ[j] == s) * (1 - (j == i)) for j = 1:Nel ) + else # s = ∅ + a2 = Rnl[i, iR] * Ylm[i, iY] + end + # println("(i=$i, σ=$s, I=$I, n=$(nlm.n), l=$(nlm.l), m=$(nlm.m)) -> ", abs(a1 - a2)) + print_tf(@test a1 ≈ a2) + end + end + end +end +println() + +# +@info("---------- rrule tests ----------") +using LinearAlgebra: dot + +@info("BackFlowPooling rrule") +for ntest = 1:30 + local testϕnlm + testϕnlm = randn(size(bϕnlm)) + bdd = randn(size(bϕnlm)) + _BB(t) = testϕnlm + t * bdd + bA2 = pooling(testϕnlm, Σ) + u = randn(size(bA2)) + F(t) = dot(u, pooling(_BB(t), Σ)) + dF(t) = begin + val, pb = ACEpsi._rrule_evaluate(pooling, _BB(t), Σ) + ∂BB = pb(u) + return dot(∂BB, bdd) + end + print_tf(@test fdtest(F, dF, 0.0; verbose=false)) +end +println() + +@info("Checking Zygote running correctly") +val, pb = Zygote.pullback(pooling, bϕnlm, Σ) +val1, pb1 = ACEpsi._rrule_evaluate(pooling, bϕnlm, Σ) +@assert val1 ≈ val1 +@assert pb1(val) ≈ pb(val)[1] # pb(val)[2] is for Σ with no pb \ No newline at end of file diff --git a/test/test_bflow.jl b/test/test_bflow.jl index 29511b3..f8d306f 100644 --- a/test/test_bflow.jl +++ b/test/test_bflow.jl @@ -1,10 +1,9 @@ -using Polynomials4ML, ACEcore, ACEbase, Printf, ACEpsi -using ACEpsi: BFwf, gradient, evaluate, laplacian +using Polynomials4ML, ACEbase, Printf, ACEpsi +using ACEpsi: BFwf1, gradient, laplacian using LinearAlgebra -#using Random -#Random.seed!(123) -## + + function lap_test(f, Δf, X) F = f(X) ΔF = Δf(X) @@ -63,7 +62,7 @@ function grad_test(f, df, X) end "This function should be removed later to test in a nicer way..." -function fdtest(F, Σ, dF, x::AbstractVector; h0 = 1.0, verbose=true) +function _fdtest(F, Σ, dF, x::AbstractVector; h0 = 1.0, verbose=true) errors = Float64[] E = F(x, Σ) dE = dF @@ -99,9 +98,8 @@ function fdtest(F, Σ, dF, x::AbstractVector; h0 = 1.0, verbose=true) const ↑, ↓, ∅ = '↑','↓','∅' Nel = 5 polys = legendre_basis(8) -wf = BFwf(Nel, polys; ν=3) - -X = 2 * rand(Nel) .- 1 +wf = BFwf1(Nel, polys; ν = 3) +X = randn(Nel) Σ = rand([↑, ↓], Nel) wf(X, Σ) @@ -111,10 +109,9 @@ g = gradient(wf, X, Σ) using LinearAlgebra using Printf -#using ACEbase.Testing: fdtest @info("Fd test of gradient w.r.t. X") -fdtest(wf, Σ, g, X) +@test _fdtest(wf, Σ, g, X) # ## @@ -211,8 +208,8 @@ grad_test3(Fp, dFp, ξ0) @info("Test getting/setting parameters") -wf1 = BFwf(Nel, polys; ν=3) -wf2 = BFwf(Nel, polys; ν=3) +wf1 = BFwf1(Nel, polys; ν=3) +wf2 = BFwf1(Nel, polys; ν=3) @printf(" wf1 - wf2: %f \n", abs(wf1(X, Σ) - wf2(X, Σ))) param1 = ACEpsi.get_params(wf1) wf2 = ACEpsi.set_params!(wf2, param1) @@ -220,6 +217,3 @@ wf2 = ACEpsi.set_params!(wf2, param1) ## -@warn("removed compac test since json file is missing") -# @info("Test compatibility with ACESchrodinger") # Jerry: Not sure if this should be kept in the same file -# include("compare_bflow.jl") diff --git a/test/test_bflow_lux.jl b/test/test_bflow_lux.jl new file mode 100644 index 0000000..e0d7d05 --- /dev/null +++ b/test/test_bflow_lux.jl @@ -0,0 +1,231 @@ +using ACEpsi, Polynomials4ML, StaticArrays, Test +using Polynomials4ML: natural_indices, degree, SparseProduct +using ACEpsi.AtomicOrbitals: Nuc, make_nlms_spec, evaluate +using ACEpsi: BackflowPooling, BFwf_lux, setupBFState, Jastrow +using ACEpsi.vmc: gradient, laplacian, grad_params +using ACEbase.Testing: print_tf, fdtest +using LuxCore +using Lux +using Zygote +using Optimisers # mainly for the destrcuture(ps) function +using Random +using Printf +using LinearAlgebra +using BenchmarkTools + +using HyperDualNumbers: Hyper + +function grad_test2(f, df, X::AbstractVector) + F = f(X) + ∇F = df(X) + nX = length(X) + EE = Matrix(I, (nX, nX)) + + for h in 0.1.^(3:12) + gh = [ (f(X + h * EE[:, i]) - F) / h for i = 1:nX ] + @printf(" %.1e | %.2e \n", h, norm(gh - ∇F, Inf)) + end +end + +Rnldegree = n1 = 2 +Ylmdegree = 3 +totdegree = 3 +Nel = 2 +X = randn(SVector{3, Float64}, Nel) +Σ = rand(spins(), Nel) +nuclei = [ Nuc(3 * rand(SVector{3, Float64}), 1.0) for _=1:3 ] + +# wrap it as HyperDualNumbers +x2dualwrtj(x, j) = SVector{3}([Hyper(x[i], i == j, i == j, 0) for i = 1:3]) +hX = [x2dualwrtj(x, 0) for x in X] +hX[1] = x2dualwrtj(X[1], 1) # test eval for grad wrt x coord of first elec + +## + +# Defining AtomicOrbitalsBasis +n2 = 2 +Pn = Polynomials4ML.legendre_basis(n1+1) +spec = [(n1 = n1, n2 = n2, l = l) for n1 = 1:n1 for n2 = 1:n2 for l = 0:n1-1] +ζ = rand(length(spec)) +Dn = GaussianBasis(ζ) +bRnl = AtomicOrbitalsRadials(Pn, Dn, spec) +bYlm = RYlmBasis(Ylmdegree) + +# setup state +BFwf_chain, spec, spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree, ν = 2) +ps, st = setupBFState(MersenneTwister(1234), BFwf_chain, Σ) + +## + +@info("Test evaluate") +A1 = BFwf_chain(X, ps, st) +hA1 = BFwf_chain(hX, ps, st) + +print_tf(@test hA1[1].value ≈ A1[1]) + +println() + +## +F(X) = BFwf_chain(X, ps, st)[1] + +# @profview let F = F, X = X +# for i = 1:10_000 +# F(X) +# end +# end + +# @btime F(X) + + +## + +@info("Test ∇ψ w.r.t. X") +ps, st = setupBFState(MersenneTwister(1234), BFwf_chain, Σ) +y, st = Lux.apply(BFwf_chain, X, ps, st) + +F(X) = BFwf_chain(X, ps, st)[1] +dF(X) = Zygote.gradient(x -> BFwf_chain(x, ps, st)[1], X)[1] +fdtest(F, dF, X, verbose = true) + +## + +@info("Test consistency with HyperDualNumbers") +for _ = 1:30 + local X = randn(SVector{3, Float64}, Nel) + local Σ = rand(spins(), Nel) + local hdF = [zeros(3) for _ = 1:Nel] + local hX = [x2dualwrtj(x, 0) for x in X] + for i = 1:3 + for j = 1:Nel + hX[j] = x2dualwrtj(X[j], i) # ∂Ψ/∂xj_{i} + hdF[j][i] = BFwf_chain(hX, ps, st)[1].epsilon1 + hX[j] = x2dualwrtj(X[j], 0) + end + end + print_tf(@test dF(X) ≈ hdF) +end +println() + +## + +@info("Test ∇ψ w.r.t. parameters") +p = Zygote.gradient(p -> BFwf_chain(X, p, st)[1], ps)[1] +p, = destructure(p) + +W0, re = destructure(ps) +Fp = w -> BFwf_chain(X, re(w), st)[1] +dFp = w -> ( gl = Zygote.gradient(p -> BFwf_chain(X, p, st)[1], ps)[1]; destructure(gl)[1]) +grad_test2(Fp, dFp, W0) + +## + +@info("Test consistency when input isa HyperDualNumbers") +#hp = Zygote.gradient(p -> BFwf_chain(hX, p, st)[1], ps)[1] + +#hp, = destructure(hp) +#P = similar(p) +#for i = 1:length(P) +# P[i] = hp[i].value +#end + +#print_tf(@test P ≈ p) + +#println() + +## + +@info("Test Δψ w.r.t. X using HyperDualNumbers") +X = randn(SVector{3, Float64}, Nel) +XX = [Vector(x) for x in X] +hX = [x2dualwrtj(x, 0) for x in X] +Σ = rand(spins(), Nel) +F(x) = BFwf_chain(x, ps, st)[1] + + +function ΔF(x) + ΔΨ = 0.0 + hX = [x2dualwrtj(xx, 0) for xx in x] + for i = 1:3 + for j = 1:Nel + hX[j] = x2dualwrtj(x[j], i) # ∂Φ/∂xj_{i} + ΔΨ += BFwf_chain(hX, ps, st)[1].epsilon12 + hX[j] = x2dualwrtj(x[j], 0) + end + end + return ΔΨ +end + +Δ1 = ΔF(X) +f0 = F(X) + +for h in 0.1.^(1:8) + Δfh = 0.0 + for i = 1:Nel + for j = 1:3 + XΔX_add, XΔX_sub = deepcopy(XX), deepcopy(XX) + XΔX_add[i][j] += h + XΔX_sub[i][j] -= h + XΔX_add = [SVector{3, Float64}(x) for x in XΔX_add] + XΔX_sub = [SVector{3, Float64}(x) for x in XΔX_sub] + Δfh += (F(XΔX_add) - f0) / h^2 + Δfh += (F(XΔX_sub) - f0) / h^2 + end + end + @printf(" %.1e | %.2e \n", h, abs(Δfh - Δ1)) +end + +## + +@info("Test gradp Δψ using HyperDualNumbers") +g_bchain = xx -> Zygote.gradient(p -> BFwf_chain(xx, p, st)[1], ps)[1] +# g_bchain(hX) + +using ACEpsi: zero! +using HyperDualNumbers + +function grad_lap(g_bchain, x) + function _mapadd!(f, dest::NamedTuple, src::NamedTuple) + for k in keys(dest) + _mapadd!(f, dest[k], src[k]) + end + return nothing + end + _mapadd!(f, dest::Nothing, src) = nothing + _mapadd!(f, dest::AbstractArray, src::AbstractArray) = + map!((s, d) -> d + f(s), dest, src, dest) + + Δ = zero!(g_bchain(x)) + hX = [x2dualwrtj(xx, 0) for xx in x] + for i = 1:3 + for j = 1:length(x) + hX[j] = x2dualwrtj(x[j], i) + _mapadd!(ε₁ε₂part, Δ, g_bchain(hX)) + hX[j] = x2dualwrtj(x[j], 0) + end + end + return Δ +end + +function ΔF(x, ps) + ΔΨ = 0.0 + hX = [x2dualwrtj(xx, 0) for xx in x] + for i = 1:3 + for j = 1:Nel + hX[j] = x2dualwrtj(x[j], i) # ∂Φ/∂xj_{i} + ΔΨ += BFwf_chain(hX, ps, st)[1].epsilon12 + hX[j] = x2dualwrtj(x[j], 0) + end + end + return ΔΨ +end + +function ∇ΔF(x, ps) + g_bchain = xx -> Zygote.gradient(p -> BFwf_chain(xx, p, st)[1], ps)[1] + p, = destructure(grad_lap(g_bchain, x)) + return p +end + +#W0, re = destructure(ps) +#Fp = w -> ΔF(X, re(w)) +#dFp = w -> ∇ΔF(X, re(w)) +#fdtest(Fp, dFp, W0) \ No newline at end of file diff --git a/tmp/test_matmul.jl b/tmp/test_matmul.jl new file mode 100644 index 0000000..ffa3473 --- /dev/null +++ b/tmp/test_matmul.jl @@ -0,0 +1,92 @@ +using LinearAlgebra +using BenchmarkTools +using ACEpsi, Polynomials4ML +using ACEpsi: setupBFState, Nuc, BFwf_lux +using ACEpsi.AtomicOrbitals: make_nlms_spec +using Polynomials4ML.Utils: gensparse +using Lux +using Random +using StaticArrays + +function BFwf_lux_AA(Nel::Integer, bRnl, bYlm, nuclei; totdeg = 15, + ν = 3, T = Float64, + sd_admissible = bb -> (true), + envelope = x -> x) # enveolpe to be replaced by SJ-factor + + spec1p = make_nlms_spec(bRnl, bYlm; + totaldegree = totdeg) + + # size(X) = (nX, 3); length(Σ) = nX + # aobasis = AtomicOrbitalsBasis(bRnl, bYlm; totaldegree = totdeg, nuclei = nuclei, ) + + # define sparse for n-correlations + tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] + default_admissible = bb -> (length(bb) == 0) || (sum(b[1] - 1 for b in bb ) <= totdeg) + + specAA = gensparse(; NU = ν, tup2b = tup2b, admissible = default_admissible, + minvv = fill(0, ν), + maxvv = fill(length(spec1p), ν), + ordered = true) + spec = [ vv[vv .> 0] for vv in specAA if !(isempty(vv[vv .> 0]))] + + # further restrict + spec = [t for t in spec if sd_admissible([spec1p[t[j]] for j = 1:length(t)])] + + # define n-correlation + corr1 = Polynomials4ML.SparseSymmProd(spec) + + # ----------- Lux connections --------- + # AtomicOrbitalsBasis: (X, Σ) -> (length(nuclei), nX, length(spec1)) + prodbasis_layer = ACEpsi.AtomicOrbitals.ProductBasisLayer(spec1p, bRnl, bYlm) + aobasis_layer = ACEpsi.AtomicOrbitals.AtomicOrbitalsBasisLayer(prodbasis_layer, nuclei) + + # BackFlowPooling: (length(nuclei), nX, length(spec1 from totaldegree)) -> (nX, 3, length(nuclei), length(spec1)) + pooling = ACEpsi.BackflowPooling(aobasis_layer) + pooling_layer = ACEpsi.lux(pooling) + + reshape_func = x -> reshape(x, (size(x, 1), prod(size(x)[2:end]))) + + # (nX, 3, length(nuclei), length(spec1 from totaldegree)) -> (nX, length(spec)) + corr_layer = Polynomials4ML.lux(corr1) + return Chain(; ϕnlm = aobasis_layer, bA = pooling_layer, reshape = WrappedFunction(reshape_func), + bAA = corr_layer, transpose_layer = WrappedFunction(transpose)) +end + +Rnldegree = 4 +Ylmdegree = 4 +totdegree = 8 +Nel = 5 +X = randn(SVector{3, Float64}, Nel) +Σ = rand(spins(), Nel) +nuclei = [ Nuc(3 * rand(SVector{3, Float64}), 1.0) for _=1:3 ] + +## + +# Defining AtomicOrbitalsBasis +bRnl = ACEpsi.AtomicOrbitals.RnlExample(Rnldegree) +bYlm = RYlmBasis(Ylmdegree) + +# setup state +BFwf_chain = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree, ν = 2) +ps, st = setupBFState(MersenneTwister(1234), BFwf_chain, Σ) +F(X) = BFwf_chain(X, ps, st)[1] + +# Define a shorter Chain up to AA +BFwf_AA_chain = BFwf_lux_AA(Nel, bRnl, bYlm, nuclei; totdeg = totdegree, ν = 2) +ps2, st2 = setupBFState(MersenneTwister(1234), BFwf_AA_chain, Σ) +F2(X) = BFwf_AA_chain(X, ps2, st2)[1] + +AA = F2(X) + +# getting weight from dense_layer +W = ps.branch.bf.hidden1.W +typeof(AA) # Transpose{Float64, ObjectPools.FlexCachedArray{Float64,....}} +typeof(W) +AA_mat = Matrix(AA) # Matrix{Float64} +@assert W * AA ≈ W * AA_mat + +@btime F(X) # 46.671 μs +@btime F2(X) # 22.180 μs +@btime W * AA # 37.171 μs +@btime W * AA_mat # 9.400 μs +@btime W * Matrix(AA) # 18.500 μs \ No newline at end of file diff --git a/tmp/try.jl b/tmp/try.jl new file mode 100644 index 0000000..5c53a49 --- /dev/null +++ b/tmp/try.jl @@ -0,0 +1,45 @@ +using ACEpsi, Polynomials4ML, StaticArrays, Test +using Polynomials4ML: natural_indices, degree, SparseProduct +using ACEpsi.AtomicOrbitals: Nuc, make_nlms_spec, evaluate +using ACEpsi: BackflowPooling, BFwf_lux, setupBFState, JPauliNet +using ACEbase.Testing: print_tf, fdtest +using ACEpsi.vmc: gradient, laplacian, grad_params +using ACEbase.Testing: print_tf, fdtest +using LuxCore +using Lux +using Zygote +using Optimisers # mainly for the destrcuture(ps) function +using Random +using Printf +using LinearAlgebra +using BenchmarkTools + +using HyperDualNumbers: Hyper + + +Nel = 5 +X = randn(SVector{3, Float64}, Nel) +Σ = rand(spins(), Nel) +nuclei = [ Nuc(3 * rand(SVector{3, Float64}), 1.0) for _=1:3 ] + +# wrap it as HyperDualNumbers +x2dualwrtj(x, j) = SVector{3}([Hyper(x[i], i == j, i == j, 0) for i = 1:3]) +hX = [x2dualwrtj(x, 0) for x in X] +hX[1] = x2dualwrtj(X[1], 1) # test eval for grad wrt x coord of first elec + +js = JPauliNet(nuclei) +jastrow_layer = ACEpsi.lux(js) +ps, st = LuxCore.setup(MersenneTwister(1234), jastrow_layer) +st = (Σ = Σ,) + +A1 = jastrow_layer(X, ps, st) +hA1 = jastrow_layer(hX, ps, st) +print_tf(@test hA1[1].value ≈ A1[1]) + +@info("Test ∇ψ w.r.t. X") +y, st = Lux.apply(jastrow_layer, X, ps, st) + +F(X) = jastrow_layer(X, ps, st)[1] +dF(X) = Zygote.gradient(x -> jastrow_layer(x, ps, st)[1], X)[1] +fdtest(F, dF, X, verbose = true) +