Skip to content

Commit

Permalink
fix(pt/dp): make dpa2 convertable to .dp format (#4324)
Browse files Browse the repository at this point in the history
Fix #4295. BTW, I found that there seems no universal uts for
`convert-backend` command.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Updated `RepformerLayer` class to version 2, enhancing serialization
and deserialization processes.
- Introduced a new structure for residual variables within the
serialized data, improving organization and clarity.

- **Bug Fixes**
- Adjusted version compatibility checks in the `deserialize` method to
align with the new versioning scheme.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Nov 8, 2024
1 parent 0199ad5 commit 15bb00c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
19 changes: 11 additions & 8 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ def serialize(self) -> dict:
"""
data = {
"@class": "RepformerLayer",
"@version": 1,
"@version": 2,
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel,
Expand Down Expand Up @@ -1877,9 +1877,11 @@ def serialize(self) -> dict:
if self.update_style == "res_residual":
data.update(
{
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
"g2_residual": [to_numpy_array(aa) for aa in self.g2_residual],
"h2_residual": [to_numpy_array(aa) for aa in self.h2_residual],
"@variables": {
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
"g2_residual": [to_numpy_array(aa) for aa in self.g2_residual],
"h2_residual": [to_numpy_array(aa) for aa in self.h2_residual],
}
}
)
return data
Expand All @@ -1894,7 +1896,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
The dict to deserialize from.
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
linear1 = data.pop("linear1")
update_chnnl_2 = data["update_chnnl_2"]
Expand All @@ -1915,9 +1917,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
attn2_ev_apply = data.pop("attn2_ev_apply", None)
loc_attn = data.pop("loc_attn", None)
g1_self_mlp = data.pop("g1_self_mlp", None)
g1_residual = data.pop("g1_residual", [])
g2_residual = data.pop("g2_residual", [])
h2_residual = data.pop("h2_residual", [])
variables = data.pop("@variables", {})
g1_residual = variables.get("g1_residual", data.pop("g1_residual", []))
g2_residual = variables.get("g2_residual", data.pop("g2_residual", []))
h2_residual = variables.get("h2_residual", data.pop("h2_residual", []))

obj = cls(**data)
obj.linear1 = NativeLayer.deserialize(linear1)
Expand Down
19 changes: 11 additions & 8 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ def serialize(self) -> dict:
"""
data = {
"@class": "RepformerLayer",
"@version": 1,
"@version": 2,
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel,
Expand Down Expand Up @@ -1380,9 +1380,11 @@ def serialize(self) -> dict:
if self.update_style == "res_residual":
data.update(
{
"g1_residual": [to_numpy_array(t) for t in self.g1_residual],
"g2_residual": [to_numpy_array(t) for t in self.g2_residual],
"h2_residual": [to_numpy_array(t) for t in self.h2_residual],
"@variables": {
"g1_residual": [to_numpy_array(t) for t in self.g1_residual],
"g2_residual": [to_numpy_array(t) for t in self.g2_residual],
"h2_residual": [to_numpy_array(t) for t in self.h2_residual],
}
}
)
return data
Expand All @@ -1397,7 +1399,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
The dict to deserialize from.
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
linear1 = data.pop("linear1")
update_chnnl_2 = data["update_chnnl_2"]
Expand All @@ -1418,9 +1420,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
attn2_ev_apply = data.pop("attn2_ev_apply", None)
loc_attn = data.pop("loc_attn", None)
g1_self_mlp = data.pop("g1_self_mlp", None)
g1_residual = data.pop("g1_residual", [])
g2_residual = data.pop("g2_residual", [])
h2_residual = data.pop("h2_residual", [])
variables = data.pop("@variables", {})
g1_residual = variables.get("g1_residual", data.pop("g1_residual", []))
g2_residual = variables.get("g2_residual", data.pop("g2_residual", []))
h2_residual = variables.get("h2_residual", data.pop("h2_residual", []))

obj = cls(**data)
obj.linear1 = MLPLayer.deserialize(linear1)
Expand Down

0 comments on commit 15bb00c

Please sign in to comment.