Skip to content

Commit

Permalink
many bugfixes related committees
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Sep 15, 2024
1 parent 4f65594 commit b0f0fac
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 7,822 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
7,882 changes: 65 additions & 7,817 deletions examples/Tutorial/ACEpotentials-Tutorial.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/ace1_compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ function ace1_model(; kwargs...)
end

model_spec = Dict{Symbol, Any}(pairs(kwargs)...)
model_spec[:Eref] = ACEpotentials.Models._convert_E0s(kwargs[:Eref])
if haskey(model_spec, :Eref)
model_spec[:Eref] = ACEpotentials.Models._convert_E0s(kwargs[:Eref])
end
model_spec[:model_name] = "ACE1"

kwargs = _clean_args(kwargs)
Expand Down
20 changes: 20 additions & 0 deletions src/models/calculators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,23 @@ function energy_forces_virial_basis(

return (energy = E, forces = F, virial = V)
end


function potential_energy_basis(at, calc::ACEPotential{<: ACEModel},
ps = calc.ps, st = calc.st;
domain = 1:length(at),
nlist = PairList(at, cutoff_radius(calc)),
kwargs...)
N_basis = length_basis(calc)
_e0 = AtomsCalculators.zero_energy(at, calc)
T = typeof(ustrip(_e0))
E = fill(zero(T) * energy_unit(calc), N_basis)

for i in domain
Js, Rs, Zs, z0 = get_neighbours(at, calc, nlist, i)
v = evaluate_basis(calc.model, Rs, Zs, z0, ps, st)
E += v * energy_unit(calc)
end

return E
end
28 changes: 25 additions & 3 deletions src/models/committee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ function co_length(model::ACEPotential)
end

function committee(f, sys::AbstractSystem, model::ACEPotential)
E = f(sys, model)
if f == potential_energy
return co_potential_energy(sys, model)
end
F = f(sys, model)
ps0 = model.ps
co_E = [ (model.ps = model.co_ps[i]; f(sys, model))
co_F = [ (model.ps = model.co_ps[i]; f(sys, model))
for i = 1:length(model.co_ps) ]
model.ps = ps0
return E, co_E
return F, co_F
end

macro committee(ex)
Expand All @@ -44,3 +47,22 @@ macro committee(ex)
committee($(esc_args...))
end
end

function co_potential_energy(sys::AbstractSystem, model::ACEPotential)
basis_E = potential_energy_basis(sys, model)
eref = potential_energy(sys, model.model.Vref) * u"eV"
E = dot(basis_E, destructure(model.ps)[1]) + eref
co_E = [ dot(basis_E, destructure(model.co_ps[i])[1]) + eref
for i = 1:length(model.co_ps) ]
return E, co_E
end

function co_potential_energy_2(sys::AbstractSystem, model::ACEPotential)
f = potential_energy
F = f(sys, model)
ps0 = model.ps
co_F = [ (model.ps = model.co_ps[i]; f(sys, model))
for i = 1:length(model.co_ps) ]
model.ps = ps0
return F, co_F
end
13 changes: 13 additions & 0 deletions test/test_bugs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,16 @@ maxdiff = maximum(abs(E_per_at[i] - E_per_at[j]) for i = 1:10, j = 1:10 )
@test ustrip(u"eV", maxdiff) < 1e-9

@info(" ============================================================")
@info(" ============== Testing for no Eref bug ====================")

# there was never an issue filed for this, but it was an annoying issue
# that came up twice by making changes in the model construction heuristics

params1 = (elements = [:Si], rcut = 5.5, order = 3, totaldegree = 12)
params2 = (; :Eref => [:Si => 0.0], pairs(params1)...)
model1 = ace1_model(; params1...)
model2 = ace1_model(; params2...)
sys = bulk(:Si, cubic=true) * 2
println_slim(@test potential_energy(sys, model1) == potential_energy(sys, model2) == 0.0u"eV")

@info(" ============================================================")
15 changes: 14 additions & 1 deletion test/test_silicon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ println_slim(@test m2.ps == model.ps)

##

info("Fit a potential with committee")
@info("Fit a potential with committee")

co_size = 10
solver = ACEfit.BLR(factorization = :svd, committee_size = co_size)
Expand All @@ -116,6 +116,19 @@ acefit!(data, model;

println_slim(@test length(model.co_ps) == co_size)

E, co_E = @committee potential_energy(data[3], model)
E
co_E

using LinearAlgebra
M = ACEpotentials.Models
efv = M.energy_forces_virial_basis(data[3], model)
e = M.potential_energy_basis(data[3], model)
println_slim(@test all(efv.energy .≈ e))
e1, co_e1 = @committee potential_energy(data[3], model)
e2, co_e2 = M.co_potential_energy_2(data[3], model)
println_slim(@test e1 e2)
println_slim(@test all(co_e1 .≈ co_e2))

##
# Add a descriptor test
Expand Down

0 comments on commit b0f0fac

Please sign in to comment.