Skip to content

Commit

Permalink
finish outer ace evaluation too
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed Nov 14, 2023
1 parent c1c878b commit 6f65cc0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
20 changes: 15 additions & 5 deletions src/uface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ UFACE_inner(rbasis, ybasis, abasis, aadot) =
TSafe(ArrayPool(FlexArrayCache)),
Dict())

struct UFACE{N, TR, TY, TA, TAA}
_i2z::NTuple{N, Int}
ace_inner::NTuple{N, UFACE_inner{TR, TY, TA, TAA}}
struct UFACE{NZ, INNER}
_i2z::NTuple{NZ, Int}
ace_inner::INNER
end


function evaluate(ace::UFACE, Rs, Zs, zi)
function ACEbase.evaluate(ace::UFACE, Rs, Zs, zi)
i_zi = _z2i(ace, zi)
ace_inner = ace.ace_inner[i_zi]
return evaluate(ace_inner, Rs, Zs)
return ACEbase.evaluate(ace_inner, Rs, Zs)
end


Expand Down Expand Up @@ -157,3 +157,13 @@ function uface_from_ace1_inner(mbpot, iz; n_spl_points = 100)

return UFACE_inner(rbasis_new, rYlm_basis_sc, A_basis, aadot)
end


function uface_from_ace1(mbpot; n_spl_points = 100)
NZ = length(mbpot.pibasis.zlist)
_i2z = tuple(Int.(mbpot.pibasis.zlist.list)...)
ace_inner = tuple(
[ uface_from_ace1_inner(mbpot, iz; n_spl_points = n_spl_points)
for iz = 1:NZ ]... )
return UFACE(_i2z, ace_inner)
end
11 changes: 6 additions & 5 deletions test/test_import_ace1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ model = acemodel(; elements = elements, order = 3, totaldegree = 10)
pot = model.potential
mbpot = pot.components[2]

##

# conver to UFACE format
ace1 = UltraFastACE.uface_from_ace1_inner(mbpot, 1; n_spl_points = 10_000)

uf_ace = UltraFastACE.uface_from_ace1(mbpot; n_spl_points = 10_000)

## ------------------------------------

Expand All @@ -24,14 +24,15 @@ for ntest = 1:30
Nat = 12; r0 = 0.9 * rnn(:Si); r1 = 1.3 * rnn(:Si)
Rs = [ (r0 + (r1 - r0) * rand()) * ACE1.Random.rand_sphere() for _=1:Nat ]
iz0 = 1
z0 = JuLIP.Potentials.i2z(mbpot, 1)
z0 = rand(AtomicNumber.(elements)) # JuLIP.Potentials.i2z(mbpot, 1)
Zs = [ rand(AtomicNumber.(elements)) for _ = 1:Nat ]

v1 = ACEbase.evaluate(mbpot, Rs, Zs, z0)
v2 = ACEbase.evaluate(ace1, Rs, Zs)
v1 = evaluate(mbpot, Rs, Zs, z0)
v2 = evaluate(uf_ace, Rs, Zs, z0)

print_tf(
@test abs(v1 - v2) / (abs(v1) + abs(v2)) < 1e-10
)
end
println()

0 comments on commit 6f65cc0

Please sign in to comment.