Skip to content

Commit

Permalink
update version
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jun 4, 2024
1 parent a394d97 commit 4d09586
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
if cls is not DescrptSeA:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
embedding_net_variables = cls.deserialize_network(
Expand Down Expand Up @@ -1427,7 +1427,7 @@ def serialize(self, suffix: str = "") -> dict:
return {
"@class": "Descriptor",
"type": "se_e2_a",
"@version": 1,
"@version": 2,
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
if cls is not DescrptSeR:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)
embedding_net_variables = cls.deserialize_network(
data.pop("embeddings"), suffix=suffix
)
Expand Down Expand Up @@ -772,7 +772,7 @@ def serialize(self, suffix: str = "") -> dict:
return {
"@class": "Descriptor",
"type": "se_r",
"@version": 1,
"@version": 2,
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel_r,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
if cls is not DescrptSeT:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
embedding_net_variables = cls.deserialize_network(
Expand Down Expand Up @@ -926,7 +926,7 @@ def serialize(self, suffix: str = "") -> dict:
return {
"@class": "Descriptor",
"type": "se_e3",
"@version": 1,
"@version": 2,
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def serialize(self, suffix: str) -> dict:
data = {
"@class": "Fitting",
"type": "dipole",
"@version": 1,
"@version": 2,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"embedding_width": self.dim_rot_mat_1,
Expand Down Expand Up @@ -406,7 +406,7 @@ def deserialize(cls, data: dict, suffix: str):
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)
fitting = cls(**data)
fitting.fitting_net_variables = cls.deserialize_network(
data["nets"],
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data["numb_dos"] = data.pop("dim_out")
fitting = cls(**data)
fitting.fitting_net_variables = cls.deserialize_network(
Expand All @@ -700,7 +700,7 @@ def serialize(self, suffix: str = "") -> dict:
data = {
"@class": "Fitting",
"type": "dos",
"@version": 1,
"@version": 2,
"var_name": "dos",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
check_version_compatibility(data.pop("@version", 1), 2, 1)
fitting = cls(**data)
fitting.fitting_net_variables = cls.deserialize_network(
data["nets"],
Expand All @@ -900,7 +900,7 @@ def serialize(self, suffix: str = "") -> dict:
data = {
"@class": "Fitting",
"type": "ener",
"@version": 1,
"@version": 2,
"var_name": "energy",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def serialize(self, suffix: str) -> dict:
data = {
"@class": "Fitting",
"type": "polar",
"@version": 1,
"@version": 3,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"embedding_width": self.dim_rot_mat_1,
Expand Down Expand Up @@ -603,7 +603,7 @@ def deserialize(cls, data: dict, suffix: str):
"""
data = data.copy()
check_version_compatibility(
data.pop("@version", 1), 2, 1
data.pop("@version", 1), 3, 1
) # to allow PT version.
fitting = cls(**data)
fitting.fitting_net_variables = cls.deserialize_network(
Expand Down

0 comments on commit 4d09586

Please sign in to comment.