Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add ZBL weighted DP model #3210

Merged
merged 111 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 96 commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
47cff4b
feat: add pair table model to pytorch
Jan 28, 2024
04b6f57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2024
eb59d87
fix: typo
Jan 28, 2024
b7cbbd5
fix: typo
Jan 28, 2024
a1a76bb
Merge branch 'devel' into devel
anyangml Jan 28, 2024
84767f3
fix: update ruct extrapolation
Jan 28, 2024
8fee8fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2024
ff08515
fix: update allclose precision
Jan 28, 2024
f4b3720
Merge branch 'devel' into devel
anyangml Jan 29, 2024
451916e
Merge branch 'devel' into devel
anyangml Jan 29, 2024
0968eaa
Merge branch 'devel' into devel
anyangml Jan 29, 2024
6b0559e
Merge branch 'devel' into devel
anyangml Jan 29, 2024
8cbb98c
chore: refactor common method to PairTab
Jan 29, 2024
a08092c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
d3090b9
fix: update unit tests
Jan 29, 2024
daf2fc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
399c278
fix: revert padding zero mask change
Jan 29, 2024
59abe43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
8f1cdc8
Merge branch 'devel' into devel
anyangml Jan 30, 2024
88936cc
Merge branch 'devel' into devel
anyangml Jan 30, 2024
1c4ee0d
feat: redo extrapolation with cubic spline for smoothness
Jan 30, 2024
5793828
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2024
27f3559
Merge branch 'devel' into devel
anyangml Jan 30, 2024
92dec18
chore: refactor _make_data in PairTab
Jan 30, 2024
bc04359
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2024
4433035
chore: move file
Jan 30, 2024
2ba0318
Merge branch 'devel' into devel
anyangml Jan 30, 2024
f2c40e6
Merge branch 'devel' into devel
anyangml Jan 31, 2024
4851a0a
chore: refactor extrapolation code
Jan 31, 2024
ddbe7db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2024
29d95db
Merge branch 'devel' into devel
anyangml Jan 31, 2024
365c20d
feat: add zbl weighted model
Feb 1, 2024
fb4ae7d
Merge branch 'deepmodeling:devel' into devel
anyangml Feb 1, 2024
9e6d8f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2024
56a34a0
Merge branch 'devel' into devel
anyangml Feb 2, 2024
294a91e
fix: linear model
Feb 2, 2024
99beff0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
3f4f0a5
fix: import
anyangml Feb 4, 2024
d89534c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2024
dc3c05a
Merge branch 'devel' into devel
anyangml Feb 4, 2024
c3d65ab
feat: add weight calculation and tests
anyangml Feb 4, 2024
5c1c18d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2024
379e48e
feat: add numpy implementation
anyangml Feb 4, 2024
96644ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2024
3881e72
Merge branch 'devel' into devel
anyangml Feb 4, 2024
2c91aff
feat: add numpy implementation
anyangml Feb 4, 2024
bba4856
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2024
935a149
fix: overwriting nlist bug
anyangml Feb 4, 2024
c6319d7
chore: rename variables
anyangml Feb 5, 2024
6138ab8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2024
b05c572
Merge branch 'devel' into devel
anyangml Feb 5, 2024
267a8bd
chore: refactor code to add general linear model
anyangml Feb 5, 2024
d140002
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2024
2253825
fix: precommit
anyangml Feb 5, 2024
7fe41a1
fix: precommit
anyangml Feb 5, 2024
8999623
chore: refactor
anyangml Feb 5, 2024
0f748b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2024
80f9ca1
chore: refactor code
anyangml Feb 5, 2024
6626e11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2024
7006d4e
feat: add model wrapper
anyangml Feb 5, 2024
3b29d06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2024
edd61f8
Merge branch 'devel' into devel
anyangml Feb 5, 2024
ddb56ee
chore: mute test warning
anyangml Feb 5, 2024
e6908f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2024
0603b61
chore: mute test warning
anyangml Feb 5, 2024
577a3eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2024
de8d1fb
fix: distinguish type nlist
anyangml Feb 6, 2024
63cd163
chore: merge conflict
anyangml Feb 6, 2024
aaff157
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
caa3202
Merge branch 'devel' into devel
anyangml Feb 6, 2024
36ef6d2
chore: remove print
anyangml Feb 6, 2024
319460f
Merge remote-tracking branch 'upstream/devel' into devel
anyangml Feb 6, 2024
446bd1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
94e8ee6
fix: import bug
anyangml Feb 6, 2024
0a02dc8
Merge branch 'devel' into devel
anyangml Feb 6, 2024
4a1f922
chore: refactor
anyangml Feb 6, 2024
32f6fa9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
7d458d7
feat: add UTs
anyangml Feb 6, 2024
72061f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
afc2af4
fix: precommit
anyangml Feb 6, 2024
f88c34c
feat: add UTs
anyangml Feb 6, 2024
95f1c62
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
3c6cf5b
feat: add UTs
anyangml Feb 6, 2024
e8bcf43
fix: test zbl model
anyangml Feb 6, 2024
d2674f8
fix: test zbl model
anyangml Feb 6, 2024
2ddc8c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
525f598
Merge branch 'devel' into devel
anyangml Feb 6, 2024
5fc617a
fix: UTs
anyangml Feb 6, 2024
7921605
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
a72c8e8
chore: refactor
anyangml Feb 6, 2024
95d693f
chore: refactor
anyangml Feb 6, 2024
f2abaaa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
98f9eb2
fix: remove autodiff UTs, no force and virial
anyangml Feb 6, 2024
396ad3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
c12e41a
fix: refactor
anyangml Feb 7, 2024
1d61a2c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2024
4233c3a
support separate r_differentiable and c_differentiable (#3240)
wanghan-iapcm Feb 7, 2024
8e03440
fix: autograd to nan
anyangml Feb 7, 2024
2ec84ca
Merge remote-tracking branch 'upstream/devel' into devel
anyangml Feb 7, 2024
10116e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2024
442b87e
fix: autograd
anyangml Feb 7, 2024
97af8aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2024
6313c41
fix: autograd nan
anyangml Feb 7, 2024
d1b6a9d
fix: jit
anyangml Feb 7, 2024
7339d81
fix: jit
anyangml Feb 7, 2024
4c7de6e
fix: jit
anyangml Feb 7, 2024
4edff13
fix: jit
anyangml Feb 7, 2024
656d979
fix: jit
anyangml Feb 8, 2024
ccf63ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2024
3fa5d08
fix: add UTs
anyangml Feb 8, 2024
ae93085
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 288 additions & 0 deletions deepmd/dpmodel/model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
abstractmethod,
)
from typing import (
Dict,
List,
Optional,
Tuple,
Union,
)

