Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible radial basis embedding #17

Merged
merged 18 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions examples/potential/forces.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,34 @@
using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote
using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML
using Polynomials4ML: LinearLayer, RYlmBasis, lux
using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA
using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis
rng = Random.MersenneTwister()

##

rcut = 5.5
maxL = 0
Aspec, AAspec = degord2spec(; totaldegree = 6,
order = 3,
totdeg = 6
ord = 3

fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0)
ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p
radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans())

Aspec, AAspec = degord2spec(radial; totaldegree = totdeg,
order = ord,
Lmax = maxL, )

l_basis, ps_basis, st_basis = equivariant_model(AAspec, maxL)
l_basis, ps_basis, st_basis = equivariant_model(AAspec, radial, maxL; islong = false)
X = [ @SVector(randn(3)) for i in 1:10 ]
B = l_basis(X, ps_basis, st_basis)[1][1]
B = l_basis(X, ps_basis, st_basis)[1]

# now build another model with a better transform
L = maximum(b.l for b in Aspec)
# now extend the above BB basis to a model
len_BB = length(B)
get1 = WrappedFunction(t -> t[1])
embed = Parallel(nothing;
Rn = Chain(trans = WrappedFunction(xx -> [1/(1+norm(x)) for x in xx]),
poly = l_basis.layers.embed.layers.Rn, ),
Ylm = Chain(Ylm = lux(RYlmBasis(L)), ) )

model = Chain(
embed = embed,
A = l_basis.layers.A,
AA = l_basis.layers.AA,
# AA_sort = l_basis.layers.AA_sort,
BB = l_basis.layers.BB,
get1 = WrappedFunction(t -> t[1]),
dot = LinearLayer(len_BB, 1),
get2 = WrappedFunction(t -> t[1]), )

model = append_layer(l_basis, WrappedFunction(t -> real(t)); l_name=:real)
model = append_layer(model, LinearLayer(len_BB, 1); l_name=:dot)
model = append_layer(model, WrappedFunction(t -> t[1]); l_name=:get1)

ps, st = Lux.setup(rng, model)
out, st = model(X, ps, st)

Expand Down Expand Up @@ -158,7 +154,7 @@ end

using JuLIP
JuLIP.usethreads!(false)
ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W))
ps.dot.W[:] .= 1e-2 * randn(length(ps.dot.W))

at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1)
calc = Pot.LuxCalc(model, ps, st, rcut)
Expand Down Expand Up @@ -217,4 +213,4 @@ end
loss(at, calc, p_vec)


ReverseDiff.gradient(p -> loss(at, calc, p), p_vec)
# ReverseDiff.gradient(p -> loss(at, calc, p), p_vec)
105 changes: 55 additions & 50 deletions src/builder.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra
using SparseArrays: SparseMatrixCSC, sparse
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter
using Polynomials4ML: legendre_basis, RYlmBasis, natural_indices, degree
using Polynomials4ML.Utils: gensparse
using Lux: WrappedFunction
Expand All @@ -16,6 +16,8 @@ P4ML = Polynomials4ML
RPE_filter(L) = bb -> (length(bb) == 0) || ((abs(sum(b.m for b in bb)) <= L) && iseven(sum(b.l for b in bb)+L))
RPE_filter_long(L) = bb -> (length(bb) == 0) || (abs(sum(b.m for b in bb)) <= L)

RPE_filter_real(L) = bb -> (length(bb) == 0) || mm_filter([b.m for b in bb],L) && iseven(sum(b.l for b in bb)+L)

