Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant Linear Layer #18

Merged
merged 25 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
01bb9e0
Initialise ConstLinearLayer
zhanglw0521 Sep 27, 2023
9ce22da
Merge branch 'Rnl_Basis' into Const_LinearLayer
zhanglw0521 Oct 1, 2023
0047583
keep it consistent to the Radial_Basis... branch
zhanglw0521 Oct 1, 2023
adf0519
Merge branch 'Rnl_Basis' into Const_LinearLayer
zhanglw0521 Oct 2, 2023
8c5ed6c
WIP - far from complete
zhanglw0521 Oct 2, 2023
9199843
Merge branch 'Rnl_Basis' into Const_LinearLayer
zhanglw0521 Oct 4, 2023
02e303d
partially works
zhanglw0521 Oct 4, 2023
8e7fcf0
Clean up
zhanglw0521 Oct 4, 2023
53a472d
Adapt Yangshuai's energy fitting code to the latest version
zhanglw0521 Oct 4, 2023
1b945d3
Merge branch 'Rnl_Basis' into Const_LinearLayer
zhanglw0521 Oct 5, 2023
4a7a06d
introduce rpe_basis & a linear dependence test
zhanglw0521 Oct 5, 2023
7252202
Fix the linear dependence issue
zhanglw0521 Oct 5, 2023
7dc53bc
Update Project.toml
zhanglw0521 Oct 5, 2023
92b2b3d
Typo fix
zhanglw0521 Oct 5, 2023
43d1a4f
Add the corresponding tests
zhanglw0521 Oct 5, 2023
ba40276
typo fix
zhanglw0521 Oct 5, 2023
0e3b9c8
Merge pull request #19 from ACEsuit/Linear_dependence_issue
zhanglw0521 Oct 5, 2023
a115a63
A simple energy test that checks the "completeness" of the basis for …
zhanglw0521 Oct 5, 2023
b363596
Resolve most of the issues in comments
zhanglw0521 Oct 5, 2023
ae937d1
Renaming W to op to avoid ambiguity of learnable weigh and constant m…
zhanglw0521 Oct 5, 2023
149447d
get rid of the "position" projectin
zhanglw0521 Oct 5, 2023
42ad3a9
clean up
zhanglw0521 Oct 5, 2023
8e63f32
Minor revision
zhanglw0521 Oct 5, 2023
bdffa55
faster linear transformation construction
zhanglw0521 Oct 6, 2023
7223e6c
turning C * X[pos] to LO * X with LO a combined LinearOperator
zhanglw0521 Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 9 additions & 22 deletions src/ConstLinearLayer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@ import ChainRulesCore: rrule
using LuxCore
using LuxCore: AbstractExplicitLayer

struct ConstLinearLayer{T} <: AbstractExplicitLayer # where {in_dim,out_dim,T}
W::AbstractMatrix{T}
position::Union{Vector{Int64}, UnitRange{Int64}}
in_dim::Integer
out_dim::Integer
struct ConstLinearLayer <: AbstractExplicitLayer
op
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gives you a type instability. you still need to specify the type of the operator. Something like this:

struct ConstLinL{TOP} 
    op::TOP
end

end

ConstLinearLayer(W::AbstractMatrix{T}) where T = ConstLinearLayer(W,1:size(W,2),size(W,2),size(W,1))
ConstLinearLayer(W::AbstractMatrix{T}, pos::Union{Vector{Int64}, UnitRange{Int64}}) where T = ConstLinearLayer(W,pos,size(W,2),size(W,1))

(l::ConstLinearLayer)(x::AbstractVector) = l.in_dim == length(x[l.position]) ? l.W * x[l.position] : error("x (or the position index) has a wrong length!")
(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[1:size(l.op,2)]
cortner marked this conversation as resolved.
Show resolved Hide resolved

(l::ConstLinearLayer)(x::AbstractMatrix) = begin
Tmp = l(x[1,:])
Expand All @@ -22,28 +16,21 @@ ConstLinearLayer(W::AbstractMatrix{T}, pos::Union{Vector{Int64}, UnitRange{Int64
return Tmp'
end

(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st)

# NOTE: the following rrule is kept because there is a issue with SparseArray
function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector)
zhanglw0521 marked this conversation as resolved.
Show resolved Hide resolved
val = l(x)
function pb(A)
return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent()
return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
end
return val, pb
end

(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st)

function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st)
val = l(x,ps,st)
function pb(A)
return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent()
return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
end
return val, pb
end

# function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractMatrix, ps, st)
# val = l(x, ps, st)
# function pb(A)
# return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent()
# end
# return val, pb
# end
end
11 changes: 10 additions & 1 deletion src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
C = _rpi_A2B_matrix(cgen, spec_nlm)
end

l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i],pos[i]) for i = 1:L+1]... ) : ConstLinearLayer(C)
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C)
# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
# luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)
Expand All @@ -271,6 +271,15 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
return luxchain, ps, st
end

function new_sparse_matrix(C,pos)
col = maximum(pos)
C_new = sparse(zeros(typeof(C[1]),size(C,1),col))
for i = 1:size(C,1)
C_new[i,pos] = C[i,:]
end
return sparse(C_new)
end

# more constructors equivariant_model
equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) =
equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH)
Expand Down