import numpy as np

from deepmd.dpmodel import (
FittingOutputDef,
OutputVariableDef,
)
from deepmd.dpmodel.utils.nlist import (
build_multiple_neighbor_list,
get_multiple_nlist_key,
nlist_distinguish_types,
)

from .base_atomic_model import (
BaseAtomicModel,
)
from .dp_atomic_model import (
DPAtomicModel,
)
from .pairtab_atomic_model import (
PairTabModel,
)


class LinearAtomicModel(BaseAtomicModel):
"""Linear model make linear combinations of several existing models.

Parameters
----------
models : list[DPAtomicModel or PairTabModel]
A list of models to be combined. PairTabModel must be used together with a DPAtomicModel.
"""

def __init__(
self,
models: List[BaseAtomicModel],
**kwargs,
):
super().__init__()
self.models = models
self.distinguish_type_list = [
model.distinguish_types() for model in self.models
]

def distinguish_types(self) -> bool:
"""If distinguish different types by sorting."""
return False

Check warning on line 58 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L58

Added line #L58 was not covered by tests

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_rcuts())

Check warning on line 62 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L62

Added line #L62 was not covered by tests

def get_rcuts(self) -> List[float]:
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]

def get_sel(self) -> List[int]:
return [model.get_nsel() for model in self.models]
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def get_original_sels(self) -> List[Union[int, List[int]]]:
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]

