diff --git a/src/uface.jl b/src/uface.jl index 7bee757..b1be9d9 100644 --- a/src/uface.jl +++ b/src/uface.jl @@ -27,16 +27,16 @@ UFACE_inner(rbasis, ybasis, abasis, aadot) = TSafe(ArrayPool(FlexArrayCache)), Dict()) -struct UFACE{N, TR, TY, TA, TAA} - _i2z::NTuple{N, Int} - ace_inner::NTuple{N, UFACE_inner{TR, TY, TA, TAA}} +struct UFACE{NZ, INNER} + _i2z::NTuple{NZ, Int} + ace_inner::INNER end -function evaluate(ace::UFACE, Rs, Zs, zi) +function ACEbase.evaluate(ace::UFACE, Rs, Zs, zi) i_zi = _z2i(ace, zi) ace_inner = ace.ace_inner[i_zi] - return evaluate(ace_inner, Rs, Zs) + return ACEbase.evaluate(ace_inner, Rs, Zs) end @@ -157,3 +157,13 @@ function uface_from_ace1_inner(mbpot, iz; n_spl_points = 100) return UFACE_inner(rbasis_new, rYlm_basis_sc, A_basis, aadot) end + + +function uface_from_ace1(mbpot; n_spl_points = 100) + NZ = length(mbpot.pibasis.zlist) + _i2z = tuple(Int.(mbpot.pibasis.zlist.list)...) + ace_inner = tuple( + [ uface_from_ace1_inner(mbpot, iz; n_spl_points = n_spl_points) + for iz = 1:NZ ]... ) + return UFACE(_i2z, ace_inner) +end \ No newline at end of file diff --git a/test/test_import_ace1.jl b/test/test_import_ace1.jl index 1002b6a..e0ce56c 100644 --- a/test/test_import_ace1.jl +++ b/test/test_import_ace1.jl @@ -12,10 +12,10 @@ model = acemodel(; elements = elements, order = 3, totaldegree = 10) pot = model.potential mbpot = pot.components[2] -## - +# conver to UFACE format ace1 = UltraFastACE.uface_from_ace1_inner(mbpot, 1; n_spl_points = 10_000) +uf_ace = UltraFastACE.uface_from_ace1(mbpot; n_spl_points = 10_000) ## ------------------------------------ @@ -24,14 +24,15 @@ for ntest = 1:30 Nat = 12; r0 = 0.9 * rnn(:Si); r1 = 1.3 * rnn(:Si) Rs = [ (r0 + (r1 - r0) * rand()) * ACE1.Random.rand_sphere() for _=1:Nat ] iz0 = 1 - z0 = JuLIP.Potentials.i2z(mbpot, 1) + z0 = rand(AtomicNumber.(elements)) # JuLIP.Potentials.i2z(mbpot, 1) Zs = [ rand(AtomicNumber.(elements)) for _ = 1:Nat ] - v1 = ACEbase.evaluate(mbpot, Rs, Zs, z0) - v2 = ACEbase.evaluate(ace1, Rs, Zs) + v1 = evaluate(mbpot, Rs, Zs, z0) + v2 = evaluate(uf_ace, Rs, Zs, z0) print_tf( @test abs(v1 - v2) / (abs(v1) + abs(v2)) < 1e-10 ) end println() +