Skip to content

Commit

Permalink
add version info to NN and PairTab
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 27, 2024
1 parent 7145d32 commit aedc5bc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
30 changes: 29 additions & 1 deletion deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import h5py
import numpy as np

from deepmd.dpmodel.utils.version import (
check_version_compatibility,
)

try:
from deepmd._version import version as __version__
except ImportError:
Expand Down Expand Up @@ -189,6 +193,8 @@ def serialize(self) -> dict:
"idt": self.idt,
}
return {
"@class": "Layer",
"@version": 1,
"bias": self.b is not None,
"use_timestep": self.idt is not None,
"activation_function": self.activation_function,
Expand All @@ -208,6 +214,8 @@ def deserialize(cls, data: dict) -> "NativeLayer":
The dict to deserialize from.
"""
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
variables = data.pop("@variables")
assert variables["w"] is not None and len(variables["w"].shape) == 2
num_in, num_out = variables["w"].shape
Expand Down Expand Up @@ -349,7 +357,11 @@ def serialize(self) -> dict:
dict
The serialized network.
"""
return {"layers": [layer.serialize() for layer in self.layers]}
return {
"@class": "NN",
"@version": 1,
"layers": [layer.serialize() for layer in self.layers],
}

@classmethod
def deserialize(cls, data: dict) -> "NN":
Expand All @@ -360,6 +372,9 @@ def deserialize(cls, data: dict) -> "NN":
data : dict
The dict to deserialize from.
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
return cls(data["layers"])

def __getitem__(self, key):
Expand Down Expand Up @@ -471,6 +486,8 @@ def serialize(self) -> dict:
The serialized network.
"""
return {
"@class": "EmbeddingNetwork",
"@version": 1,
"in_dim": self.in_dim,
"neuron": self.neuron.copy(),
"activation_function": self.activation_function,
Expand All @@ -490,6 +507,8 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
The dict to deserialize from.
"""
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
layers = data.pop("layers")
obj = cls(**data)
super(EN, obj).__init__(layers)
Expand Down Expand Up @@ -566,6 +585,8 @@ def serialize(self) -> dict:
The serialized network.
"""
return {
"@class": "FittingNetwork",
"@version": 1,
"in_dim": self.in_dim,
"out_dim": self.out_dim,
"neuron": self.neuron.copy(),
Expand All @@ -586,6 +607,8 @@ def deserialize(cls, data: dict) -> "FittingNet":
The dict to deserialize from.
"""
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
layers = data.pop("layers")
obj = cls(**data)
T_Network.__init__(obj, layers)
Expand Down Expand Up @@ -688,6 +711,8 @@ def serialize(self) -> dict:
network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()}
network_type_name = network_type_map_inv[self.network_type]
return {
"@class": "NetworkCollection",
"@version": 1,
"ndim": self.ndim,
"ntypes": self.ntypes,
"network_type": network_type_name,
Expand All @@ -703,4 +728,7 @@ def deserialize(cls, data: dict) -> "NetworkCollection":
data : dict
The dict to deserialize from.
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
return cls(**data)
9 changes: 9 additions & 0 deletions deepmd/utils/pair_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
CubicSpline,
)

from deepmd.dpmodel.utils.version import (
check_version_compatibility,
)

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -72,6 +76,8 @@ def reinit(self, filename: str, rcut: Optional[float] = None) -> None:

def serialize(self) -> dict:
return {
"@class": "PairTab",
"@version": 1,
"rmin": self.rmin,
"rmax": self.rmax,
"hh": self.hh,
Expand All @@ -87,6 +93,9 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "PairTab":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
variables = data.pop("@variables")
tab = PairTab(None, None)
tab.vdata = variables["vdata"]
Expand Down

0 comments on commit aedc5bc

Please sign in to comment.