Skip to content

Commit

Permalink
Merge pull request #269 from ACEsuit/co/fixes
Browse files Browse the repository at this point in the history
Small fixes
  • Loading branch information
cortner authored Sep 20, 2024
2 parents ec34a73 + 8161809 commit bddf71d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
27 changes: 21 additions & 6 deletions src/ace1_compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,19 @@ function _get_all_rcut(kwargs; _rcut = kwargs[:rcut])
end


function _rin0cuts_rcut(zlist, cutoffs::Dict)
function _rin0cuts_rcut(zlist, cutoffs::Dict, kwargs = nothing)
function _get_r0(zi, zj)
if kwargs == nothing
return DefaultHypers.bond_len(zi, zj)
elseif kwargs[:r0] == :bondlen
return DefaultHypers.bond_len(zi, zj)
elseif kwargs[:r0] isa Number
return kwargs[:r0]
end
error("Cannot determine r0($zi, $zj) from the arguments provided.")
end
function rin0cut(zi, zj)
r0 = DefaultHypers.bond_len(zi, zj)
r0 = _get_r0(zi, zj)
rin, rcut = cutoffs[zi, zj]
return (rin = rin, r0 = r0, rcut = rcut)
end
Expand All @@ -166,18 +176,23 @@ function _rin0cuts_rcut(zlist, cutoffs::Dict)
end


function _ace1_rin0cuts(kwargs; rcutkey = :rcut)
function _ace1_rin0cuts(kwargs; rcutkey = :rcut, rinkey = :rin)
elements = _get_elements(kwargs)
rcut = _get_all_rcut(kwargs; _rcut = kwargs[rcutkey])
if kwargs[:rin] isa Number
rin = kwargs[:rin]
else
error("Cannot read rin; please provide a number of file an issue if a more general mechanism is needed.")
end
if rcut isa Number
cutoffs = Dict([ (s1, s2) => (0.0, rcut) for s1 in elements, s2 in elements]...)
cutoffs = Dict([ (s1, s2) => (rin, rcut) for s1 in elements, s2 in elements]...)
else
cutoffs = Dict([ (s1, s2) => (0.0, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...)
cutoffs = Dict([ (s1, s2) => (rin, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...)
end
# rcut = maximum(values(rcut)) # multitransform wants a single cutoff.

# construct the rin0cut structures
rin0cuts = _rin0cuts_rcut(elements, cutoffs)
rin0cuts = _rin0cuts_rcut(elements, cutoffs, kwargs)
end


Expand Down
4 changes: 1 addition & 3 deletions src/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ export acefit!, assemble, compute_errors

_get_Vref(model::ACEPotential) = model.model.Vref

__set_params!(model::ACEPotential, coeffs) = ACEpotentials.Models.set_parameters!(model, coeffs)

default_weights() = Dict("default"=>Dict("E"=>30.0, "F"=>1.0, "V"=>1.0))

function _make_prior(model::ACEpotentials.Models.ACEPotential, smoothness, P)
Expand Down Expand Up @@ -163,7 +161,7 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model;
coeffs = P \ result["C"]

# dispatch setting of parameters
__set_params!(model, coeffs)
ACEpotentials.Models.set_linear_parameters!(model, coeffs)

if haskey(result, "committee")
co_coeffs = result["committee"]
Expand Down
11 changes: 11 additions & 0 deletions src/models/calculators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ function set_parameters!(V::ACEPotential, θ::AbstractVector)
return set_parameters!(V, ps)
end

function set_linear_parameters!(V::ACEPotential{<: ACEModel}, θ::AbstractVector)
ps = V.ps
ps1 = (WB = ps.WB, Wpair = ps.Wpair,)
ps1_vec, _restruct = destructure(ps1)
ps2 = _restruct(θ)
ps3 = deepcopy(ps)
ps3.WB[:] = ps2.WB
ps3.Wpair[:] = ps2.Wpair
return set_parameters!(V, ps3)
end

# ---------------------------------------------------------------
# AtomsCalculatorsUtilities / SitePotential based implementation
#
Expand Down

0 comments on commit bddf71d

Please sign in to comment.