From ca26fbda11bd851f53747f6e814c52f51b7797ba Mon Sep 17 00:00:00 2001 From: ACEsuit Date: Tue, 14 Nov 2023 21:54:34 -0800 Subject: [PATCH] more cleanup, towards gradients --- src/UltraFastACE.jl | 6 +++ src/splines.jl | 28 ++++---------- src/uface.jl | 93 +++++++++++++++++++++++++++++++-------------- 3 files changed, 79 insertions(+), 48 deletions(-) diff --git a/src/UltraFastACE.jl b/src/UltraFastACE.jl index e12709f..ed84990 100644 --- a/src/UltraFastACE.jl +++ b/src/UltraFastACE.jl @@ -1,5 +1,9 @@ module UltraFastACE +import ACEbase +import ACEbase: evaluate, evaluate!, + evaluate_ed, evaluate_ed! + _i2z(obj, i::Integer) = obj._i2z[i] function _z2i(obj, Z) @@ -19,6 +23,8 @@ include("splines.jl") include("convert_c2r.jl") +include("auxiliary.jl") + include("uface.jl") end diff --git a/src/splines.jl b/src/splines.jl index fe56336..8abcc23 100644 --- a/src/splines.jl +++ b/src/splines.jl @@ -22,28 +22,16 @@ struct SplineRadials{SPL, N} end -function evaluate(basis::SplineRadials, rij, Zj) - i_Zj = _z2i(basis, Zj) - spl_j = basis.spl[i_Zj] - return spl_j(rij) +function evaluate(ace, basis::SplineRadialsZ, + Rs::AbstractVector{<: SVector}, Zs::AbstractVector) + TF = eltype(eltype(Rs)) + Rn = acquire!(ace.pool, :Rn, (length(Rs), length(basis)), TF) + evaluate!(Rn, basis, Rs, Zs) + return Rn end -function evaluate(basis::SplineRadialsZ, rij, Zj) - i_Zj = _z2i(basis, Zj) - spl_j = basis.spl[i_Zj] - return SparseStaticArray(basis.idx[i_Zj], spl_j(rij)) -end - -function evaluate!(out, basis::SplineRadials, Rs, Zs) - @inbounds for ij = 1:length(Rs) - rij = norm(Rs[ij]) - zj = Zs[ij] - i_zj = _z2i(basis, zj) - spl_ij = basis.spl[i_zj] - out[ij, :] .= spl_j(rij) - return out - end -end +# Rn = acquire!(ace.pool, :Rn, (length(Rs), length(rbasis)), TF) +# evaluate!(Rn, rbasis, Rs, Zs) function evaluate!(out, basis::SplineRadialsZ, Rs, Zs) nX = length(Rs) diff --git a/src/uface.jl b/src/uface.jl index edbf96c..fbefc04 100644 --- a/src/uface.jl +++ b/src/uface.jl @@ -1,15 +1,10 @@ import Polynomials4ML import ACEpotentials using Interpolations, ObjectPools -import SpheriCart -import SpheriCart: SphericalHarmonics, compute! import ACEpotentials.ACE1 import ACEpotentials.ACE1: AtomicNumber using LinearAlgebra: norm -import ACEbase -import ACEbase: evaluate - const C2R = ConvertC2R const P4ML = Polynomials4ML @@ -41,43 +36,86 @@ function ACEbase.evaluate(ace::UFACE, Rs, Zs, zi) end -_get_L(ybasis::SphericalHarmonics{L}) where {L} = L -_len_ylm(ybasis) = (_get_L(ybasis) + 1)^2 - -function embed_z!(Ez, rbasis, Zs) - fill!(Ez, 0) - for (j, z) in enumerate(Zs) - iz = _z2i(rbasis, z) - Ez[j, iz] = 1 - end - return Ez -end function ACEbase.evaluate(ace::UFACE_inner, Rs, Zs) TF = eltype(eltype(Rs)) rbasis = ace.rbasis NZ = length(rbasis._i2z) - - _spl(rbasis, z) = rbasis.spl[UltraFastACE._z2i(rbasis, z)] # embeddings # element embedding - Ez = acquire!(ace.pool, :Ez, (length(Zs), NZ), TF) - embed_z!(Ez, rbasis, Zs) + Ez = embed_z(ace, Rs, Zs) + # radial embedding + Rn = evaluate(ace, rbasis, Rs, Zs) + # angular embedding + Zlm = evaluate_ylm(ace, Rs) + + # pooling + A = ace.abasis((Ez, Rn, Zlm)) + + # n correlations + φ = ace.aadot(A) + + # release the borrowed arrays + release!(Zlm) + release!(Rn) + release!(Ez) + release!(A) + + return φ +end + + +function ACEbase.evaluate_ed!(∇φ, ace::UFACE_inner, Rs, Zs) + TF = eltype(eltype(Rs)) + rbasis = ace.rbasis + NZ = length(rbasis._i2z) + + # embeddings + # element embedding (there is no gradient) + Ez = embed_z(ace, Rs, Zs) # radial embedding - Rn = acquire!(ace.pool, :Rn, (length(Rs), length(rbasis)), TF) - evaluate!(Rn, rbasis, Rs, Zs) + Rn, dRn = evaluate_ed(ace, rbasis, Rs, Zs) # angular embedding - Zlm = acquire!(ace.pool, :Zlm, (length(Rs), _len_ylm(ace.ybasis)), TF) - compute!(Zlm, ace.ybasis, Rs) + Zlm, dZlm = evaluate_ylm_ed(ace, Rs) # pooling A = ace.abasis((Ez, Rn, Zlm)) # n correlations - φ = ace.aadot(A) + φ, ∂φ_∂A = ace.aadot(A) # compute with gradient + + # backprop through A + ∂φ_∂Ez = BlackHole(TF) + ∂φ_∂Rn = acquire!(ace.pool, :∂Rn, size(Rn), TF) + ∂φ_∂Zlm = acquire!(ace.pool, :∂Zlm, size(Zlm), TF) + P4ML._pullback_evaluate!((∂φ_∂Ez, ∂φ_∂Rn, ∂φ_∂Zlm), ∂φ_∂A, + ace.abasis, (Ez, Rn, Zlm)) + + # backprop through the embeddings + # depending on whether there is a bottleneck here, this can be + # potentially implemented more efficiently without needing writing/reading + # (to be investigated where the bottleneck is) + + # we just ignore Ez (hence the black hole) + + # backprop through Rn + # We already computed the gradients in the forward pass + fill!(∇φ, zero(SVector{3, TF})) + for n = 1:size(Rn, 2) + for j = 1:length(Rs) + ∇φ[j] += ∂φ_∂Rn[j, n] * dRn[j, n] + end + end + + # ... and Ylm + for i_lm = 1:size(Zlm, 2) + for j = 1:length(Rs) + ∇φ[j] += ∂φ_∂Zlm[j, i_lm] * dZlm[j, i_lm] + end + end # release the borrowed arrays release!(Zlm) @@ -85,12 +123,12 @@ function ACEbase.evaluate(ace::UFACE_inner, Rs, Zs) release!(Ez) release!(A) - return φ + return φ, ∇φ end # ------------------------------------------------------ -# transformation code +# transformation code : ACE1 -> UF_ACE models function make_radial_splines(Rn_basis, zlist; npoints = 100) @assert Rn_basis.envelope isa ACEpotentials.ACE1.OrthPolys.OneEnvelope @@ -185,7 +223,6 @@ function uface_from_ace1_inner(mbpot, iz; n_spl_points = 100) c_r_iz = AA_transform[:T]' * mbpot.coeffs[iz] aadot = generate_AA_dot(spec_AA_inds, c_r_iz) - return UFACE_inner(rbasis_new, rYlm_basis_sc, A_basis, aadot) end