Skip to content

Commit

Permalink
bugfixes to enable duals
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed Jun 19, 2024
1 parent 6c54b84 commit c8d2aa8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ end
function eval_and_grad!(∇φ_A, aadot::AADot, A)
φ = aadot(A)
∇φ_A_1 = P4ML._pb_evaluate(aadot.aabasis, aadot.cc, A)
∇φ_A .= unwrap(∇φ_A_1)
# ∇φ_A .= unwrap(∇φ_A_1)
for n = 1:length(A)
∇φ_A[n] = ∇φ_A_1[n]
end
release!(∇φ_A_1)
return φ
end
9 changes: 7 additions & 2 deletions src/ncorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ function generate_AA_dot(spec, c)
return Polynomial(dynamic_poly)
end

function eval_and_grad!(∇_A, aadot, A)
return evaluate_and_gradient!(∇_A, aadot, A)
function eval_and_grad!(∇φ_A, aadot, A)
# evaluate_and_gradient!(∇_A, aadot, A)
φ, ∇φ_A_1 = evaluate_and_gradient(aadot, A)
for n = 1:length(A)
∇φ_A[n] = ∇φ_A_1[n]
end
return φ
end
7 changes: 5 additions & 2 deletions src/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ function evaluate_ed!(Rn, dRn, basis::SplineRadialsZ, Rs, Zs)
zj = Zs[ij]
i_zj = _z2i(basis, zj)
spl_ij = basis.spl[i_zj]
Rn[ij, :] .= spl_ij(rij)
# Rn[ij, :] .= spl_ij(rij)
Rn_ij = spl_ij(rij)
g = Interpolations.gradient1(spl_ij, rij)
for n = 1:length(g)
@assert length(Rn_ij) == length(g)
for n = 1:length(Rn_ij)
Rn[ij, n] = Rn_ij[n]
dRn[ij, n] = g[n] * 𝐫̂ij
end
end
Expand Down

0 comments on commit c8d2aa8

Please sign in to comment.