Skip to content

Commit

Permalink
finish import and test of inner
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed Nov 14, 2023
1 parent a4a196a commit c1c878b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 182 deletions.
36 changes: 30 additions & 6 deletions src/uface.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import Polynomials4ML
import ACEpotentials
using Interpolations
using Interpolations, ObjectPools
import SpheriCart
import ACEpotentials.ACE1
import ACEpotentials.ACE1: AtomicNumber
# using SpheriCart: SphericalHarmonics, compute
using LinearAlgebra: norm

C2R = ConvertC2R
P4ML = Polynomials4ML
import ACEbase
import ACEbase: evaluate


const C2R = ConvertC2R
const P4ML = Polynomials4ML

struct UFACE_inner{TR, TY, TA, TAA}
rbasis::TR
ybasis::TY
abasis::TA
aadot::TAA
pool::TSafe{ArrayPool{FlexArrayCache}}
meta::Dict
end

UFACE_inner(rbasis, ybasis, abasis, aadot) =
UFACE_inner(rbasis, ybasis, abasis, aadot,
TSafe(ArrayPool(FlexArrayCache)),
Dict())

struct UFACE{N, TR, TY, TA, TAA}
_i2z::NTuple{N, Int}
ace_inner::NTuple{N, UFACE_inner{TR, TY, TA, TAA}}
Expand All @@ -30,9 +40,23 @@ function evaluate(ace::UFACE, Rs, Zs, zi)
end


function evaluate(ace::UFACE_inner, Rs, Zs)
function ACEbase.evaluate(ace::UFACE_inner, Rs, Zs)

rbasis = ace.rbasis
_spl(rbasis, z) = rbasis.spl[UltraFastACE._z2i(rbasis, z)]

# embeddings
Ez = reduce(vcat, [ SVector((z .== rbasis._i2z)...)' for z in Zs ])
Rn = reduce(vcat, [ _spl(rbasis, Zs[j])(norm(Rs[j]))' for j = 1:length(Rs) ])
Zlm = ace.ybasis(Rs)

# pooling
A = ace.abasis((Ez, Rn, Zlm))

# n correlations
φ = ace.aadot(A)

return φ
end


Expand Down Expand Up @@ -131,5 +155,5 @@ function uface_from_ace1_inner(mbpot, iz; n_spl_points = 100)
aadot = generate_AA_dot(spec_AA_inds, c_r_iz)


return UFACE_inner(rbasis_new, rYlm_basis_sc, A_basis, aadot, Dict()), AA_transform
return UFACE_inner(rbasis_new, rYlm_basis_sc, A_basis, aadot)
end
176 changes: 0 additions & 176 deletions test/test_import.jl

This file was deleted.

37 changes: 37 additions & 0 deletions test/test_import_ace1.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

using ACEpotentials, StaticArrays, BenchmarkTools,
LinearAlgebra, UltraFastACE, Test
using ACEbase: evaluate
using ACEbase.Testing: print_tf

##

elements = [:Si,:O]

model = acemodel(; elements = elements, order = 3, totaldegree = 10)
pot = model.potential
mbpot = pot.components[2]

##

ace1 = UltraFastACE.uface_from_ace1_inner(mbpot, 1; n_spl_points = 10_000)


## ------------------------------------

@info("Test Consistency of ACE1 with UFACE")
for ntest = 1:30
Nat = 12; r0 = 0.9 * rnn(:Si); r1 = 1.3 * rnn(:Si)
Rs = [ (r0 + (r1 - r0) * rand()) * ACE1.Random.rand_sphere() for _=1:Nat ]
iz0 = 1
z0 = JuLIP.Potentials.i2z(mbpot, 1)
Zs = [ rand(AtomicNumber.(elements)) for _ = 1:Nat ]

v1 = ACEbase.evaluate(mbpot, Rs, Zs, z0)
v2 = ACEbase.evaluate(ace1, Rs, Zs)

print_tf(
@test abs(v1 - v2) / (abs(v1) + abs(v2)) < 1e-10
)
end
println()

0 comments on commit c1c878b

Please sign in to comment.