diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index 1a684b53..b6ddb839 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -155,9 +155,19 @@ function _get_all_rcut(kwargs; _rcut = kwargs[:rcut]) end -function _rin0cuts_rcut(zlist, cutoffs::Dict) +function _rin0cuts_rcut(zlist, cutoffs::Dict, kwargs = nothing) + function _get_r0(zi, zj) + if kwargs == nothing + return DefaultHypers.bond_len(zi, zj) + elseif kwargs[:r0] == :bondlen + return DefaultHypers.bond_len(zi, zj) + elseif kwargs[:r0] isa Number + return kwargs[:r0] + end + error("Cannot determine r0($zi, $zj) from the arguments provided.") + end function rin0cut(zi, zj) - r0 = DefaultHypers.bond_len(zi, zj) + r0 = _get_r0(zi, zj) rin, rcut = cutoffs[zi, zj] return (rin = rin, r0 = r0, rcut = rcut) end @@ -166,18 +176,23 @@ function _rin0cuts_rcut(zlist, cutoffs::Dict) end -function _ace1_rin0cuts(kwargs; rcutkey = :rcut) +function _ace1_rin0cuts(kwargs; rcutkey = :rcut, rinkey = :rin) elements = _get_elements(kwargs) rcut = _get_all_rcut(kwargs; _rcut = kwargs[rcutkey]) + if kwargs[:rin] isa Number + rin = kwargs[:rin] + else + error("Cannot read rin; please provide a number of file an issue if a more general mechanism is needed.") + end if rcut isa Number - cutoffs = Dict([ (s1, s2) => (0.0, rcut) for s1 in elements, s2 in elements]...) + cutoffs = Dict([ (s1, s2) => (rin, rcut) for s1 in elements, s2 in elements]...) else - cutoffs = Dict([ (s1, s2) => (0.0, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...) + cutoffs = Dict([ (s1, s2) => (rin, rcut[(s1, s2)]) for s1 in elements, s2 in elements]...) end # rcut = maximum(values(rcut)) # multitransform wants a single cutoff. # construct the rin0cut structures - rin0cuts = _rin0cuts_rcut(elements, cutoffs) + rin0cuts = _rin0cuts_rcut(elements, cutoffs, kwargs) end diff --git a/src/fit_model.jl b/src/fit_model.jl index 24ca1b13..c830fb57 100644 --- a/src/fit_model.jl +++ b/src/fit_model.jl @@ -14,8 +14,6 @@ export acefit!, assemble, compute_errors _get_Vref(model::ACEPotential) = model.model.Vref -__set_params!(model::ACEPotential, coeffs) = ACEpotentials.Models.set_parameters!(model, coeffs) - default_weights() = Dict("default"=>Dict("E"=>30.0, "F"=>1.0, "V"=>1.0)) function _make_prior(model::ACEpotentials.Models.ACEPotential, smoothness, P) @@ -163,7 +161,7 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model; coeffs = P \ result["C"] # dispatch setting of parameters - __set_params!(model, coeffs) + ACEpotentials.Models.set_linear_parameters!(model, coeffs) if haskey(result, "committee") co_coeffs = result["committee"] diff --git a/src/models/calculators.jl b/src/models/calculators.jl index 4019567c..0904d359 100644 --- a/src/models/calculators.jl +++ b/src/models/calculators.jl @@ -54,6 +54,17 @@ function set_parameters!(V::ACEPotential, θ::AbstractVector) return set_parameters!(V, ps) end +function set_linear_parameters!(V::ACEPotential{<: ACEModel}, θ::AbstractVector) + ps = V.ps + ps1 = (WB = ps.WB, Wpair = ps.Wpair,) + ps1_vec, _restruct = destructure(ps1) + ps2 = _restruct(θ) + ps3 = deepcopy(ps) + ps3.WB[:] = ps2.WB + ps3.Wpair[:] = ps2.Wpair + return set_parameters!(V, ps3) +end + # --------------------------------------------------------------- # AtomsCalculatorsUtilities / SitePotential based implementation #