Skip to content

Commit

Permalink
chore: use pickle instead of copy.deepcopy
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 7, 2024
1 parent 3236db5 commit dae0a3e
Show file tree
Hide file tree
Showing 80 changed files with 409 additions and 287 deletions.
6 changes: 4 additions & 2 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Optional,
)
Expand All @@ -15,6 +14,9 @@
from deepmd.dpmodel.output_def import (
FittingOutputDef,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -208,7 +210,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 2)
data.pop("@class")
data.pop("type")
Expand Down
8 changes: 5 additions & 3 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Optional,
Union,
Expand All @@ -16,6 +15,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -305,7 +307,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 2), 2, 2)
data.pop("@class", None)
data.pop("type", None)
Expand Down Expand Up @@ -418,7 +420,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 2)
models = [
BaseAtomicModel.get_class_by_type(model["type"]).deserialize(model)
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Optional,
Union,
Expand All @@ -18,6 +17,9 @@
from deepmd.dpmodel.utils.safe_gradient import (
safe_for_sqrt,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.pair_tab import (
PairTab,
)
Expand Down Expand Up @@ -174,7 +176,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 2)
data.pop("@class")
data.pop("type")
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import itertools
from typing import (
Any,
Expand Down Expand Up @@ -33,6 +32,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -460,7 +462,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "DescrptSeA":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Optional,
Expand Down Expand Up @@ -30,6 +29,9 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -388,7 +390,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "DescrptSeR":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import itertools
from typing import (
Optional,
Expand Down Expand Up @@ -30,6 +29,9 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -388,7 +390,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "DescrptSeT":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Optional,
Expand All @@ -20,6 +19,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -156,7 +158,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
var_name = data.pop("var_name", None)
assert var_name == "dipole"
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
Optional,
Expand All @@ -15,6 +14,9 @@
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
)
from deepmd.utils.copy import (
deepcopy,
)

if TYPE_CHECKING:
from deepmd.dpmodel.fitting.general_fitting import (
Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data["numb_dos"] = data.pop("dim_out")
data.pop("tot_ener_zero", None)
Expand Down
7 changes: 5 additions & 2 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -13,11 +12,15 @@
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
)
from deepmd.utils.copy import (
deepcopy,
)

if TYPE_CHECKING:
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)

from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -73,7 +76,7 @@ def __init__(

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("var_name")
data.pop("dim_out")
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from abc import (
abstractmethod,
)
Expand Down Expand Up @@ -32,6 +31,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.finetune import (
get_index_between_two_maps,
map_atom_exclude_types,
Expand Down Expand Up @@ -320,7 +322,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = deepcopy(data)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Optional,
Expand All @@ -16,6 +15,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -179,7 +181,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
return super().deserialize(data)

Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Optional,
Expand All @@ -26,6 +25,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.finetune import (
get_index_between_two_maps,
)
Expand Down Expand Up @@ -197,7 +199,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 3, 1)
var_name = data.pop("var_name", None)
assert var_name == "polar"
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/property_fitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Optional,
Union,
Expand All @@ -13,6 +12,9 @@
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -111,7 +113,7 @@ def __init__(

@classmethod
def deserialize(cls, data: dict) -> "PropertyFittingNet":
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("dim_out")
data.pop("var_name")
Expand Down
12 changes: 7 additions & 5 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
See issue #2982 for more information.
"""

import copy
import itertools
from typing import (
Callable,
Expand All @@ -30,6 +29,9 @@
from deepmd.dpmodel.utils.seed import (
child_seed,
)
from deepmd.utils.copy import (
deepcopy,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -135,7 +137,7 @@ def deserialize(cls, data: dict) -> "NativeLayer":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class", None)
variables = data.pop("@variables")
Expand Down Expand Up @@ -404,7 +406,7 @@ def deserialize(cls, data: dict) -> "LayerNorm":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class", None)
variables = data.pop("@variables")
Expand Down Expand Up @@ -673,7 +675,7 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
layers = data.pop("layers")
Expand Down Expand Up @@ -778,7 +780,7 @@ def deserialize(cls, data: dict) -> "FittingNet":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
data = deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class", None)
layers = data.pop("layers")
Expand Down
Loading

0 comments on commit dae0a3e

Please sign in to comment.