Skip to content

Commit

Permalink
draft gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed Nov 15, 2023
1 parent ca26fbd commit a6c4fae
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/ncorr.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@


using StaticArrays, OffsetArrays, StaticPolynomials
using StaticArrays, OffsetArrays
using StaticPolynomials: Polynomial, evaluate_and_gradient
using DynamicPolynomials: @polyvar

"""
Expand Down
36 changes: 34 additions & 2 deletions src/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ function evaluate(ace, basis::SplineRadialsZ,
return Rn
end

# Rn = acquire!(ace.pool, :Rn, (length(Rs), length(rbasis)), TF)
# evaluate!(Rn, rbasis, Rs, Zs)

function evaluate!(out, basis::SplineRadialsZ, Rs, Zs)
nX = length(Rs)
Expand All @@ -50,3 +48,37 @@ function evaluate!(out, basis::SplineRadialsZ, Rs, Zs)
return out
end


function evaluate_ed(ace, rbasis, Rs, Zs)
TF = eltype(eltype(Rs))
Rn = acquire!(ace.pool, :Rn, (length(Rs), length(rbasis)), TF)
dRn = acquire!(ace.pool, :dRn, (length(Rs), length(rbasis)), SVector{3, TF})
evaluate_ed!(Rn, dRn, rbasis, Rs, Zs)
return Rn, dRn
end


function evaluate_ed!(Rn, dRn, basis::SplineRadialsZ, Rs, Zs)
nX = length(Rs)
len = length(basis)
@assert length(Zs) >= nX
@assert size(Rn, 1) >= nX
@assert size(Rn, 2) >= len
@assert size(dRn, 1) >= nX
@assert size(dRn, 2) >= len

for ij = 1:nX
rij = norm(Rs[ij])
𝐫̂ij = Rs[ij] / rij
zj = Zs[ij]
i_zj = _z2i(basis, zj)
spl_ij = basis.spl[i_zj]
Rn[ij, :] .= spl_ij(rij)
g = Interpolations.gradient1(spl_ij, rij)
for n = 1:length(g)
dRn[ij, n] = g[n] * 𝐫̂ij
end
end
return nothing
end

10 changes: 8 additions & 2 deletions src/uface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ function ACEbase.evaluate(ace::UFACE_inner, Rs, Zs)
end


function ACEbase.evaluate_ed!(∇φ, ace::UFACE, Rs, Zs, z0)
i_z0 = _z2i(ace, z0)
ace_inner = ace.ace_inner[i_z0]
return ACEbase.evaluate_ed!(∇φ, ace_inner, Rs, Zs)
end

function ACEbase.evaluate_ed!(∇φ, ace::UFACE_inner, Rs, Zs)
TF = eltype(eltype(Rs))
rbasis = ace.rbasis
Expand All @@ -84,8 +90,8 @@ function ACEbase.evaluate_ed!(∇φ, ace::UFACE_inner, Rs, Zs)
# pooling
A = ace.abasis((Ez, Rn, Zlm))

# n correlations
φ, ∂φ_∂A = ace.aadot(A) # compute with gradient
# n correlations # compute with gradient
φ, ∂φ_∂A = evaluate_and_gradient(ace.aadot, A)

# backprop through A
∂φ_∂Ez = BlackHole(TF)
Expand Down

0 comments on commit a6c4fae

Please sign in to comment.