"""
_rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T},Rot3DCoeffs_real{L,T},Rot3DCoeffs_long{L,T}},
spec::Vector{Vector{NamedTuple}})
Expand All @@ -36,16 +38,12 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro
for i = 1:length(spec)
# get the specification of the ith basis function, which is a tuple/vec of NamedTuples
pib = spec[i]
# skip it unless all m are zero, because we want to consider each
# (nn, ll) block only once.
# if !all(b.m == 0 for b in pib)
# continue
# end
# But we can not do this anymore for L≥1, so I add an nnllset

# get the rotation-coefficients for this basis group
# the bs are the basis functions corresponding to the columns

# The nnlllist is created because we want to consider each
# (nn, ll) block only once.
nn = SVector([onep.n for onep in pib]...)
ll = SVector([onep.l for onep in pib]...) # get a SVector of ll index
if haskey(pib[1],:s)
Expand Down Expand Up @@ -102,7 +100,12 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro
if !isnothing(idxAA)
push!(Irow, idxB)
push!(Jcol, idxAA)
push!(vals, U[irow, icol])
if norm(U[irow, icol] - real.(U[irow, icol]))<1e-12
push!(vals, real.(U[irow, icol]))
else
push!(vals, U[irow, icol])
end
# push!(vals, U[irow, icol])
end
end
end
Expand All @@ -119,42 +122,38 @@ end

# TODO: symmetry group O(d)?
"""
xx2AA(spec_nlm, d=3, categories=[]; radial_basis=legendre_basis)
Construct a lux chain that maps a configuration to the corresponding the AA basis
xx2AA(spec_nlm, radial; d=3, categories=[])
Construct a lux chain that maps a configuration to the corresponding AA basis
spec_nlm: Specification of the AA bases
radial : specified radial basis, with both basis and its specification
===
OptionalField:
d: Input dimension
categories : A list of categories
radial_basis : specified radial basis, default using P4ML.legendre_basis
"""
function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Configuration to AA bases - this is what all chains have in common

function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3, rSH = false) # Configuration to AA bases - this is what all chains have in common
# from spec_nlm to all possible spec1p
spec1p, lmax, nmax = specnlm2spec1p(spec_nlm)
dict_spec1p = Dict([spec1p[i] => i for i = 1:length(spec1p)])
Ylm = CYlmBasis(lmax)
Rn = radial_basis(nmax)
# TODO: make it Rnl = radial_basis(nmax,lmax)
Ylm = rSH ? RYlmBasis(lmax) : CYlmBasis(lmax)
# Rn = radial_basis(nmax)

if !isempty(categories)
# Read categories from x - TODO: discuss which format we like it to be...
# For now we just give get_cat(x) a random value
#get_cat(x) = length(categories) > 1 ? (iseven(floor(norm(x))) ? categories[1] : categories[2]) : categories[1]
#_get_cat(x) = get_cat.(x)

# Define categorical bases
δs = CategoricalBasis(categories)
l_δs = P4ML.lux(δs)
end

spec1pidx = isempty(categories) ? getspec1idx(spec1p, Rn, Ylm) : getspec1idx(spec1p, Rn, Ylm, δs)
# TODO: write getspec1idx for Rnl basis
spec1pidx = isempty(categories) ? getspec1idx(spec1p, radial.Radialspec, Ylm) : getspec1idx(spec1p, radial.Radialspec, Ylm, δs)
bA = P4ML.PooledSparseProduct(spec1pidx)

Spec = sort.([ [dict_spec1p[spec_nlm[k][j]] for j = 1:length(spec_nlm[k])] for k = 1:length(spec_nlm) ])
Spec = sort(Spec, by=length)
bAA = P4ML.SparseSymmProd(Spec)

# wrapping into lux layers
l_Rn = P4ML.lux(Rn)
l_Rnl = radial.Rnl
l_Ylm = P4ML.lux(Ylm)
l_bA = P4ML.lux(bA)
l_bAA = P4ML.lux(bAA)
Expand All @@ -169,13 +168,15 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co
_norm(x) = norm.(x)

if isempty(categories)
l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity))
l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm)
luxchain = Chain(l_xnx = l_xnx, embed = l_embed, A = l_bA , AA = l_bAA)
else
l_xnxz = Lux.BranchLayer(normx = WrappedFunction(x -> _norm(x[1])), x = WrappedFunction(x -> x[1]), catlist = WrappedFunction(x -> x[2]))
l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm, δs = l_δs)
luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA)
l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm)
luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA)
else
l_Rnl = append_layer(Chain(get_pos = get_i(1), ), l_Rnl; l_name = :radial_poly)
l_Ylm = append_layer(Chain(get_pos = get_i(1), ), l_Ylm; l_name = :angle_poly)
l_δs = append_layer(Chain(get_cat = get_i(2), ), l_δs; l_name = :categorical)

l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm, δs = l_δs)
luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA) # Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA)
end

# luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA)
Expand All @@ -192,15 +193,19 @@ L : Largest equivariance level
categories : A list of categories
radial_basis : specified radial basis, default using P4ML.legendre_basis
"""
function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true)
function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false)
if rSH && L > 0
error("rSH is only implemented (for now) for L = 0")
end

# first filt out those unfeasible spec_nlm
filter_init = islong ? RPE_filter_long(L) : RPE_filter(L)
filter_init = rSH ? RPE_filter_real(L) : (islong ? RPE_filter_long(L) : RPE_filter(L))
spec_nlm = spec_nlm[findall(x -> filter_init(x) == 1, spec_nlm)]

# sort!(spec_nlm, by = x -> length(x))
spec_nlm = closure(spec_nlm,filter_init; categories = categories)

luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis)
luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d, rSH = rSH)
F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1]

if islong
Expand All @@ -209,15 +214,15 @@ function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=
pos = Vector{Any}(undef, L+1)

for l = 0:L
filter = RPE_filter(l)
cgen = Rot3DCoeffs(l) # TODO: this should be made group related
filter = rSH ? RPE_filter_real(L) : RPE_filter(l)
cgen = rSH ? Rot3DCoeffs_real(l) : Rot3DCoeffs(l) # TODO: this should be made group related