def _sort_rcuts_sels(self) -> Tuple[List[int], List[float]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
zipped = sorted(
zip(self.get_rcuts(), self.get_sel()), key=lambda x: (x[1], x[0])
)
return [p[0] for p in zipped], [p[1] for p in zipped]

def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
"""Return atomic prediction.

Parameters
----------
extended_coord
coodinates in extended region, (nframes, nall * 3)
extended_atype
atomic type in extended region, (nframes, nall)
nlist
neighbor list, (nframes, nloc, nsel).
mapping
mapps the extended indices to local indices.
fparam
frame parameter. (nframes, ndf)
aparam
atomic parameter. (nframes, nloc, nda)

Returns
-------
result_dict
the result dict, defined by the fitting net output def.
"""
nframes, nloc, nnei = nlist.shape
self.extended_coord = extended_coord.reshape(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
self.extended_coord,
nlist,
sorted_rcuts,
sorted_sels,
)
raw_nlists = [
nlists[get_multiple_nlist_key(rcut, sel)]
for rcut, sel in zip(self.get_rcuts(), self.get_sel())
]
self.nlists_ = [
nl if not dt else nlist_distinguish_types(nl, extended_atype, sel)
for dt, nl, sel in zip(
self.distinguish_type_list, raw_nlists, self.get_original_sels()
)
]
ener_list = [
model.forward_atomic(
self.extended_coord,
extended_atype,
nl,
mapping,
)["energy"]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
for model, nl in zip(self.models, self.nlists_)
]
self.weights = self._compute_weight()
self.atomic_bias = None
if self.atomic_bias is not None:
raise NotImplementedError("Need to add bias in a future PR.")

Check warning on line 144 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L144

Added line #L144 was not covered by tests
else:
fit_ret = {
"energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0),
} # (nframes, nloc, 1)
return fit_ret

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(

Check warning on line 152 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L152

Added line #L152 was not covered by tests
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
anyangml marked this conversation as resolved.
Show resolved Hide resolved
)
]
)

def serialize(self) -> dict:
return {

Check warning on line 161 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L161

Added line #L161 was not covered by tests
"models": [model.serialize() for model in self.models],
}

@classmethod
def deserialize(cls, data) -> "LinearAtomicModel":
models = [DPAtomicModel.deserialize(model) for model in data["models"]]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return cls(models)

Check warning on line 168 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L167-L168

Added lines #L167 - L168 were not covered by tests

@abstractmethod
def _compute_weight(self) -> np.ndarray:
Fixed Show fixed Hide fixed
"""This should be a list of user defined weights that matches the number of models to be combined."""
raise NotImplementedError

Check warning on line 173 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L173

Added line #L173 was not covered by tests


class DPZBLLinearAtomicModel(LinearAtomicModel):
"""Model linearly combine a list of AtomicModels.

Parameters
----------
models
This linear model should take a DPAtomicModel and a PairTable model.
"""

def __init__(
self,
dp_model: DPAtomicModel,
zbl_model: PairTabModel,
sw_rmin: float,
sw_rmax: float,
smin_alpha: Optional[float] = 0.1,
**kwargs,
):
models = [dp_model, zbl_model]
super().__init__(models, **kwargs)
self.dp_model = dp_model
self.zbl_model = zbl_model

self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
self.smin_alpha = smin_alpha

def serialize(self) -> dict:
return {
"dp_model": self.dp_model.serialize(),
"zbl_model": self.zbl_model.serialize(),
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]

dp_model = DPAtomicModel.deserialize(data["dp_model"])
zbl_model = PairTabModel.deserialize(data["zbl_model"])
anyangml marked this conversation as resolved.
Show resolved Hide resolved

return cls(
dp_model=dp_model,
zbl_model=zbl_model,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
smin_alpha=smin_alpha,
)

