Skip to content

Commit

Permalink
removed all allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed Nov 15, 2023
1 parent ea09a06 commit f35ce05
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
8 changes: 7 additions & 1 deletion src/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@ struct SparseStaticArray{N, T}
end


struct SplineRadialsZ{SPL, N}
struct SplineRadialsZ{SPL, N, LEN}
_i2z::NTuple{N, Int}
spl::NTuple{N, SPL}
end

SplineRadialsZ(_i2z::NTuple{N, Int}, spl::NTuple{N, SPL}, LEN
) where {N, SPL} =
SplineRadialsZ{SPL, N, LEN}(_i2z, spl)

Base.length(basis::SplineRadialsZ{SPL, N, LEN}) where {SPL, N, LEN} = LEN

struct SplineRadials{SPL, N}
_i2z::NTuple{N, Int}
spl::NTuple{N, SPL}
Expand Down
39 changes: 34 additions & 5 deletions src/uface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Polynomials4ML
import ACEpotentials
using Interpolations, ObjectPools
import SpheriCart
import SpheriCart: SphericalHarmonics, compute!
import ACEpotentials.ACE1
import ACEpotentials.ACE1: AtomicNumber
using LinearAlgebra: norm
Expand Down Expand Up @@ -40,22 +41,48 @@ function ACEbase.evaluate(ace::UFACE, Rs, Zs, zi)
end


_get_L(ybasis::SphericalHarmonics{L}) where {L} = L
_len_ylm(ybasis) = (_get_L(ybasis) + 1)^2

function ACEbase.evaluate(ace::UFACE_inner, Rs, Zs)

TF = eltype(eltype(Rs))
rbasis = ace.rbasis
NZ = length(rbasis._i2z)

_spl(rbasis, z) = rbasis.spl[UltraFastACE._z2i(rbasis, z)]

# embeddings
Ez = reduce(vcat, [ SVector((z .== rbasis._i2z)...)' for z in Zs ])
Rn = reduce(vcat, [ _spl(rbasis, Zs[j])(norm(Rs[j]))' for j = 1:length(Rs) ])
Zlm = ace.ybasis(Rs)
# Ez = reduce(vcat, [ SVector((z .== rbasis._i2z)...)' for z in Zs ])
Ez = acquire!(ace.pool, :Ez, (length(Zs), NZ), UInt8)
fill!(Ez, 0)
for (j, z) in enumerate(Zs)
iz = _z2i(rbasis, z)
Ez[j, iz] = 1
end

# Rn = reduce(vcat, [ _spl(rbasis, Zs[j])(norm(Rs[j]))' for j = 1:length(Rs) ])
Rn = acquire!(ace.pool, :Rn, (length(Rs), length(rbasis)), TF)
for (j, z) in enumerate(Zs)
spl_j = _spl(rbasis, z)
Rn[j, :] .= spl_j(norm(Rs[j]))
end

# Zlm = ace.ybasis(Rs)
Zlm = acquire!(ace.pool, :Zlm, (length(Rs), _len_ylm(ace.ybasis)), TF)
compute!(Zlm, ace.ybasis, Rs)

# pooling
A = ace.abasis((Ez, Rn, Zlm))

# n correlations
φ = ace.aadot(A)

# release the borrowed arrays
release!(Zlm)
release!(Rn)
release!(Ez)
release!(A)

return φ
end

Expand Down Expand Up @@ -100,9 +127,11 @@ function uface_from_ace1_inner(mbpot, iz; n_spl_points = 100)

# radial embedding
Rn_basis = mbpot.pibasis.basis1p.J
LEN_Rn = length(Rn_basis.J)
spl = make_radial_splines(Rn_basis, zlist; npoints = n_spl_points)
rbasis_new = SplineRadialsZ(Int.(t_zlist),
ntuple(iz1 -> spl[(zlist[iz1], z0)], length(zlist)))
ntuple(iz1 -> spl[(zlist[iz1], z0)], length(zlist)),
LEN_Rn)
# P4ML style spec of radial embedding
spec2i_Rn = 1:length(Rn_basis.J)

Expand Down

0 comments on commit f35ce05

Please sign in to comment.