tmp = spec_nlm[findall(x -> filter(x) == 1, spec_nlm)]
C[l+1] = _rpi_A2B_matrix(cgen, tmp)
pos[l+1] = findall(x -> filter(x) == 1, spec_nlm) # [ dict[tmp[j]] for j = 1:length(tmp)]
end
else
cgen = Rot3DCoeffs(L) # TODO: this should be made group related
cgen = rSH ? Rot3DCoeffs_real(L) : Rot3DCoeffs(L) # TODO: this should be made group related
C = _rpi_A2B_matrix(cgen, spec_nlm)
end

Expand All @@ -233,13 +238,13 @@ function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=
end

# more constructors equivariant_model
equivariant_model(totdeg::Int64, ν::Int64, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) =
equivariant_model(degord2spec(;totaldegree = totdeg, order = ν, Lmax=L, radial_basis = radial_basis, islong = islong)[2], L; categories, d, radial_basis, group, islong)
equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) =
equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH)

# With the _close function, the input could simply be an nnlllist (nlist,llist)
equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, L::Int64; categories=[], d=3, radial_basis = legendre_basis, group = "O3", islong = true) = begin
equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group = "O3", islong = true, rSH = false) = begin
filter = islong ? RPE_filter_long(L) : RPE_filter(L)
equivariant_model(_close(nn, ll; filter = filter), L; categories, d, radial_basis, group, islong)
equivariant_model(_close(nn, ll; filter = filter), radial, L; categories, d, group, islong, rSH)
end

# ===== Codes that we might remove later =====
Expand All @@ -251,14 +256,14 @@ end
# What can be adjusted in its input are: (1) total polynomial degree; (2) correlation order; (3) largest L
# (4) weight of the order of spherical harmonics; (5) specified radial basis

function equivariant_SYY_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3")
function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3")
filter_init = RPE_filter_long(L)
spec_nlm = spec_nlm[findall(x -> filter_init(x) == 1, spec_nlm)]

# sort!(spec_nlm, by = x -> length(x))
spec_nlm = closure(spec_nlm, filter_init; categories = categories)

luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis)
luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d)
F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1]

cgen = Rot3DCoeffs_long(L) # TODO: this should be made group related
Expand All @@ -274,11 +279,11 @@ function equivariant_SYY_model(spec_nlm, L::Int64; categories=[], d=3, radial_ba
return luxchain, ps, st
end

equivariant_SYY_model(totdeg::Int64, ν::Int64, L::Int64; categories=[], d=3, radial_basis = legendre_basis,group = "O3") =
equivariant_SYY_model(degord2spec(;totaldegree = totdeg, order = ν, Lmax = L, radial_basis = radial_basis, islong=true)[2], L; categories, d, radial_basis, group)
equivariant_SYY_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3,group = "O3") =
equivariant_SYY_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax = L, islong=true)[2], radial, L; categories, d, group)

equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3") =
equivariant_SYY_model(_close(nn, ll; filter = RPE_filter_long(L)), L; categories, d, radial_basis, group)
equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3") =
equivariant_SYY_model(_close(nn, ll; filter = RPE_filter_long(L)), radial, L; categories, d, group)

## TODO: The following should eventually go into ACEhamiltonians.jl rather than this package

Expand Down Expand Up @@ -312,9 +317,9 @@ function equivariant_luxchain_constructor(totdeg, ν, L; wL = 1, Rn = legendre_b

Ylm = CYlmBasis(totdeg)

spec1p = make_nlms_spec(Rn, Ylm; totaldegree = totdeg, admissible = (br, by) -> br + wL * by.l <= totdeg)
spec1p = make_nlms_spec(Radial_basis(Polynomials4ML.lux(Rn)), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg)
spec1p = sort(spec1p, by = (x -> x.n + x.l * wL))
spec1pidx = getspec1idx(spec1p, Rn, Ylm)
spec1pidx = getspec1idx(spec1p, Radial_basis(Polynomials4ML.lux(Rn)).Radialspec, Ylm)

# define sparse for n-correlations
tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ]
Expand Down Expand Up @@ -370,9 +375,9 @@ end
function equivariant_luxchain_constructor_new(totdeg, ν, L; wL = 1, Rn = legendre_basis(totdeg))
Ylm = CYlmBasis(totdeg)

spec1p = make_nlms_spec(Rn, Ylm; totaldegree = totdeg, admissible = (br, by) -> br + wL * by.l <= totdeg)
spec1p = make_nlms_spec(Radial_basis(Polynomials4ML.lux(Rn)), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg)
spec1p = sort(spec1p, by = (x -> x.n + x.l * wL))
spec1pidx = getspec1idx(spec1p, Rn, Ylm)
spec1pidx = getspec1idx(spec1p, Radial_basis(Polynomials4ML.lux(Rn)).Radialspec, Ylm)

# define sparse for n-correlations
tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ]
Expand Down
Loading
Loading