Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: refactor LinearAtomicModel serialize/deserialize #3451

Merged
merged 13 commits into from
Mar 13, 2024
1 change: 1 addition & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)


@BaseAtomicModel.register("standard")
class DPAtomicModel(BaseAtomicModel):
"""Model give atomic prediction of some physical property.

Expand Down
47 changes: 20 additions & 27 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
)
from typing import (
Dict,
List,
Expand Down Expand Up @@ -225,40 +221,38 @@
]
)

@staticmethod
def serialize(models, type_map) -> dict:
def serialize(self) -> dict:
return {
"@class": "Model",
"type": "linear",
"@version": 1,
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
"type_map": type_map,
"models": [model.serialize() for model in self.models],
"type_map": self.type_map,
}

@staticmethod
def deserialize(data) -> Tuple[List[BaseAtomicModel], List[str]]:
@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
model_names = data["model_name"]
type_map = data["type_map"]
type_map = data.pop("type_map")

Check warning on line 239 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L239

Added line #L239 was not covered by tests
models = [
getattr(sys.modules[__name__], name).deserialize(model)
for name, model in zip(model_names, data["models"])
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
for model in data["models"]
]
return models, type_map
data.pop("models")
return cls(models, type_map, **data)

Check warning on line 245 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L244-L245

Added lines #L244 - L245 were not covered by tests

@abstractmethod
def _compute_weight(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlists_: List[np.ndarray],
) -> np.ndarray:
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"""This should be a list of user defined weights that matches the number of models to be combined."""
raise NotImplementedError
nmodels = len(self.models)
return [np.ones(1) / nmodels for _ in range(nmodels)]

Check warning on line 255 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L254-L255

Added lines #L254 - L255 were not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down Expand Up @@ -335,10 +329,10 @@
{
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearEnergyAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"@version": 2,
"models": LinearEnergyAtomicModel(
models=[self.models[0], self.models[1]], type_map=self.type_map
).serialize(),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
Expand All @@ -349,16 +343,15 @@
@classmethod
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)

Check warning on line 346 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L346

Added line #L346 was not covered by tests
data.pop("@class")
data.pop("type")
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

([dp_model, zbl_model], type_map) = LinearEnergyAtomicModel.deserialize(
data.pop("models")
)
linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models"))
dp_model, zbl_model = linear_model.models
type_map = linear_model.type_map

Check warning on line 354 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L352-L354

Added lines #L352 - L354 were not covered by tests

return cls(
dp_model=dp_model,
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)


@BaseAtomicModel.register("pairtab")
class PairTabAtomicModel(BaseAtomicModel):
"""Pairwise tabulation energy model.

Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
log = logging.getLogger(__name__)


@BaseAtomicModel.register("standard")

Check warning on line 36 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L36

Added line #L36 was not covered by tests
class DPAtomicModel(torch.nn.Module, BaseAtomicModel):
"""Model give atomic prediction of some physical property.

Expand Down
49 changes: 22 additions & 27 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
)
from typing import (
Dict,
List,
Expand Down Expand Up @@ -260,35 +256,35 @@
]
)

@staticmethod
def serialize(models, type_map) -> dict:
def serialize(self) -> dict:

Check warning on line 259 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L259

Added line #L259 was not covered by tests
return {
"@class": "Model",
"@version": 1,
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
"type_map": type_map,
"models": [model.serialize() for model in self.models],
"type_map": self.type_map,
}

@staticmethod
def deserialize(data) -> Tuple[List[BaseAtomicModel], List[str]]:
@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":

Check warning on line 269 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L268-L269

Added lines #L268 - L269 were not covered by tests
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
model_names = data["model_name"]
type_map = data["type_map"]
data.pop("@class")
data.pop("type")
type_map = data.pop("type_map")

Check warning on line 274 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L272-L274

Added lines #L272 - L274 were not covered by tests
models = [
getattr(sys.modules[__name__], name).deserialize(model)
for name, model in zip(model_names, data["models"])
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
for model in data["models"]
]
return models, type_map
data.pop("models")
return cls(models, type_map, **data)

Check warning on line 280 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L279-L280

Added lines #L279 - L280 were not covered by tests

@abstractmethod
def _compute_weight(
self, extended_coord, extended_atype, nlists_
) -> List[torch.Tensor]:
"""This should be a list of user defined weights that matches the number of models to be combined."""
raise NotImplementedError
nmodels = len(self.models)
return [torch.ones(1) / nmodels for _ in range(nmodels)]

Check warning on line 287 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L286-L287

Added lines #L286 - L287 were not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down Expand Up @@ -400,11 +396,11 @@
dd.update(
{
"@class": "Model",
"@version": 1,
"@version": 2,
"type": "zbl",
"models": LinearEnergyAtomicModel.serialize(
[self.models[0], self.models[1]], self.type_map
),
"models": LinearEnergyAtomicModel(
models=[self.models[0], self.models[1]], type_map=self.type_map
).serialize(),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
Expand All @@ -415,14 +411,13 @@
@classmethod
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)

Check warning on line 414 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L414

Added line #L414 was not covered by tests
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

[dp_model, zbl_model], type_map = LinearEnergyAtomicModel.deserialize(
data.pop("models")
)
linear_model = LinearEnergyAtomicModel.deserialize(data.pop("models"))
dp_model, zbl_model = linear_model.models
type_map = linear_model.type_map

Check warning on line 420 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L418-L420

Added lines #L418 - L420 were not covered by tests

data.pop("@class", None)
data.pop("type", None)
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)


@BaseAtomicModel.register("pairtab")

Check warning on line 38 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L38

Added line #L38 was not covered by tests
class PairTabAtomicModel(torch.nn.Module, BaseAtomicModel):
"""Pairwise tabulation energy model.

Expand Down