Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant Linear Layer #18

Merged
merged 25 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
01bb9e0
Initialise ConstLinearLayer
zhanglw0521 Sep 27, 2023
9ce22da
Merge branch 'Rnl_Basis' into Const_LinearLayer
zhanglw0521 Oct 1, 2023
0047583
keep it consistent to the Radial_Basis... branch
zhanglw0521 Oct 1, 2023
adf0519
Merge branch 'Rnl_Basis' into Const_LinearLayer
zhanglw0521 Oct 2, 2023
8c5ed6c
WIP - far from complete
zhanglw0521 Oct 2, 2023
9199843
Merge branch 'Rnl_Basis' into Const_LinearLayer
zhanglw0521 Oct 4, 2023
02e303d
partially works
zhanglw0521 Oct 4, 2023
8e7fcf0
Clean up
zhanglw0521 Oct 4, 2023
53a472d
Adapt Yangshuai's energy fitting code to the latest version
zhanglw0521 Oct 4, 2023
1b945d3
Merge branch 'Rnl_Basis' into Const_LinearLayer
zhanglw0521 Oct 5, 2023
4a7a06d
introduce rpe_basis & a linear dependence test
zhanglw0521 Oct 5, 2023
7252202
Fix the linear dependence issue
zhanglw0521 Oct 5, 2023
7dc53bc
Update Project.toml
zhanglw0521 Oct 5, 2023
92b2b3d
Typo fix
zhanglw0521 Oct 5, 2023
43d1a4f
Add the corresponding tests
zhanglw0521 Oct 5, 2023
ba40276
typo fix
zhanglw0521 Oct 5, 2023
0e3b9c8
Merge pull request #19 from ACEsuit/Linear_dependence_issue
zhanglw0521 Oct 5, 2023
a115a63
A simple energy test that checks the "completeness" of the basis for …
zhanglw0521 Oct 5, 2023
b363596
Resolve most of the issues in comments
zhanglw0521 Oct 5, 2023
ae937d1
Renaming W to op to avoid ambiguity of learnable weigh and constant m…
zhanglw0521 Oct 5, 2023
149447d
get rid of the "position" projectin
zhanglw0521 Oct 5, 2023
42ad3a9
clean up
zhanglw0521 Oct 5, 2023
8e63f32
Minor revision
zhanglw0521 Oct 5, 2023
bdffa55
faster linear transformation construction
zhanglw0521 Oct 6, 2023
7223e6c
turning C * X[pos] to LO * X with LO a combined LinearOperator
zhanglw0521 Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.0.2"
ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Expand Down
36 changes: 36 additions & 0 deletions src/ConstLinearLayer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import ChainRulesCore: rrule
using LuxCore
using LuxCore: AbstractExplicitLayer

struct ConstLinearLayer <: AbstractExplicitLayer
op
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gives you a type instability. you still need to specify the type of the operator. Something like this:

struct ConstLinL{TOP} 
    op::TOP
end

end

(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[1:size(l.op,2)]
cortner marked this conversation as resolved.
Show resolved Hide resolved

(l::ConstLinearLayer)(x::AbstractMatrix) = begin
Tmp = l(x[1,:])
for i = 2:size(x,1)
Tmp = [Tmp l(x[i,:])]
end
return Tmp'
end

(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st)

# NOTE: the following rrule is kept because there is a issue with SparseArray
function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector)
zhanglw0521 marked this conversation as resolved.
Show resolved Hide resolved
val = l(x)
function pb(A)
return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
end
return val, pb
end

function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st)
val = l(x,ps,st)
function pb(A)
return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
end
return val, pb
end
52 changes: 46 additions & 6 deletions src/builder.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using LinearAlgebra
using SparseArrays: SparseMatrixCSC, sparse
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter, coco_dot
using Polynomials4ML: legendre_basis, RYlmBasis, natural_indices, degree
using Polynomials4ML.Utils: gensparse
using Lux: WrappedFunction
using Lux
using Random
using Polynomials4ML
using StaticArrays
using Combinatorics

export equivariant_model, equivariant_SYY_model, equivariant_luxchain_constructor, equivariant_luxchain_constructor_new