def _compute_weight(self) -> List[np.ndarray]:
"""ZBL weight.

Returns
-------
List[np.ndarray]
the atomic ZBL weight for interpolation. (nframes, nloc, 1)
"""
assert (
self.sw_rmax > self.sw_rmin
), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`."

dp_nlist = self.nlists_[0]
zbl_nlist = self.nlists_[1]

zbl_nnei = zbl_nlist.shape[-1]
dp_nnei = dp_nlist.shape[-1]

# use the larger rr based on nlist
nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist
nloc = nlist_larger.shape[1]
Fixed Show fixed Hide fixed
masked_nlist = np.clip(nlist_larger, 0, None)
pairwise_rr = np.sqrt(
np.sum(
np.power(
(
np.expand_dims(self.extended_coord, 2)
- np.expand_dims(self.extended_coord, 1)
),
2,
),
axis=-1,
)
)

rr = np.take_along_axis(pairwise_rr[:, :nloc, :], masked_nlist, 2)

numerator = np.sum(
rr * np.exp(-rr / self.smin_alpha), axis=-1
) # masked nnei will be zero, no need to handle
denominator = np.sum(
np.where(
nlist_larger != -1,
np.exp(-rr / self.smin_alpha),
np.zeros_like(nlist_larger),
),
axis=-1,
) # handle masked nnei.
sigma = numerator / denominator
u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin)
coef = np.zeros_like(u)
left_mask = sigma < self.sw_rmin
mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax)
right_mask = sigma >= self.sw_rmax
coef[left_mask] = 1
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
coef[mid_mask] = smooth[mid_mask]
coef[right_mask] = 0
self.zbl_weight = coef
return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)]
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def fitting_output_def(self) -> FittingOutputDef:
def get_rcut(self) -> float:
return self.rcut

def get_sel(self) -> int:
def get_sel(self) -> List[int]:
return [self.sel]

def get_nsel(self) -> int:
return self.sel

def distinguish_types(self) -> bool:
Expand Down Expand Up @@ -196,7 +199,7 @@ def _pair_tabulated_inter(
i_type, j_type, idx, self.tab_data, self.nspline
)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
ener = self._calcualte_ener(table_coef, uu)
ener = self._calculate_ener(table_coef, uu)
# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
Expand Down Expand Up @@ -275,7 +278,7 @@ def _extract_spline_coefficient(
return final_coef

@staticmethod
def _calcualte_ener(coef: np.ndarray, uu: np.ndarray) -> np.ndarray:
def _calculate_ener(coef: np.ndarray, uu: np.ndarray) -> np.ndarray:
"""Calculate energy using spline coeeficients.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def build_multiple_neighbor_list(
nall = coord1.shape[1]
coord0 = coord1[:, :nloc, :]
nlist_mask = nlist == -1
tnlist_0 = nlist
tnlist_0 = nlist.copy()
tnlist_0[nlist_mask] = 0
index = np.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3])
coord2 = np.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3)
Expand Down
44 changes: 44 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,62 @@
from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
)
from deepmd.pt.model.model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.model.pairtab_atomic_model import (
PairTabModel,
)
from deepmd.pt.model.task import (
Fitting,
)

from .ener import (
EnergyModel,
ZBLModel,
)
from .model import (
BaseModel,
)


def get_zbl_model(model_params, sampled=None):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
# descriptor
model_params["descriptor"]["ntypes"] = ntypes
descriptor = Descriptor(**model_params["descriptor"])
# fitting
fitting_net = model_params.get("fitting_net", None)
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["distinguish_types"] = descriptor.distinguish_types()
fitting_net["embedding_width"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True

Check warning on line 42 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L40-L42

Added lines #L40 - L42 were not covered by tests
fitting = Fitting(**fitting_net)
dp_model = DPAtomicModel(
descriptor, fitting, type_map=model_params["type_map"], resuming=True
)
# pairtab
filepath = model_params["use_srtab"]
pt_model = PairTabModel(
filepath, model_params["descriptor"]["rcut"], model_params["descriptor"]["sel"]
)

rmin = model_params["sw_rmin"]
rmax = model_params["sw_rmax"]
return ZBLModel(
dp_model,
pt_model,
rmin,
rmax,
)


def get_model(model_params, sampled=None):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
Expand Down
Loading
Loading