diff --git a/examples/potential/forces.jl b/examples/potential/forces.jl index 028b710..450b326 100644 --- a/examples/potential/forces.jl +++ b/examples/potential/forces.jl @@ -1,38 +1,34 @@ -using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML using Polynomials4ML: LinearLayer, RYlmBasis, lux -using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis rng = Random.MersenneTwister() ## rcut = 5.5 maxL = 0 -Aspec, AAspec = degord2spec(; totaldegree = 6, - order = 3, +totdeg = 6 +ord = 3 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, Lmax = maxL, ) -l_basis, ps_basis, st_basis = equivariant_model(AAspec, maxL) +l_basis, ps_basis, st_basis = equivariant_model(AAspec, radial, maxL; islong = false) X = [ @SVector(randn(3)) for i in 1:10 ] -B = l_basis(X, ps_basis, st_basis)[1][1] +B = l_basis(X, ps_basis, st_basis)[1] -# now build another model with a better transform -L = maximum(b.l for b in Aspec) +# now extend the above BB basis to a model len_BB = length(B) -get1 = WrappedFunction(t -> t[1]) -embed = Parallel(nothing; - Rn = Chain(trans = WrappedFunction(xx -> [1/(1+norm(x)) for x in xx]), - poly = l_basis.layers.embed.layers.Rn, ), - Ylm = Chain(Ylm = lux(RYlmBasis(L)), ) ) - -model = Chain( - embed = embed, - A = l_basis.layers.A, - AA = l_basis.layers.AA, - # AA_sort = l_basis.layers.AA_sort, - BB = l_basis.layers.BB, - get1 = WrappedFunction(t -> t[1]), - dot = LinearLayer(len_BB, 1), - get2 = WrappedFunction(t -> t[1]), ) + +model = append_layer(l_basis, WrappedFunction(t -> real(t)); l_name=:real) +model = append_layer(model, LinearLayer(len_BB, 1); l_name=:dot) +model = append_layer(model, WrappedFunction(t -> t[1]); l_name=:get1) + ps, st = Lux.setup(rng, model) out, st = model(X, ps, st) @@ -158,7 +154,7 @@ end using JuLIP JuLIP.usethreads!(false) -ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W)) +ps.dot.W[:] .= 1e-2 * randn(length(ps.dot.W)) at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) calc = Pot.LuxCalc(model, ps, st, rcut) @@ -217,4 +213,4 @@ end loss(at, calc, p_vec) -ReverseDiff.gradient(p -> loss(at, calc, p), p_vec) +# ReverseDiff.gradient(p -> loss(at, calc, p), p_vec) diff --git a/examples/potential/forces_chho.jl b/examples/potential/forces_chho.jl new file mode 100644 index 0000000..ad23490 --- /dev/null +++ b/examples/potential/forces_chho.jl @@ -0,0 +1,256 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML +using Polynomials4ML: LinearLayer, RYlmBasis, lux +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis +rng = Random.MersenneTwister() + +## + +rcut = 5.5 +maxL = 0 +totdeg = 6 +ord = 3 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = maxL, ) + +l_basis, ps_basis, st_basis = equivariant_model(AAspec, radial, maxL; islong = false) +X = [ @SVector(randn(3)) for i in 1:10 ] +B = l_basis(X, ps_basis, st_basis)[1] + +# now extend the above BB basis to a model +len_BB = length(B) + +model = append_layer(l_basis, WrappedFunction(t -> real(t)); l_name=:real) +model = append_layer(model, LinearLayer(len_BB, 1); l_name=:dot) +model = append_layer(model, WrappedFunction(t -> t[1]); l_name=:get1) + +ps, st = Lux.setup(rng, model) +out, st = model(X, ps, st) + +# testing derivative (forces) +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] + +## + +module Pot + import JuLIP, Zygote, StaticArrays + import JuLIP: cutoff, Atoms + import ACEbase: evaluate!, evaluate_d! + import StaticArrays: SVector, SMatrix + import ReverseDiff + import ChainRulesCore + import ChainRulesCore: rrule, ignore_derivatives + + import Optimisers: destructure + + struct LuxCalc <: JuLIP.SitePotential + luxmodel + ps + st + rcut::Float64 + restructure + end + + function LuxCalc(luxmodel, ps, st, rcut) + pvec, rest = destructure(ps) + return LuxCalc(luxmodel, ps, st, rcut, rest) + end + + cutoff(calc::LuxCalc) = calc.rcut + + function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0) + E, st = calc.luxmodel(Rs, calc.ps, calc.st) + return E[1] + end + + function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0) + g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], Rs)[1] + @assert length(g) == length(Rs) <= length(dEs) + dEs[1:length(g)] .= g + return dEs + end + + # ----- parameter estimation stuff + + + function lux_energy(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + return sum( i -> begin + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + Ei, st = calc.luxmodel(Rs, ps, st) + Ei[1] + end, + 1:length(at) + ) + end + + + function lux_efv(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + E = 0.0 + F = zeros(SVector{3, Float64}, length(at)) + V = zero(SMatrix{3, 3, Float64}) + for i = 1:length(at) + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + comp = Zygote.withgradient(_X -> calc.luxmodel(_X, ps, st)[1], Rs) + Ei = comp.val + _∇Ei = comp.grad[1] + ∇Ei = ReverseDiff.value.(_∇Ei) + # energy + E += Ei + + # Forces + for j = 1:length(Rs) + F[Js[j]] -= ∇Ei[j] + F[i] += ∇Ei[j] + end + + # Virial + if length(Rs) > 0 + V -= sum(∇Eij * Rij' for (∇Eij, Rij) in zip(∇Ei, Rs)) + end + end + + return E, F, V + end + +# site_virial(dV::AbstractVector{JVec{T1}}, R::AbstractVector{JVec{T2}} +# ) where {T1, T2} = ( +# length(R) > 0 ? (- sum( dVi * Ri' for (dVi, Ri) in zip(dV, R) )) +# : zero(JMat{fltype_intersect(T1, T2)}) +# ) + # function rrule(::typeof(lux_energy), at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + # E = lux_energy(at, calc, ps, st) + # function pb(Δ) + # nlist = JuLIP.neighbourlist(at, calc.rcut) + # @show Δ + # error("stop") + # function pb_inner(i) + # Js, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, i) + # gi = ReverseDiff.gradient() + # end + # for i = 1:length(at) + # Ei, st = calc.luxmodel(Rs, calc.ps, calc.st) + # E += Ei[1] + # end + # end + # end + +end + +## + +using JuLIP +JuLIP.usethreads!(false) +ps.dot.W[:] .= 1e-2 * randn(length(ps.dot.W)) + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +calc = Pot.LuxCalc(model, ps, st, rcut) +JuLIP.energy(calc, at) +JuLIP.forces(calc, at) +JuLIP.virial(calc, at) +Pot.lux_energy(at, calc, ps, st) + +@time JuLIP.energy(calc, at) +@time Pot.lux_energy(at, calc, ps, st) +@time JuLIP.forces(calc, at) + +## + +using Optimisers, ReverseDiff + +p_vec, _rest = destructure(ps) +f(_pvec) = Pot.lux_energy(at, calc, _rest(_pvec), st) + +f(p_vec) +gz = Zygote.gradient(f, p_vec)[1] + +@time f(p_vec) +@time Zygote.gradient(f, p_vec)[1] + +# We can use either Zygote or ReverseDiff for gradients. +gr = ReverseDiff.gradient(f, p_vec) +@show gr ≈ gz + +@info("Interestingly ReverseDiff is much faster here, almost optimal") +@time f(p_vec) +@time Zygote.gradient(f, p_vec)[1] +@time ReverseDiff.gradient(f, p_vec) + +## + +@info("Compute Energies, Forces and Virials at the same time") +E, F, V = Pot.lux_efv(at, calc, ps, st) +@show E ≈ JuLIP.energy(calc, at) +@show F ≈ JuLIP.forces(calc, at) +@show V ≈ JuLIP.virial(calc, at) + +## + +# make up a baby loss function type thing. +function loss(at, calc, p_vec) + ps = _rest(p_vec) + st = calc.st + E, F, V = Pot.lux_efv(at, calc, ps, st) + Nat = length(at) + return (E / Nat)^2 + + sum( f -> sum(abs2, f), F ) / Nat + + sum(abs2, V) +end + +loss(at, calc, p_vec) + +# ==== +using Polynomials4ML +import ChainRulesCore: ProjectTo +using ChainRulesCore +using SparseArrays +function Polynomials4ML._pullback_evaluate(∂A, basis::Polynomials4ML.PooledSparseProduct{NB}, BB::Polynomials4ML.TupMat) where {NB} + nX = size(BB[1], 1) + TA = promote_type(eltype.(BB)..., eltype(∂A)) + # @show TA + ∂BB = ntuple(i -> zeros(TA, size(BB[i])...), NB) + Polynomials4ML._pullback_evaluate!(∂BB, ∂A, basis, BB) + return ∂BB +end + +function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) + dy = if axes(dx) == project.axes + dx + else + if size(dx) != (length(project.axes[1]), length(project.axes[2])) + throw(_projection_mismatch(project.axes, size(dx))) + end + reshape(dx, project.axes) + end + T = promote_type(ChainRulesCore.project_type(project.element), eltype(dx)) + nzval = Vector{T}(undef, length(project.rowval)) + k = 0 + for col in project.axes[2] + for i in project.nzranges[col] + row = project.rowval[i] + val = dy[row, col] + nzval[k += 1] = project.element(val) + end + end + m, n = map(length, project.axes) + return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) +end + + + +## +ReverseDiff.gradient(p -> loss(at, calc, p), p_vec) diff --git a/examples/potential/test_potential.jl b/examples/potential/test_potential.jl index 53926af..a752ccc 100644 --- a/examples/potential/test_potential.jl +++ b/examples/potential/test_potential.jl @@ -1,11 +1,13 @@ -using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML using Polynomials4ML: LinearLayer, RYlmBasis, lux -using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis using JuLIP, Combinatorics, Test using ACEbase.Testing: println_slim, print_tf, fdtest using Optimisers: destructure using Printf +L = 0 + include("staticprod.jl") function grad_test2(f, df, X::AbstractVector; verbose = true) @@ -17,7 +19,7 @@ function grad_test2(f, df, X::AbstractVector; verbose = true) verbose && @printf("---------|----------- \n") verbose && @printf(" h | error \n") verbose && @printf("---------|----------- \n") - for h in 0.1.^(-3:9) + for h in 0.1.^(0:12) gh = [ (f(X + h * EE[:, i]) - F) / h for i = 1:nX ] verbose && @printf(" %.1e | %.2e \n", h, norm(gh - ∇F, Inf)) end @@ -27,10 +29,16 @@ rng = Random.MersenneTwister() rcut = 5.5 maxL = 0 -L = 0 -Aspec, AAspec = degord2spec(; totaldegree = 6, - order = 3, - Lmax = 0, ) +totdeg = 6 +ord = 3 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = maxL, ) cats = AtomicNumber.([:W, :Cu, :Ni, :Fe, :Al]) ipairs = collect(Combinatorics.permutations(1:length(cats), 2)) allcats = collect(SVector{2}.(Combinatorics.permutations(cats, 2))) @@ -52,7 +60,7 @@ for bb in ori_AAspec push!(new_AAspec, newbb) end -luxchain, ps, st = equivariant_model(new_AAspec, L; categories=allcats, islong = false) +luxchain, ps, st = equivariant_model(new_AAspec, radial, L; categories=allcats, islong = false) at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) iCu = [5, 12]; iNi = [3, 8]; iAl = [10]; iFe = [6]; @@ -67,7 +75,7 @@ get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] Z0S = get_Z0S(z0, Zs) # input of luxmodel -X = (Rs, Z0S) +X = [Rs, Z0S] out, st = luxchain(X, ps, st) @@ -83,14 +91,15 @@ model(X, ps, st) # testing derivative (forces) g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] +grad_model(X, ps, st) = Zygote.gradient(_X -> model(_X, ps, st)[1], X)[1] - -F(Rs) = model((Rs, Z0S), ps, st)[1] -dF(Rs) = Zygote.gradient(rs -> model((rs, Z0S), ps, st)[1], Rs)[1] +F(Rs) = model([Rs, Z0S], ps, st)[1] +dF(Rs) = Zygote.gradient(rs -> model([rs, Z0S], ps, st)[1], Rs)[1] ## @info("test derivative w.r.t X") print_tf(@test fdtest(F, dF, Rs; verbose=true)) +println() @info("test derivative w.r.t parameter") @@ -103,3 +112,59 @@ dFp = w -> ( gl = Zygote.gradient(p -> model(X, p, st)[1], ps)[1]; destructure(g grad_test2(Fp, dFp, W0) +# === define toy loss === +function loss(X, p) + ps = _rest(p) + g = grad_model(X, ps, st) + return sum(norm.(g)) +end + +p_vec, _rest = destructure(ps) + +# === override useful functions === +using Polynomials4ML +import ChainRulesCore: ProjectTo +using ChainRulesCore +using SparseArrays +function Polynomials4ML._pullback_evaluate(∂A, basis::Polynomials4ML.PooledSparseProduct{NB}, BB::Polynomials4ML.TupMat) where {NB} + nX = size(BB[1], 1) + TA = promote_type(eltype.(BB)..., eltype(∂A)) + # @show TA + ∂BB = ntuple(i -> zeros(TA, size(BB[i])...), NB) + Polynomials4ML._pullback_evaluate!(∂BB, ∂A, basis, BB) + return ∂BB +end + +function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) + dy = if axes(dx) == project.axes + dx + else + if size(dx) != (length(project.axes[1]), length(project.axes[2])) + throw(_projection_mismatch(project.axes, size(dx))) + end + reshape(dx, project.axes) + end + T = promote_type(ChainRulesCore.project_type(project.element), eltype(dx)) + nzval = Vector{T}(undef, length(project.rowval)) + k = 0 + for col in project.axes[2] + for i in project.nzranges[col] + row = project.rowval[i] + val = dy[row, col] + nzval[k += 1] = project.element(val) + end + end + m, n = map(length, project.axes) + return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) +end + +# === reverse over reverse === +using ReverseDiff +gg1 = ReverseDiff.gradient(_p -> loss(X, _p), p_vec) + +using ACEbase +ACEbase.Testing.fdtest( + _p -> loss(X, _p), + _p -> ReverseDiff.gradient(__p -> loss1(X, __p), _p), + p_vec ) +## \ No newline at end of file diff --git a/examples/potential/test_potential_multi_chho.jl b/examples/potential/test_potential_multi_chho.jl new file mode 100644 index 0000000..2bbff30 --- /dev/null +++ b/examples/potential/test_potential_multi_chho.jl @@ -0,0 +1,174 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML +using Polynomials4ML: LinearLayer, RYlmBasis, lux +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis +using JuLIP, Combinatorics, Test +using ACEbase.Testing: println_slim, print_tf, fdtest +using Optimisers: destructure +using Printf + +L = 0 + +include("staticprod.jl") + +function grad_test2(f, df, X::AbstractVector; verbose = true) + F = f(X) + ∇F = df(X) + nX = length(X) + EE = Matrix(I, (nX, nX)) + + verbose && @printf("---------|----------- \n") + verbose && @printf(" h | error \n") + verbose && @printf("---------|----------- \n") + for h in 0.1.^(0:12) + gh = [ (f(X + h * EE[:, i]) - F) / h for i = 1:nX ] + verbose && @printf(" %.1e | %.2e \n", h, norm(gh - ∇F, Inf)) + end +end + +rng = Random.MersenneTwister() + +rcut = 5.5 +maxL = 0 +totdeg = 6 +ord = 3 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = maxL, ) +cats = AtomicNumber.([:W, :Cu, :Ni, :Fe, :Al]) + +ipairs = collect(Combinatorics.permutations(1:length(cats), 2)) +allcats = collect(SVector{2}.(Combinatorics.permutations(cats, 2))) + +for (i, cat) in enumerate(cats) + push!(ipairs, [i, i]) + push!(allcats, SVector{2}([cat, cat])) +end + +new_spec = [] +ori_AAspec = deepcopy(AAspec) +new_AAspec = [] + +for bb in ori_AAspec + newbb = [] + for (t, ip) in zip(bb, ipairs) + push!(newbb, (t..., s = cats[ip])) + end + push!(new_AAspec, newbb) +end + +luxchain, ps, st = equivariant_model(new_AAspec, radial, L; categories=allcats, islong = false) + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +iCu = [5, 12]; iNi = [3, 8]; iAl = [10]; iFe = [6]; +at.Z[iCu] .= cats[2]; at.Z[iNi] .= cats[3]; at.Z[iAl] .= cats[4]; at.Z[iFe] .= cats[5]; +nlist = JuLIP.neighbourlist(at, rcut) +_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1) +# centere atom +z0 = at.Z[1] + +# serialization, I want the input data structure to lux as simple as possible +get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] +Z0S = get_Z0S(z0, Zs) + +# input of luxmodel +X = [Rs, Z0S] + +out, st = luxchain(X, ps, st) + + +# == lux chain eval and grad +B = out + +model = append_layers(luxchain, get1 = WrappedFunction(t -> real.(t)), dot = LinearLayer(length(B), 1), get2 = WrappedFunction(t -> t[1])) + +ps, st = Lux.setup(MersenneTwister(1234), model) +ps.dot.W[:] = ps.dot.W[:] / 1000 + +model(X, ps, st) + +# testing derivative (forces) +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1][1] +grad_model(X, ps, st) = Zygote.gradient(_X -> model(_X, ps, st)[1], X)[1] + +F(Rs) = model([Rs, Z0S], ps, st)[1] +dF(Rs) = Zygote.gradient(rs -> model([rs, Z0S], ps, st)[1], Rs)[1] + +## +@info("test derivative w.r.t X") +print_tf(@test fdtest(F, dF, Rs; verbose=true)) +println() + + +@info("test derivative w.r.t parameter") +p = Zygote.gradient(p -> model(X, p, st)[1], ps)[1] +p, = destructure(p) + +W0, re = destructure(ps) +Fp = w -> model(X, re(w), st)[1] +dFp = w -> ( gl = Zygote.gradient(p -> model(X, p, st)[1], ps)[1]; destructure(gl)[1]) +grad_test2(Fp, dFp, W0) + + +# === define toy loss === +function loss(X, p) + ps = _rest(p) + g = grad_model(X, ps, st)[1] + return sum(norm.(g)) +end + +p_vec, _rest = destructure(ps) + +# === override useful functions === +using Polynomials4ML +import ChainRulesCore: ProjectTo +using ChainRulesCore +using SparseArrays +function Polynomials4ML._pullback_evaluate(∂A, basis::Polynomials4ML.PooledSparseProduct{NB}, BB::Polynomials4ML.TupMat) where {NB} + nX = size(BB[1], 1) + TA = promote_type(eltype.(BB)..., eltype(∂A)) + # @show TA + ∂BB = ntuple(i -> zeros(TA, size(BB[i])...), NB) + Polynomials4ML._pullback_evaluate!(∂BB, ∂A, basis, BB) + return ∂BB +end + +function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) + dy = if axes(dx) == project.axes + dx + else + if size(dx) != (length(project.axes[1]), length(project.axes[2])) + throw(_projection_mismatch(project.axes, size(dx))) + end + reshape(dx, project.axes) + end + T = promote_type(ChainRulesCore.project_type(project.element), eltype(dx)) + nzval = Vector{T}(undef, length(project.rowval)) + k = 0 + for col in project.axes[2] + for i in project.nzranges[col] + row = project.rowval[i] + val = dy[row, col] + nzval[k += 1] = project.element(val) + end + end + m, n = map(length, project.axes) + return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) +end + +# === reverse over reverse === +using ReverseDiff +g1 = ReverseDiff.gradient(_p -> loss(X, _p), p_vec) + +Zygote.gradient(_p -> loss(X, _p), p_vec) + +using ACEbase +ACEbase.Testing.fdtest( + _p -> loss(X, _p), + _p -> Zygote.gradient(__p -> loss(X, __p), _p), + p_vec ) +## \ No newline at end of file diff --git a/src/EquivariantModels.jl b/src/EquivariantModels.jl index d714d67..1accbb9 100644 --- a/src/EquivariantModels.jl +++ b/src/EquivariantModels.jl @@ -1,5 +1,6 @@ module EquivariantModels +include("radial.jl") include("utils.jl") include("lux_utils.jl") include("categorical.jl") diff --git a/src/builder.jl b/src/builder.jl index aa54760..1ecc941 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -1,6 +1,6 @@ using LinearAlgebra using SparseArrays: SparseMatrixCSC, sparse -using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector +using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter using Polynomials4ML: legendre_basis, RYlmBasis, natural_indices, degree using Polynomials4ML.Utils: gensparse using Lux: WrappedFunction @@ -13,9 +13,11 @@ export equivariant_model, equivariant_SYY_model, equivariant_luxchain_constructo P4ML = Polynomials4ML -RPE_filter(L) = bb -> (length(bb) == 0) || ((abs(sum(b.m for b in bb)) <= L) && iseven(sum(b.l for b in bb)+L)) +RPE_filter(L) = bb -> (length(bb) == 0) || ((abs(sum(b.m for b in bb)) <= L) && iseven(sum(b.l for b in bb)+L)) && ( length(bb) == 1 && L == 0 ? bb[1].l == 0 : true ) RPE_filter_long(L) = bb -> (length(bb) == 0) || (abs(sum(b.m for b in bb)) <= L) +RPE_filter_real(L) = bb -> (length(bb) == 0) || mm_filter([b.m for b in bb],L) && iseven(sum(b.l for b in bb)+L) && ( length(bb) == 1 && L == 0 ? bb[1].l == 0 : true ) + """ _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T},Rot3DCoeffs_real{L,T},Rot3DCoeffs_long{L,T}}, spec::Vector{Vector{NamedTuple}}) @@ -36,16 +38,12 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro for i = 1:length(spec) # get the specification of the ith basis function, which is a tuple/vec of NamedTuples pib = spec[i] - # skip it unless all m are zero, because we want to consider each - # (nn, ll) block only once. - # if !all(b.m == 0 for b in pib) - # continue - # end - # But we can not do this anymore for L≥1, so I add an nnllset # get the rotation-coefficients for this basis group # the bs are the basis functions corresponding to the columns + # The nnlllist is created because we want to consider each + # (nn, ll) block only once. nn = SVector([onep.n for onep in pib]...) ll = SVector([onep.l for onep in pib]...) # get a SVector of ll index if haskey(pib[1],:s) @@ -102,7 +100,12 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro if !isnothing(idxAA) push!(Irow, idxB) push!(Jcol, idxAA) - push!(vals, U[irow, icol]) + if norm(U[irow, icol] - real.(U[irow, icol]))<1e-12 + push!(vals, real.(U[irow, icol])) + else + push!(vals, U[irow, icol]) + end + # push!(vals, U[irow, icol]) end end end @@ -119,34 +122,33 @@ end # TODO: symmetry group O(d)? """ -xx2AA(spec_nlm, d=3, categories=[]; radial_basis=legendre_basis) -Construct a lux chain that maps a configuration to the corresponding the AA basis +xx2AA(spec_nlm, radial; d=3, categories=[]) +Construct a lux chain that maps a configuration to the corresponding AA basis spec_nlm: Specification of the AA bases +radial : specified radial basis, with both basis and its specification +=== +OptionalField: d: Input dimension categories : A list of categories -radial_basis : specified radial basis, default using P4ML.legendre_basis """ -function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Configuration to AA bases - this is what all chains have in common + +function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3, rSH = false) # Configuration to AA bases - this is what all chains have in common # from spec_nlm to all possible spec1p spec1p, lmax, nmax = specnlm2spec1p(spec_nlm) + # An assertation whether all the radial specs are in spec1p + @assert issubset(nset(spec1p), radial.Radialspec) || issubset(nlset(spec1p), radial.Radialspec) + dict_spec1p = Dict([spec1p[i] => i for i = 1:length(spec1p)]) - Ylm = CYlmBasis(lmax) - Rn = radial_basis(nmax) - # TODO: make it Rnl = radial_basis(nmax,lmax) + Ylm = rSH ? RYlmBasis(lmax) : CYlmBasis(lmax) + # Rn = radial_basis(nmax) if !isempty(categories) - # Read categories from x - TODO: discuss which format we like it to be... - # For now we just give get_cat(x) a random value - #get_cat(x) = length(categories) > 1 ? (iseven(floor(norm(x))) ? categories[1] : categories[2]) : categories[1] - #_get_cat(x) = get_cat.(x) - # Define categorical bases δs = CategoricalBasis(categories) l_δs = P4ML.lux(δs) end - spec1pidx = isempty(categories) ? getspec1idx(spec1p, Rn, Ylm) : getspec1idx(spec1p, Rn, Ylm, δs) - # TODO: write getspec1idx for Rnl basis + spec1pidx = isempty(categories) ? getspec1idx(spec1p, radial.Radialspec, Ylm) : getspec1idx(spec1p, radial.Radialspec, Ylm, δs) bA = P4ML.PooledSparseProduct(spec1pidx) Spec = sort.([ [dict_spec1p[spec_nlm[k][j]] for j = 1:length(spec_nlm[k])] for k = 1:length(spec_nlm) ]) @@ -154,7 +156,7 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co bAA = P4ML.SparseSymmProd(Spec) # wrapping into lux layers - l_Rn = P4ML.lux(Rn) + l_Rnl = radial.Rnl l_Ylm = P4ML.lux(Ylm) l_bA = P4ML.lux(bA) l_bAA = P4ML.lux(bAA) @@ -169,13 +171,15 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co _norm(x) = norm.(x) if isempty(categories) - l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity)) - l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm) - luxchain = Chain(l_xnx = l_xnx, embed = l_embed, A = l_bA , AA = l_bAA) - else - l_xnxz = Lux.BranchLayer(normx = WrappedFunction(x -> _norm(x[1])), x = WrappedFunction(x -> x[1]), catlist = WrappedFunction(x -> x[2])) - l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm, δs = l_δs) - luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA) + l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm) + luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA) + else + l_Rnl = append_layer(Chain(get_pos = get_i(1), ), l_Rnl; l_name = :radial_poly) + l_Ylm = append_layer(Chain(get_pos = get_i(1), ), l_Ylm; l_name = :angle_poly) + l_δs = append_layer(Chain(get_cat = get_i(2), ), l_δs; l_name = :categorical) + + l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm, δs = l_δs) + luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA) # Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA) end # luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA) @@ -192,15 +196,19 @@ L : Largest equivariance level categories : A list of categories radial_basis : specified radial basis, default using P4ML.legendre_basis """ -function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) +function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) + if rSH && L > 0 + error("rSH is only implemented (for now) for L = 0") + end + # first filt out those unfeasible spec_nlm - filter_init = islong ? RPE_filter_long(L) : RPE_filter(L) + filter_init = rSH ? RPE_filter_real(L) : (islong ? RPE_filter_long(L) : RPE_filter(L)) spec_nlm = spec_nlm[findall(x -> filter_init(x) == 1, spec_nlm)] # sort!(spec_nlm, by = x -> length(x)) spec_nlm = closure(spec_nlm,filter_init; categories = categories) - luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis) + luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d, rSH = rSH) F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1] if islong @@ -209,15 +217,15 @@ function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis= pos = Vector{Any}(undef, L+1) for l = 0:L - filter = RPE_filter(l) - cgen = Rot3DCoeffs(l) # TODO: this should be made group related + filter = rSH ? RPE_filter_real(L) : RPE_filter(l) + cgen = rSH ? Rot3DCoeffs_real(l) : Rot3DCoeffs(l) # TODO: this should be made group related tmp = spec_nlm[findall(x -> filter(x) == 1, spec_nlm)] C[l+1] = _rpi_A2B_matrix(cgen, tmp) pos[l+1] = findall(x -> filter(x) == 1, spec_nlm) # [ dict[tmp[j]] for j = 1:length(tmp)] end else - cgen = Rot3DCoeffs(L) # TODO: this should be made group related + cgen = rSH ? Rot3DCoeffs_real(L) : Rot3DCoeffs(L) # TODO: this should be made group related C = _rpi_A2B_matrix(cgen, spec_nlm) end @@ -233,13 +241,13 @@ function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis= end # more constructors equivariant_model -equivariant_model(totdeg::Int64, ν::Int64, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) = - equivariant_model(degord2spec(;totaldegree = totdeg, order = ν, Lmax=L, radial_basis = radial_basis, islong = islong)[2], L; categories, d, radial_basis, group, islong) +equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) = + equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH) # With the _close function, the input could simply be an nnlllist (nlist,llist) -equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, L::Int64; categories=[], d=3, radial_basis = legendre_basis, group = "O3", islong = true) = begin +equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group = "O3", islong = true, rSH = false) = begin filter = islong ? RPE_filter_long(L) : RPE_filter(L) - equivariant_model(_close(nn, ll; filter = filter), L; categories, d, radial_basis, group, islong) + equivariant_model(_close(nn, ll; filter = filter), radial, L; categories, d, group, islong, rSH) end # ===== Codes that we might remove later ===== @@ -251,14 +259,14 @@ end # What can be adjusted in its input are: (1) total polynomial degree; (2) correlation order; (3) largest L # (4) weight of the order of spherical harmonics; (5) specified radial basis -function equivariant_SYY_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3") +function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3") filter_init = RPE_filter_long(L) spec_nlm = spec_nlm[findall(x -> filter_init(x) == 1, spec_nlm)] # sort!(spec_nlm, by = x -> length(x)) spec_nlm = closure(spec_nlm, filter_init; categories = categories) - luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis) + luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d) F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1] cgen = Rot3DCoeffs_long(L) # TODO: this should be made group related @@ -274,11 +282,11 @@ function equivariant_SYY_model(spec_nlm, L::Int64; categories=[], d=3, radial_ba return luxchain, ps, st end -equivariant_SYY_model(totdeg::Int64, ν::Int64, L::Int64; categories=[], d=3, radial_basis = legendre_basis,group = "O3") = - equivariant_SYY_model(degord2spec(;totaldegree = totdeg, order = ν, Lmax = L, radial_basis = radial_basis, islong=true)[2], L; categories, d, radial_basis, group) +equivariant_SYY_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3,group = "O3") = + equivariant_SYY_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax = L, islong=true)[2], radial, L; categories, d, group) -equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3") = - equivariant_SYY_model(_close(nn, ll; filter = RPE_filter_long(L)), L; categories, d, radial_basis, group) +equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3") = + equivariant_SYY_model(_close(nn, ll; filter = RPE_filter_long(L)), radial, L; categories, d, group) ## TODO: The following should eventually go into ACEhamiltonians.jl rather than this package @@ -312,9 +320,9 @@ function equivariant_luxchain_constructor(totdeg, ν, L; wL = 1, Rn = legendre_b Ylm = CYlmBasis(totdeg) - spec1p = make_nlms_spec(Rn, Ylm; totaldegree = totdeg, admissible = (br, by) -> br + wL * by.l <= totdeg) + spec1p = make_nlms_spec(simple_radial_basis(Rn), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx(spec1p, Rn, Ylm) + spec1pidx = getspec1idx(spec1p, simple_radial_basis(Rn).Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] @@ -370,9 +378,9 @@ end function equivariant_luxchain_constructor_new(totdeg, ν, L; wL = 1, Rn = legendre_basis(totdeg)) Ylm = CYlmBasis(totdeg) - spec1p = make_nlms_spec(Rn, Ylm; totaldegree = totdeg, admissible = (br, by) -> br + wL * by.l <= totdeg) + spec1p = make_nlms_spec(simple_radial_basis(Rn), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx(spec1p, Rn, Ylm) + spec1pidx = getspec1idx(spec1p, simple_radial_basis(Rn).Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] diff --git a/src/radial.jl b/src/radial.jl new file mode 100644 index 0000000..ecee4db --- /dev/null +++ b/src/radial.jl @@ -0,0 +1,37 @@ +using Polynomials4ML: natural_indices, ScalarPoly4MLBasis, lux +using LuxCore: AbstractExplicitContainerLayer, AbstractExplicitLayer + + struct Radial_basis{T <: AbstractExplicitLayer} <:AbstractExplicitContainerLayer{(:Rnl, )} + Rnl::T + # make it meta or just leave it as a NameTuple ? + Radialspec::Vector #{NamedTuple} #TODO: double check this... + end + +Radial_basis(Rnl::AbstractExplicitLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = + Radial_basis(Rnl, [(n = i, ) for i in spec_Rnl]) + +Radial_basis(Rnl::AbstractExplicitLayer) = + try + Radial_basis(Rnl,natural_indices(Rnl.basis)) + catch + try + Radial_basis(Rnl,natural_indices(Rnl.layers.poly.basis)) + catch + error("The specification of this Radial_basis should be given explicitly!") + end + end + +# it is in its current form just for the purpose of testing - a more specific example can be found in forces.jl +function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->1; spec = nothing) + if isnothing(spec) + try + spec = natural_indices(basis) + catch + error("The specification of this Radial_basis should be given explicitly!") + end + end + + f(r) = f_trans(r) * f_cut(r) + + return Radial_basis(Chain(getnorm = WrappedFunction(x -> norm.(x)), trans = WrappedFunction(x -> f.(x)), poly = lux(basis), ), spec) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 73ae1c3..d319690 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,5 @@ using Polynomials4ML: natural_indices - """ _invmap(a::AbstractVector) Return a dictionary that maps the elements of a to their indices @@ -23,29 +22,34 @@ function dropnames(namedtuple::NamedTuple, names::Tuple{Vararg{Symbol}}) end """ -getspec1idx(spec1, bRnl, bYlm) +getspec1idx(spec1, spec_Rnl, bYlm) Return a vector of tuples of indices of spec1 w.r.t actual indices (i.e. 1, 2, 3, ...) of bRnl and bYlm """ -function getspec1idx(spec1, bRnl, bYlm) +function getspec1idx(spec1, spec_Rnl, bYlm) spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) - spec_Rnl = natural_indices(bRnl); - spec_Rnl = [(n = i, ) for i in spec_Rnl] + # try is_l = isinteger(spec_Rnl[1].l); catch; is_l = false; end inv_Rnl = _invmap(spec_Rnl) spec_Ylm = natural_indices(bYlm); inv_Ylm = _invmap(spec_Ylm) spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) - for (i, b) in enumerate(spec1) - spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l))], inv_Ylm[(l=b.l, m=b.m)]) + + if length(spec_Rnl[1]) > 1 && haskey(spec_Rnl[1],:l) + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, ))], inv_Ylm[(l=b.l, m=b.m)]) + end + else + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l))], inv_Ylm[(l=b.l, m=b.m)]) + end end + return spec1idx end -function getspec1idx(spec1, bRnl, bYlm, bδs) +function getspec1idx(spec1, spec_Rnl, bYlm, bδs) spec1idx = Vector{Tuple{Int, Int, Int}}(undef, length(spec1)) - - spec_Rnl = natural_indices(bRnl) - spec_Rnl = [(n = i, ) for i in spec_Rnl] + # try is_l = isinteger(spec_Rnl[1].l); catch; is_l = false; end inv_Rnl = _invmap(spec_Rnl) spec_Ylm = natural_indices(bYlm); inv_Ylm = _invmap(spec_Ylm) @@ -53,9 +57,17 @@ function getspec1idx(spec1, bRnl, bYlm, bδs) slist = bδs.categories spec1idx = Vector{Tuple{Int, Int, Int}}(undef, length(spec1)) - for (i, b) in enumerate(spec1) - spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l, :s))], inv_Ylm[(l=b.l, m=b.m)], val2i(slist, b.s)) + + if length(spec_Rnl[1]) > 1 && haskey(spec_Rnl[1],:l) + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :s))], inv_Ylm[(l=b.l, m=b.m)], val2i(slist, b.s)) + end + else + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l, :s))], inv_Ylm[(l=b.l, m=b.m)], val2i(slist, b.s)) + end end + return spec1idx end @@ -63,18 +75,24 @@ end make_nlms_spec(bRnl, bYlm) Return a vector of tuples of indices of spec1 w.r.t naural indices (i.e. (n = ..., l = ..., m = ...) ) of bRnl and bYlm """ -function make_nlms_spec(bRn, bYlm; +function make_nlms_spec(radial::Radial_basis, bYlm; totaldegree::Int64 = -1, admissible = nothing, nnuc = 0) - spec_Rn = natural_indices(bRn) + spec_Rn = radial.Radialspec spec_Ylm = natural_indices(bYlm) spec1 = [] for (iR, br) in enumerate(spec_Rn), (iY, by) in enumerate(spec_Ylm) if admissible(br, by) - push!(spec1, (n = br, l = by.l, m = by.m)) + if haskey(br,:l) + if br.l == by.l + push!(spec1, (n = br.n, l = by.l, m = by.m)) + end + else + push!(spec1, (n = br.n, l = by.l, m = by.m)) + end end end return spec1 @@ -142,6 +160,9 @@ function specnlm2spec1p(spec_nlm) return spec1p, lmax, nmax + 1 end +nset(spec1p) = [ (n=spec.n,) for spec in spec1p] +nlset(spec1p) = [ (n=spec.n, l=spec.l,) for spec in spec1p] + """ closure(spec_nlm,filter) Make a spec_nlm to be a "complete" set to be symmetrised w.r.t to the filter @@ -196,20 +217,24 @@ end degord2spec(;totaldegree, order, Lmax, radial_basis = legendre_basis, wL = 1, islong = true) Return a list of AA specifications and A specifications """ -function degord2spec(;totaldegree, order, Lmax, catagories = [], radial_basis = legendre_basis, wL = 1, islong = true) - Rn = radial_basis(totaldegree) +function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories = [], wL = 1, islong = true, rSH = false) + # Rn = radial.radial_basis(totaldegree) Ylm = CYlmBasis(totaldegree) - spec1p = make_nlms_spec(Rn, Ylm; totaldegree = totaldegree, admissible = (br, by) -> br + wL * by.l <= totaldegree) + spec1p = make_nlms_spec(radial, Ylm; totaldegree = totaldegree, admissible = (br, by) -> br.n + wL * by.l <= totaldegree) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx(spec1p, Rn, Ylm) + spec1pidx = getspec1idx(spec1p, radial.Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] default_admissible = bb -> length(bb) == 0 || sum(b.n for b in bb) + wL * sum(b.l for b in bb) <= totaldegree # to construct SS, SD blocks - filter_ = islong ? RPE_filter_long(Lmax) : RPE_filter(Lmax) + if rSH + filter_ = RPE_filter_real(Lmax) + else + filter_ = islong ? RPE_filter_long(Lmax) : RPE_filter(Lmax) + end specAA = gensparse(; NU = order, tup2b = tup2b, filter = filter_, admissible = default_admissible, @@ -226,3 +251,5 @@ function degord2spec(;totaldegree, order, Lmax, catagories = [], radial_basis = Aspec = specnlm2spec1p(AAspec)[1] return Aspec, AAspec # Aspecgetspecnlm(spec1p, spec) end + +get_i(i) = WrappedFunction(t -> t[i]) diff --git a/test/runtests.jl b/test/runtests.jl index 729f071..6d78f14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,5 +3,9 @@ using Test @testset "EquivariantModels.jl" begin @testset "CategoricalBasis" begin include("test_categorial.jl") end - @testset "Equivariance" begin include("test_equivariance.jl"); include("test_equiv_with_cate.jl"); end + @testset "Equivariance" begin + include("test_equivariance.jl") + include("test_equiv_with_cate.jl") + include("test_rSH_equivariance.jl") + end end diff --git a/test/test_equiv_with_cate.jl b/test/test_equiv_with_cate.jl index fb49796..02814bb 100644 --- a/test/test_equiv_with_cate.jl +++ b/test/test_equiv_with_cate.jl @@ -1,13 +1,17 @@ using Polynomials4ML, StaticArrays, EquivariantModels, Test, Rotations, LinearAlgebra using ACEbase.Testing: print_tf -using EquivariantModels: getspec1idx, _invmap, dropnames, SList, val2i, xx2AA, degord2spec +using EquivariantModels: getspec1idx, _invmap, dropnames, SList, val2i, xx2AA, degord2spec, simple_radial_basis +using Polynomials4ML: lux include("wigner.jl") L = 4 - -Aspec, AAspec = degord2spec(; totaldegree = 4, - order = 2, +totdeg = 4 +ord = 2 +radial = simple_radial_basis(legendre_basis(totdeg)) +# radial = Radial_basis(legendre_basis(totdeg) |> lux) +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, Lmax = 0, ) cats = [:O,:C] cats_ext = [(:O,:C),(:C,:O),(:O,:O),(:C,:C)] |> unique @@ -22,7 +26,7 @@ _AAspec_tmp2 = [ [(AAspec[i][1]..., s = cats_ext[2]), (AAspec[i][2]..., s = cats append!(AAspec_tmp,_AAspec_tmp) append!(AAspec_tmp,_AAspec_tmp2) -luxchain, ps, st = equivariant_model(AAspec_tmp, L; categories=cats_ext) +luxchain, ps, st = equivariant_model(AAspec_tmp, radial, L; categories=cats_ext) F(X) = luxchain(X, ps, st)[1] species = [ rand(cats) for i = 1:10 ] Species = [ (species[1], species[i]) for i = 1:10 ] @@ -31,13 +35,13 @@ Species = [ (species[1], species[i]) for i = 1:10 ] for ntest = 1:10 local X, θ1, θ2, θ3, Q, QX X = [ @SVector(rand(3)) for i in 1:10 ] - XX = (X, Species) + XX = [X, Species] θ1 = rand() * 2pi θ2 = rand() * 2pi θ3 = rand() * 2pi Q = RotXYZ(θ1, θ2, θ3) QX = [SVector{3}(x) for x in Ref(Q) .* X] - QXX = (QX, Species) + QXX = [QX, Species] print_tf(@test F(XX)[1] ≈ F(QXX)[1]) diff --git a/test/test_equivariance.jl b/test/test_equivariance.jl index 50a1ab6..71c6cf5 100644 --- a/test/test_equivariance.jl +++ b/test/test_equivariance.jl @@ -1,9 +1,8 @@ -using EquivariantModels -using StaticArrays -using Test +using EquivariantModels, StaticArrays, Test, Polynomials4ML, LinearAlgebra using ACEbase.Testing: print_tf using Rotations, WignerD, BlockDiagonals -using LinearAlgebra +using EquivariantModels: Radial_basis +using Polynomials4ML:lux include("wigner.jl") @@ -11,13 +10,15 @@ include("wigner.jl") totdeg = 6 ν = 2 Lmax = 2 +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) for L = 0:Lmax local F, luxchain, ps, st, F2, luxchain2, ps2, st2 - luxchain, ps, st = equivariant_model(totdeg, ν, L;islong = false) + luxchain, ps, st = equivariant_model(totdeg, ν, radial, L;islong = false) F(X) = luxchain(X, ps, st)[1] - luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],L;islong = false) + luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L;islong = false) F2(X) = luxchain(X, ps2, st2)[1] @info("Tesing L = $L O(3) equivariance") @@ -54,9 +55,11 @@ end totdeg = 6 ν = 2 L = Lmax -luxchain, ps, st = equivariant_model(totdeg,ν,L;islong = true) +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) +luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true) F(X) = luxchain(X, ps, st)[1] -luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],L;islong = true) +luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L;islong = true) F2(X) = luxchain(X, ps2, st2)[1] for ntest = 1:10 @@ -73,7 +76,7 @@ for ntest = 1:10 for l = 2:L D = wigner_D(l-1,Matrix(Q))' # D = wignerD(l-1, 0, 0, θ) - print_tf(@test norm.(Ref(D') .* F(X)[l] - F(QX)[l]) |> norm <1e-12) + print_tf(@test norm.(Ref(D') .* F(X)[l] - F(QX)[l]) |> norm <1e-11) end end println() @@ -82,13 +85,15 @@ println() totdeg = 6 ν = 2 L = Lmax -luxchain, ps, st = equivariant_model(totdeg,ν,L;islong = true); +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) +luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true); F(X) = luxchain(X, ps, st)[1] for l = 0:Lmax @info("Consistency check for L = $l") local FF, luxchain, ps, st - luxchain, ps, st = equivariant_model(totdeg,ν,l;islong = false) + luxchain, ps, st = equivariant_model(totdeg,ν,radial,l;islong = false) FF(X) = luxchain(X, ps, st)[1] for ntest = 1:20 @@ -116,7 +121,7 @@ for L = 0:Lmax while iseven(L) != iseven(sum(ll)) ll = rand(0:2,4) end - luxchain, ps, st = equivariant_model(nn,ll,L;islong = false) + luxchain, ps, st = equivariant_model(nn,ll,radial,L;islong = false) F(X) = luxchain(X, ps, st)[1] @info("Tesing L = $L O(3) equivariance") @@ -145,9 +150,11 @@ end totdeg = 6 ν = 2 L = Lmax -luxchain, ps, st = equivariant_SYY_model(totdeg,ν,L); +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) +luxchain, ps, st = equivariant_SYY_model(totdeg,ν,radial,L); F(X) = luxchain(X, ps, st)[1] -luxchain2, ps2, st2 = equivariant_SYY_model(EquivariantModels.degord2spec(;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],L) +luxchain2, ps2, st2 = equivariant_SYY_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L) F2(X) = luxchain(X, ps2, st2)[1] @info("Tesing L = $L O(3) full equivariance") @@ -182,7 +189,7 @@ while iseven(Lmax) != iseven(sum(ll)) global ll = rand(0:2,4) end -luxchain, ps, st = equivariant_SYY_model(nn, ll, L) +luxchain, ps, st = equivariant_SYY_model(nn, ll, radial, L) F(X) = luxchain(X, ps, st)[1] @info("Tesing L = $L O(3) full equivariance") @@ -211,7 +218,7 @@ L = Lmax luxchain, ps, st = equivariant_luxchain_constructor(totdeg,ν,L) F(X) = luxchain(X, ps, st)[1] -# A small comparison - long vector does give us some redundent basis... +# A small comparison - long vector does give us some redundant basis... @info("Equivariance test") l1l2set = [(l1,l2) for l1 = 0:L for l2 = 0:L-l1] diff --git a/test/test_rSH_equivariance.jl b/test/test_rSH_equivariance.jl new file mode 100644 index 0000000..f03dac1 --- /dev/null +++ b/test/test_rSH_equivariance.jl @@ -0,0 +1,139 @@ +using EquivariantModels, StaticArrays, Test, Polynomials4ML, LinearAlgebra +using ACEbase.Testing: print_tf +using Rotations, WignerD, BlockDiagonals +# using EquivariantModels: Radial_basis +# using Polynomials4ML:lux + +include("wigner.jl") + +@info("Testing the chain that generates a single B basis") +totdeg = 6 +ν = 2 +Lmax = 0 +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) + +luxchain, ps, st = equivariant_model(totdeg, ν, radial, 0;islong = false, rSH = true) +F(X) = luxchain(X, ps, st)[1] + +for L = 0:Lmax + local F, luxchain, ps, st, F2, luxchain2, ps2, st2 + luxchain, ps, st = equivariant_model(totdeg, ν, radial, L;islong = false, rSH = true) + F(X) = luxchain(X, ps, st)[1] + + luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true,rSH = true)[2][1:end-1],radial,L;islong = false,rSH = true) + F2(X) = luxchain(X, ps2, st2)[1] + + @info("Tesing L = $L O(3) equivariance") + for _ = 1:30 + local X, θ1, θ2, θ3, Q, QX + X = [ @SVector(rand(3)) for i in 1:10 ] + θ1 = rand() * 2pi + θ2 = rand() * 2pi + θ3 = rand() * 2pi + Q = RotXYZ(θ1, θ2, θ3) + # Q = rand_rot() + QX = [SVector{3}(x) for x in Ref(Q) .* X] + D = wigner_D(L,Matrix(Q))' + # D = wignerD(L, θ, θ, θ) + + print_tf(@test F(X) ≈ F(QX)) + + end + println() + + @info("Tesing consistency between the two ways of input - in particular the ``closure'' of specifications") + for _ = 1:30 + local X + X = [ @SVector(rand(3)) for i in 1:10 ] + print_tf(@test F(X) ≈ F2(X)) + end + println() + +end + +@info("Testing the chain that generates all B bases") +totdeg = 6 +ν = 2 +L = Lmax +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) +luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true) +F(X) = luxchain(X, ps, st)[1] +luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L;islong = true,rSH = true) +F2(X) = luxchain(X, ps2, st2)[1] + +for ntest = 1:10 + local X, θ1, θ2, θ3, Q, QX + X = [ @SVector(rand(3)) for i in 1:10 ] + θ1 = rand() * 2pi + θ2 = rand() * 2pi + θ3 = rand() * 2pi + Q = RotXYZ(θ1, θ2, θ3) + QX = [SVector{3}(x) for x in Ref(Q) .* X] + + print_tf(@test F(X)[1] ≈ F(QX)[1]) +end +println() + +@info("Consistency check") +totdeg = 6 +ν = 2 +L = Lmax +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) +luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true, rSH = true); +F(X) = luxchain(X, ps, st)[1] + +for l = 0:Lmax + @info("Consistency check for L = $l") + local FF, luxchain, ps, st + luxchain, ps, st = equivariant_model(totdeg,ν,radial,l;islong = false, rSH = true) + FF(X) = luxchain(X, ps, st)[1] + + for ntest = 1:20 + X = [ @SVector(rand(3)) for i in 1:10 ] + print_tf(@test F(X)[l+1] == FF(X)) + end + println() +end + +@info("Tesing consistency between the two ways of input - in particular the ``closure'' of specifications") +for _ = 1:10 + local X + X = [ @SVector(rand(3)) for i in 1:10 ] + print_tf(@test length(F(X)) == length(F2(X)) && all([F(X)[i] ≈ F2(X)[i] for i = 1:length(F(X))])) +end +println() + +@info("Tesing the last way of input - given n_list and l_list") + +for L = 0:Lmax + local F, luxchain, ps, st, nn, ll + + nn = rand(0:2,4) + ll = rand(0:2,4) + while iseven(L) != iseven(sum(ll)) + ll = rand(0:2,4) + end + luxchain, ps, st = equivariant_model(nn,ll,radial,L;islong = false, rSH = true) + F(X) = luxchain(X, ps, st)[1] + + @info("Tesing L = $L O(3) equivariance") + for _ = 1:30 + local X, θ, Q, QX + X = [ @SVector(rand(3)) for i in 1:10 ] + θ = rand() * 2pi + Q = RotXYZ(0, 0, θ) + # Q = rand_rot() + QX = [SVector{3}(x) for x in Ref(Q) .* X] + D = wignerD(L, 0, 0, θ) + if length(F(X)) == 0 + continue + end + print_tf(@test F(X) ≈ F(QX)) + + end + println() +end +