Skip to content

Commit

Permalink
more cleanup, towards gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed Nov 15, 2023
1 parent a5a61d3 commit ca26fbd
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 48 deletions.
6 changes: 6 additions & 0 deletions src/UltraFastACE.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
module UltraFastACE

import ACEbase
import ACEbase: evaluate, evaluate!,
evaluate_ed, evaluate_ed!

_i2z(obj, i::Integer) = obj._i2z[i]

function _z2i(obj, Z)
Expand All @@ -19,6 +23,8 @@ include("splines.jl")

include("convert_c2r.jl")

include("auxiliary.jl")

include("uface.jl")

end
28 changes: 8 additions & 20 deletions src/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,16 @@ struct SplineRadials{SPL, N}
end


function evaluate(basis::SplineRadials, rij, Zj)
i_Zj = _z2i(basis, Zj)
spl_j = basis.spl[i_Zj]
return spl_j(rij)
function evaluate(ace, basis::SplineRadialsZ,
Rs::AbstractVector{<: SVector}, Zs::AbstractVector)
TF = eltype(eltype(Rs))
Rn = acquire!(ace.pool, :Rn, (length(Rs), length(basis)), TF)
evaluate!(Rn, basis, Rs, Zs)
return Rn
end

function evaluate(basis::SplineRadialsZ, rij, Zj)
i_Zj = _z2i(basis, Zj)
spl_j = basis.spl[i_Zj]
return SparseStaticArray(basis.idx[i_Zj], spl_j(rij))
end

function evaluate!(out, basis::SplineRadials, Rs, Zs)
@inbounds for ij = 1:length(Rs)
rij = norm(Rs[ij])
zj = Zs[ij]
i_zj = _z2i(basis, zj)
spl_ij = basis.spl[i_zj]
out[ij, :] .= spl_j(rij)
return out
end
end
# Rn = acquire!(ace.pool, :Rn, (length(Rs), length(rbasis)), TF)
# evaluate!(Rn, rbasis, Rs, Zs)

function evaluate!(out, basis::SplineRadialsZ, Rs, Zs)
nX = length(Rs)
Expand Down
93 changes: 65 additions & 28 deletions src/uface.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import Polynomials4ML
import ACEpotentials
using Interpolations, ObjectPools
import SpheriCart
import SpheriCart: SphericalHarmonics, compute!
import ACEpotentials.ACE1
import ACEpotentials.ACE1: AtomicNumber
using LinearAlgebra: norm

import ACEbase
import ACEbase: evaluate


const C2R = ConvertC2R
const P4ML = Polynomials4ML
Expand Down Expand Up @@ -41,56 +36,99 @@ function ACEbase.evaluate(ace::UFACE, Rs, Zs, zi)
end


_get_L(ybasis::SphericalHarmonics{L}) where {L} = L
_len_ylm(ybasis) = (_get_L(ybasis) + 1)^2

function embed_z!(Ez, rbasis, Zs)
fill!(Ez, 0)
for (j, z) in enumerate(Zs)
iz = _z2i(rbasis, z)
Ez[j, iz] = 1
end
return Ez
end

function ACEbase.evaluate(ace::UFACE_inner, Rs, Zs)
TF = eltype(eltype(Rs))
rbasis = ace.rbasis
NZ = length(rbasis._i2z)

_spl(rbasis, z) = rbasis.spl[UltraFastACE._z2i(rbasis, z)]

# embeddings
# element embedding
Ez = acquire!(ace.pool, :Ez, (length(Zs), NZ), TF)
embed_z!(Ez, rbasis, Zs)
Ez = embed_z(ace, Rs, Zs)
# radial embedding
Rn = evaluate(ace, rbasis, Rs, Zs)
# angular embedding
Zlm = evaluate_ylm(ace, Rs)

# pooling
A = ace.abasis((Ez, Rn, Zlm))

# n correlations
φ = ace.aadot(A)

# release the borrowed arrays
release!(Zlm)
release!(Rn)
release!(Ez)
release!(A)

return φ
end


function ACEbase.evaluate_ed!(∇φ, ace::UFACE_inner, Rs, Zs)
TF = eltype(eltype(Rs))
rbasis = ace.rbasis
NZ = length(rbasis._i2z)

# embeddings
# element embedding (there is no gradient)
Ez = embed_z(ace, Rs, Zs)

# radial embedding
Rn = acquire!(ace.pool, :Rn, (length(Rs), length(rbasis)), TF)
evaluate!(Rn, rbasis, Rs, Zs)
Rn, dRn = evaluate_ed(ace, rbasis, Rs, Zs)

# angular embedding
Zlm = acquire!(ace.pool, :Zlm, (length(Rs), _len_ylm(ace.ybasis)), TF)
compute!(Zlm, ace.ybasis, Rs)
Zlm, dZlm = evaluate_ylm_ed(ace, Rs)

# pooling
A = ace.abasis((Ez, Rn, Zlm))

# n correlations
φ = ace.aadot(A)
φ, ∂φ_∂A = ace.aadot(A) # compute with gradient

# backprop through A
∂φ_∂Ez = BlackHole(TF)
∂φ_∂Rn = acquire!(ace.pool, :∂Rn, size(Rn), TF)
∂φ_∂Zlm = acquire!(ace.pool, :∂Zlm, size(Zlm), TF)
P4ML._pullback_evaluate!((∂φ_∂Ez, ∂φ_∂Rn, ∂φ_∂Zlm), ∂φ_∂A,
ace.abasis, (Ez, Rn, Zlm))

# backprop through the embeddings
# depending on whether there is a bottleneck here, this can be
# potentially implemented more efficiently without needing writing/reading
# (to be investigated where the bottleneck is)

# we just ignore Ez (hence the black hole)

# backprop through Rn
# We already computed the gradients in the forward pass
fill!(∇φ, zero(SVector{3, TF}))
for n = 1:size(Rn, 2)
for j = 1:length(Rs)
∇φ[j] += ∂φ_∂Rn[j, n] * dRn[j, n]
end
end

# ... and Ylm
for i_lm = 1:size(Zlm, 2)
for j = 1:length(Rs)
∇φ[j] += ∂φ_∂Zlm[j, i_lm] * dZlm[j, i_lm]
end
end

# release the borrowed arrays
release!(Zlm)
release!(Rn)
release!(Ez)
release!(A)

return φ
return φ, ∇φ
end


# ------------------------------------------------------
# transformation code
# transformation code : ACE1 -> UF_ACE models

function make_radial_splines(Rn_basis, zlist; npoints = 100)
@assert Rn_basis.envelope isa ACEpotentials.ACE1.OrthPolys.OneEnvelope
Expand Down Expand Up @@ -185,7 +223,6 @@ function uface_from_ace1_inner(mbpot, iz; n_spl_points = 100)
c_r_iz = AA_transform[:T]' * mbpot.coeffs[iz]
aadot = generate_AA_dot(spec_AA_inds, c_r_iz)


return UFACE_inner(rbasis_new, rYlm_basis_sc, A_basis, aadot)
end

Expand Down

0 comments on commit ca26fbd

Please sign in to comment.