Expand All @@ -23,6 +24,33 @@ _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T},Rot3DCoeffs_real{L,T},Rot3DCoeffs_l
spec::Vector{Vector{NamedTuple}})
Return a sparse matrix for symmetrisation of AA basis of spec with equivariance specified by cgen
"""
function rpe_basis(A::Union{Rot3DCoeffs,Rot3DCoeffs_long,Rot3DCoeffs_real}, nn::SVector{N, TN}, ll::SVector{N, Int}) where {N, TN}
Ure, Mre = re_basis(A, ll)
G = _gramian(nn, ll, Ure, Mre)
S = svd(G)
rk = rank(Diagonal(S.S); rtol = 1e-7)
Urpe = S.U[:, 1:rk]'
return Diagonal(sqrt.(S.S[1:rk])) * Urpe * Ure, Mre
end


function _gramian(nn, ll, Ure, Mre)
N = length(nn)
nre = size(Ure, 1)
G = zeros(Complex{Float64}, nre, nre)
for σ in permutations(1:N)
if (nn[σ] != nn) || (ll[σ] != ll); continue; end
for (iU1, mm1) in enumerate(Mre), (iU2, mm2) in enumerate(Mre)
if mm1[σ] == mm2
for i1 = 1:nre, i2 = 1:nre
G[i1, i2] += coco_dot(Ure[i1, iU1], Ure[i2, iU2])
end
end
end
end
return G
end

function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Rot3DCoeffs_long{L,T}}, spec) where {L,T}
# allocate triplet format
Irow, Jcol = Int[], Int[]
Expand Down Expand Up @@ -55,7 +83,7 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro
if (nn,ll,ss) in nnllset; continue; end

# get the Mll indices and coeffs
U, Mll = re_basis(cgen, ll)
U, Mll = rpe_basis(cgen, nn, ll)
# conver the Mlls into basis functions (NamedTuples)

rpibs = [_nlms2b(nn, ll, mm, ss) for mm in Mll]
Expand Down Expand Up @@ -83,7 +111,8 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro
if (nn,ll) in nnllset; continue; end

# get the Mll indices and coeffs
U, Mll = re_basis(cgen, ll)
# U, Mll = re_basis(cgen, ll)
U, Mll = rpe_basis(cgen, nn, ll)
# conver the Mlls into basis functions (NamedTuples)

rpibs = [_nlms2b(nn, ll, mm) for mm in Mll]
Expand Down Expand Up @@ -196,6 +225,9 @@ L : Largest equivariance level
categories : A list of categories
radial_basis : specified radial basis, default using P4ML.legendre_basis
"""

include("ConstLinearLayer.jl")

function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false)
if rSH && L > 0
error("rSH is only implemented (for now) for L = 0")
Expand Down Expand Up @@ -229,8 +261,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
C = _rpi_A2B_matrix(cgen, spec_nlm)
end

l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : WrappedFunction(x -> C * x)
# TODO: make it a Const_LinearLayer instead
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C)
# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
# luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)
Expand All @@ -240,6 +271,15 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
return luxchain, ps, st
end

function new_sparse_matrix(C,pos)
col = maximum(pos)
C_new = sparse(zeros(typeof(C[1]),size(C,1),col))
for i = 1:size(C,1)
C_new[i,pos] = C[i,:]
end
return sparse(C_new)
end

# more constructors equivariant_model
equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) =
equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH)
Expand Down Expand Up @@ -271,7 +311,7 @@ function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categor

cgen = Rot3DCoeffs_long(L) # TODO: this should be made group related
C = _rpi_A2B_matrix(cgen, spec_nlm)
l_sym = WrappedFunction(x -> C * x)
l_sym = ConstLinearLayer(C)

# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
Expand Down
6 changes: 2 additions & 4 deletions src/radial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@ Radial_basis(Rnl::AbstractExplicitLayer) =
end

# it is in its current form just for the purpose of testing - a more specific example can be found in forces.jl
function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->1; spec = nothing)
function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->r; spec = nothing)
if isnothing(spec)
try
spec = natural_indices(basis)
catch
error("The specification of this Radial_basis should be given explicitly!")
end
end

f(r) = f_trans(r) * f_cut(r)

return Radial_basis(Chain(getnorm = WrappedFunction(x -> norm.(x)), trans = WrappedFunction(x -> f.(x)), poly = lux(basis), ), spec)
return Radial_basis(Chain(trans = WrappedFunction(x -> f_trans.(norm.(x))), evaluation = Lux.BranchLayer(poly = lux(basis), cutoff = WrappedFunction(x -> f_cut.(x))), env = WrappedFunction(x -> x[1].*x[2]), ), spec)
end
Loading