Skip to content

Commit

Permalink
chore: avoid unnecessary deepcopy (deepmodeling#4327)
Browse files Browse the repository at this point in the history
This PR saves about 1 minute in each CI job. See devel vs this branch:

![image](https://github.com/user-attachments/assets/13fd74b6-bec0-4e44-9f5e-54c06be24531)

![image](https://github.com/user-attachments/assets/7987ab4a-3b4b-4585-9503-cccdfce81a26)


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced data handling in various models by replacing deep copy with
shallow copy in the `deserialize` method, improving performance.
- Added `exclude_types` parameter in multiple fitting classes, allowing
for the exclusion of specific atom types during fitting.

- **Bug Fixes**
- Improved robustness in output definition retrieval by adding fallback
mechanisms in certain models.

- **Documentation**
- Updated comments and docstrings for clarity in methods across several
classes.

- **Refactor**
- Streamlined data handling processes in multiple models by eliminating
unnecessary deep copying.
- Adjusted version compatibility checks in several `deserialize`
methods.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 9, 2024
1 parent 8f11bc7 commit cb3e39e
Show file tree
Hide file tree
Showing 35 changed files with 69 additions and 116 deletions.
3 changes: 1 addition & 2 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Optional,
)
Expand Down Expand Up @@ -208,7 +207,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 2)
data.pop("@class")
data.pop("type")
Expand Down
5 changes: 2 additions & 3 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Optional,
Union,
Expand Down Expand Up @@ -305,7 +304,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 2), 2, 2)
data.pop("@class", None)
data.pop("type", None)
Expand Down Expand Up @@ -418,7 +417,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 2)
models = [
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Optional,
Union,
Expand Down Expand Up @@ -174,7 +173,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 2)
data.pop("@class")
data.pop("type")
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import itertools
from typing import (
Any,
Expand Down Expand Up @@ -460,7 +459,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "DescrptSeA":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Optional,
Expand Down Expand Up @@ -388,7 +387,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "DescrptSeR":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import itertools
from typing import (
Optional,
Expand Down Expand Up @@ -388,7 +387,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "DescrptSeT":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Optional,
Expand Down Expand Up @@ -156,7 +155,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
var_name = data.pop("var_name", None)
assert var_name == "dipole"
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
Optional,
Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data["numb_dos"] = data.pop("dim_out")
data.pop("tot_ener_zero", None)
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("var_name")
data.pop("dim_out")
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from abc import (
abstractmethod,
)
Expand Down Expand Up @@ -320,7 +319,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = data.copy()
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Optional,
Expand Down Expand Up @@ -179,7 +178,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
return super().deserialize(data)

Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Optional,
Expand Down Expand Up @@ -197,7 +196,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 3, 1)
var_name = data.pop("var_name", None)
assert var_name == "polar"
Expand Down
3 changes: 1 addition & 2 deletions deepmd/dpmodel/fitting/property_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Optional,
Union,
Expand Down Expand Up @@ -111,7 +110,7 @@ def __init__(

@classmethod
def deserialize(cls, data: dict) -> "PropertyFittingNet":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("dim_out")
data.pop("var_name")
Expand Down
9 changes: 4 additions & 5 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
See issue #2982 for more information.
"""

import copy
import itertools
from typing import (
Callable,
Expand Down Expand Up @@ -135,7 +134,7 @@ def deserialize(cls, data: dict) -> "NativeLayer":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class", None)
variables = data.pop("@variables")
Expand Down Expand Up @@ -404,7 +403,7 @@ def deserialize(cls, data: dict) -> "LayerNorm":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class", None)
variables = data.pop("@variables")
Expand Down Expand Up @@ -673,7 +672,7 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
layers = data.pop("layers")
Expand Down Expand Up @@ -778,7 +777,7 @@ def deserialize(cls, data: dict) -> "FittingNet":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class", None)
layers = data.pop("layers")
Expand Down
3 changes: 1 addition & 2 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import copy
import logging
from typing import (
Callable,
Expand Down Expand Up @@ -331,7 +330,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "BaseAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
variables = data.pop("@variables", None)
variables = (
{"out_bias": None, "out_std": None} if variables is None else variables
Expand Down
3 changes: 1 addition & 2 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import functools
import logging
from typing import (
Expand Down Expand Up @@ -149,7 +148,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand Down
5 changes: 2 additions & 3 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Callable,
Optional,
Expand Down Expand Up @@ -369,7 +368,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 2), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand Down Expand Up @@ -563,7 +562,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
models = [
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
Expand Down
3 changes: 1 addition & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Callable,
Optional,
Expand Down Expand Up @@ -195,7 +194,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
tab = PairTab.deserialize(data.pop("tab"))
data.pop("@class", None)
Expand Down
15 changes: 6 additions & 9 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from copy import (
deepcopy,
)
from typing import (
Optional,
)
Expand Down Expand Up @@ -40,19 +37,19 @@ def __init__(
def translated_output_def(self):
out_def_data = self.model_output_def().get_data()
output_def = {
"dipole": deepcopy(out_def_data["dipole"]),
"global_dipole": deepcopy(out_def_data["dipole_redu"]),
"dipole": out_def_data["dipole"],
"global_dipole": out_def_data["dipole_redu"],
}
if self.do_grad_r("dipole"):
output_def["force"] = deepcopy(out_def_data["dipole_derv_r"])
output_def["force"] = out_def_data["dipole_derv_r"]
output_def["force"].squeeze(-2)
if self.do_grad_c("dipole"):
output_def["virial"] = deepcopy(out_def_data["dipole_derv_c_redu"])
output_def["virial"] = out_def_data["dipole_derv_c_redu"]
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["dipole_derv_c"])
output_def["atom_virial"] = out_def_data["dipole_derv_c"]
output_def["atom_virial"].squeeze(-3)
if "mask" in out_def_data:
output_def["mask"] = deepcopy(out_def_data["mask"])
output_def["mask"] = out_def_data["mask"]
return output_def

def forward(
Expand Down
9 changes: 3 additions & 6 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from copy import (
deepcopy,
)
from typing import (
Optional,
)
Expand Down Expand Up @@ -40,11 +37,11 @@ def __init__(
def translated_output_def(self):
out_def_data = self.model_output_def().get_data()
output_def = {
"atom_dos": deepcopy(out_def_data["dos"]),
"dos": deepcopy(out_def_data["dos_redu"]),
"atom_dos": out_def_data["dos"],
"dos": out_def_data["dos_redu"],
}
if "mask" in out_def_data:
output_def["mask"] = deepcopy(out_def_data["mask"])
output_def["mask"] = out_def_data["mask"]
return output_def

def forward(
Expand Down
Loading

0 comments on commit cb3e39e

Please sign in to comment.