Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): consistent "frozen" model #3450

Merged
merged 7 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,6 @@ def do_grad_(self, var_name: str, base: str) -> bool:
return self.fitting_output_def()[var_name].c_differentiable
return self.fitting_output_def()[var_name].r_differentiable

def get_model_def_script(self) -> str:
# TODO: implement this method; saved to model
raise NotImplementedError

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")

Expand Down
7 changes: 7 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,10 @@
deepmd.dpmodel.model.base_model.BaseBaseModel
Backend-independent BaseModel class.
"""

def __init__(self) -> None:
self.model_def_script = ""

Check warning on line 177 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L177

Added line #L177 was not covered by tests

def get_model_def_script(self) -> str:
"""Get the model definition script."""
return self.model_def_script

Check warning on line 181 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L181

Added line #L181 was not covered by tests
5 changes: 1 addition & 4 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
atomic_model_: Optional[T_AtomicModel] = None,
**kwargs,
):
BaseModel.__init__(self)

Check warning on line 76 in deepmd/dpmodel/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_model.py#L76

Added line #L76 was not covered by tests
if atomic_model_ is not None:
self.atomic_model: T_AtomicModel = atomic_model_
else:
Expand Down Expand Up @@ -452,10 +453,6 @@
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.atomic_model.get_nnei()

def get_model_def_script(self) -> str:
"""Get the model definition script."""
return self.atomic_model.get_model_def_script()

def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
return self.atomic_model.get_sel()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@
variables.get("b", None),
variables.get("idt", None),
)
if obj.b is not None:
obj.b = obj.b.ravel()
if obj.idt is not None:
obj.idt = obj.idt.ravel()

Check warning on line 236 in deepmd/dpmodel/utils/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/network.py#L233-L236

Added lines #L233 - L236 were not covered by tests
obj.check_shape_consistency()
return obj

Expand Down
3 changes: 0 additions & 3 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def reinit_pair_exclude(
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

def get_model_def_script(self) -> str:
return self.model_def_script

def atomic_output_def(self) -> FittingOutputDef:
old_def = self.fitting_output_def()
if self.atom_excl is None:
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
**kwargs,
):
torch.nn.Module.__init__(self)
self.model_def_script = ""
ntypes = len(type_map)
self.type_map = type_map
self.ntypes = ntypes
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def __init__(
):
models = [dp_model, zbl_model]
super().__init__(models, type_map, **kwargs)
self.model_def_script = ""
self.dp_model = dp_model
self.zbl_model = zbl_model

Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
**kwargs,
):
torch.nn.Module.__init__(self)
self.model_def_script = ""
self.tab_file = tab_file
self.rcut = rcut
self.tab = self._set_pairtab(tab_file, rcut)
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from .ener_model import (
EnergyModel,
)
from .frozen import (

Check warning on line 40 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L40

Added line #L40 was not covered by tests
FrozenModel,
)
from .make_hessian_model import (
make_hessian_model,
)
Expand Down Expand Up @@ -173,6 +176,7 @@
"get_model",
"DPModel",
"EnergyModel",
"FrozenModel",
"SpinModel",
"SpinEnergyModel",
"DPZBLModel",
Expand Down
174 changes: 174 additions & 0 deletions deepmd/pt/model/model/frozen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import tempfile
from typing import (

Check warning on line 4 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L2-L4

Added lines #L2 - L4 were not covered by tests
Dict,
List,
Optional,
)

import torch

Check warning on line 10 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L10

Added line #L10 was not covered by tests

from deepmd.dpmodel.output_def import (

Check warning on line 12 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L12

Added line #L12 was not covered by tests
FittingOutputDef,
)
from deepmd.entrypoints.convert_backend import (

Check warning on line 15 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L15

Added line #L15 was not covered by tests
convert_backend,
)
from deepmd.pt.model.model.model import (

Check warning on line 18 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L18

Added line #L18 was not covered by tests
BaseModel,
)


@BaseModel.register("frozen")
class FrozenModel(BaseModel):

Check warning on line 24 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L23-L24

Added lines #L23 - L24 were not covered by tests
"""Load model from a frozen model, which cannot be trained.

Parameters
----------
model_file : str
The path to the frozen model
"""

def __init__(self, model_file: str, **kwargs):
super().__init__(**kwargs)
self.model_file = model_file
if model_file.endswith(".pth"):
self.model = torch.jit.load(model_file)

Check warning on line 37 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L33-L37

Added lines #L33 - L37 were not covered by tests
else:
# try to convert from other formats
with tempfile.NamedTemporaryFile(suffix=".pth") as f:
convert_backend(INPUT=model_file, OUTPUT=f.name)
self.model = torch.jit.load(f.name)

Check warning on line 42 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L40-L42

Added lines #L40 - L42 were not covered by tests

@torch.jit.export
def fitting_output_def(self) -> FittingOutputDef:

Check warning on line 45 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L44-L45

Added lines #L44 - L45 were not covered by tests
"""Get the output def of developer implemented atomic models."""
return self.model.fitting_output_def()

Check warning on line 47 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L47

Added line #L47 was not covered by tests

@torch.jit.export
def get_rcut(self) -> float:

Check warning on line 50 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L49-L50

Added lines #L49 - L50 were not covered by tests
"""Get the cut-off radius."""
return self.model.get_rcut()

Check warning on line 52 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L52

Added line #L52 was not covered by tests

@torch.jit.export
def get_type_map(self) -> List[str]:

Check warning on line 55 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L54-L55

Added lines #L54 - L55 were not covered by tests
"""Get the type map."""
return self.model.get_type_map()

Check warning on line 57 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L57

Added line #L57 was not covered by tests

@torch.jit.export
def get_sel(self) -> List[int]:

Check warning on line 60 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L59-L60

Added lines #L59 - L60 were not covered by tests
"""Returns the number of selected atoms for each type."""
return self.model.get_sel()

Check warning on line 62 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L62

Added line #L62 was not covered by tests

@torch.jit.export
def get_dim_fparam(self) -> int:

Check warning on line 65 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L64-L65

Added lines #L64 - L65 were not covered by tests
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.model.get_dim_fparam()

Check warning on line 67 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L67

Added line #L67 was not covered by tests

@torch.jit.export
def get_dim_aparam(self) -> int:

Check warning on line 70 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L69-L70

Added lines #L69 - L70 were not covered by tests
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.model.get_dim_aparam()

Check warning on line 72 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L72

Added line #L72 was not covered by tests

@torch.jit.export
def get_sel_type(self) -> List[int]:

Check warning on line 75 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L74-L75

Added lines #L74 - L75 were not covered by tests
"""Get the selected atom types of this model.

Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""
return self.model.get_sel_type()

Check warning on line 82 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L82

Added line #L82 was not covered by tests

@torch.jit.export
def is_aparam_nall(self) -> bool:

Check warning on line 85 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L84-L85

Added lines #L84 - L85 were not covered by tests
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).
"""
return self.model.is_aparam_nall()

Check warning on line 90 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L90

Added line #L90 was not covered by tests

@torch.jit.export
def mixed_types(self) -> bool:

Check warning on line 93 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L92-L93

Added lines #L92 - L93 were not covered by tests
"""If true, the model
1. assumes total number of atoms aligned across frames;
2. uses a neighbor list that does not distinguish different atomic types.

If false, the model
1. assumes total number of atoms of each atom type aligned across frames;
2. uses a neighbor list that distinguishes different atomic types.

"""
return self.model.mixed_types()

Check warning on line 103 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L103

Added line #L103 was not covered by tests

@torch.jit.export
def forward(

Check warning on line 106 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L105-L106

Added lines #L105 - L106 were not covered by tests
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
return self.model.forward(

Check warning on line 115 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L115

Added line #L115 was not covered by tests
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

@torch.jit.export
def get_model_def_script(self) -> str:

Check warning on line 125 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L124-L125

Added lines #L124 - L125 were not covered by tests
"""Get the model definition script."""
# try to use the original script instead of "frozen model"
# Note: this cannot change the script of the parent model
# it may still try to load hard-coded filename, which might
# be a problem
return self.model.get_model_def_script()

Check warning on line 131 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L131

Added line #L131 was not covered by tests

def serialize(self) -> dict:
from deepmd.pt.model.model import (

Check warning on line 134 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L133-L134

Added lines #L133 - L134 were not covered by tests
get_model,
)

# try to recover the original model
model_def_script = json.loads(self.get_model_def_script())
model = get_model(model_def_script)
model.load_state_dict(self.model.state_dict())
return model.serialize()

Check warning on line 142 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L139-L142

Added lines #L139 - L142 were not covered by tests

@classmethod
def deserialize(cls, data: dict):
raise RuntimeError("Should not touch here.")

Check warning on line 146 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L144-L146

Added lines #L144 - L146 were not covered by tests

@torch.jit.export
def get_nnei(self) -> int:

Check warning on line 149 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L148-L149

Added lines #L148 - L149 were not covered by tests
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.model.get_nnei()

Check warning on line 151 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L151

Added line #L151 was not covered by tests

@torch.jit.export
def get_nsel(self) -> int:

Check warning on line 154 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L153-L154

Added lines #L153 - L154 were not covered by tests
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.model.get_nsel()

Check warning on line 156 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L156

Added line #L156 was not covered by tests

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):

Check warning on line 159 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L158-L159

Added lines #L158 - L159 were not covered by tests
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
return local_jdata

Check warning on line 169 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L169

Added line #L169 was not covered by tests

@torch.jit.export
def model_output_type(self) -> str:

Check warning on line 172 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L171-L172

Added lines #L171 - L172 were not covered by tests
"""Get the output type for the model."""
return self.model.model_output_type()

Check warning on line 174 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L174

Added line #L174 was not covered by tests
5 changes: 0 additions & 5 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,6 @@ def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return self.atomic_model.get_nnei()

@torch.jit.export
def get_model_def_script(self) -> str:
"""Get the model definition script."""
return self.atomic_model.get_model_def_script()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
def __init__(self, *args, **kwargs):
"""Construct a basic model for different tasks."""
torch.nn.Module.__init__(self)
self.model_def_script = ""

Check warning on line 20 in deepmd/pt/model/model/model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/model.py#L20

Added line #L20 was not covered by tests

def compute_or_load_stat(
self,
Expand All @@ -39,3 +40,8 @@
The path to the statistics files.
"""
raise NotImplementedError

@torch.jit.export
def get_model_def_script(self) -> str:

Check warning on line 45 in deepmd/pt/model/model/model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/model.py#L44-L45

Added lines #L44 - L45 were not covered by tests
"""Get the model definition script."""
return self.model_def_script

Check warning on line 47 in deepmd/pt/model/model/model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/model.py#L47

Added line #L47 was not covered by tests
2 changes: 1 addition & 1 deletion deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@
else:
self.atom_ener.append(None)
self.useBN = False
self.bias_atom_e = np.zeros(self.ntypes, dtype=np.float64)
self.bias_atom_e = np.zeros((self.ntypes, 1), dtype=np.float64)

Check warning on line 216 in deepmd/tf/fit/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/ener.py#L216

Added line #L216 was not covered by tests
# data requirement
if self.numb_fparam > 0:
add_data_requirement(
Expand Down
35 changes: 34 additions & 1 deletion deepmd/tf/model/frozen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import tempfile

Check warning on line 4 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L2-L4

Added lines #L2 - L4 were not covered by tests
from enum import (
Enum,
)
Expand All @@ -7,6 +10,9 @@
Union,
)

from deepmd.entrypoints.convert_backend import (

Check warning on line 13 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L13

Added line #L13 was not covered by tests
convert_backend,
)
from deepmd.infer.deep_pot import (
DeepPot,
)
Expand All @@ -24,6 +30,10 @@
from deepmd.tf.loss.loss import (
Loss,
)
from deepmd.tf.utils.graph import (

Check warning on line 33 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L33

Added line #L33 was not covered by tests
get_tensor_by_name_from_graph,
load_graph_def,
)

from .model import (
Model,
Expand All @@ -43,7 +53,14 @@
def __init__(self, model_file: str, **kwargs):
super().__init__(**kwargs)
self.model_file = model_file
self.model = DeepPotential(model_file)
if not model_file.endswith(".pb"):

Check warning on line 56 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L56

Added line #L56 was not covered by tests
# try to convert from other formats
with tempfile.NamedTemporaryFile(

Check warning on line 58 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L58

Added line #L58 was not covered by tests
suffix=".pb", dir=os.curdir, delete=False
) as f:
convert_backend(INPUT=model_file, OUTPUT=f.name)
self.model_file = f.name
self.model = DeepPotential(self.model_file)

Check warning on line 63 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L61-L63

Added lines #L61 - L63 were not covered by tests
if isinstance(self.model, DeepPot):
self.model_type = "ener"
else:
Expand Down Expand Up @@ -228,3 +245,19 @@
"""
# we don't know how to compress it, so no neighbor statistics here
return local_jdata

def serialize(self, suffix: str = "") -> dict:

Check warning on line 249 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L249

Added line #L249 was not covered by tests
# try to recover the original model
# the current graph contains a prefix "load",
# so it cannot used to recover the original model
graph, graph_def = load_graph_def(self.model_file)
t_jdata = get_tensor_by_name_from_graph(graph, "train_attr/training_script")
jdata = json.loads(t_jdata)
model = Model(**jdata["model"])

Check warning on line 256 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L253-L256

Added lines #L253 - L256 were not covered by tests
# important! must be called before serialize
model.init_variables(graph=graph, graph_def=graph_def)
return model.serialize()

Check warning on line 259 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L258-L259

Added lines #L258 - L259 were not covered by tests

@classmethod
def deserialize(cls, data: dict, suffix: str = ""):
raise RuntimeError("Should not touch here.")

Check warning on line 263 in deepmd/tf/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/frozen.py#L261-L263

Added lines #L261 - L263 were not covered by tests
3 changes: 2 additions & 1 deletion deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Model":
"""
if cls is Model:
return Model.get_class_by_type(data.get("type", "standard")).deserialize(
data
data,
suffix=suffix,
)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Expand Down
1 change: 0 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,6 @@ def frozen_model_args() -> Argument:
[
Argument("model_file", str, optional=False, doc=doc_model_file),
],
doc=doc_only_tf_supported,
)
return ca

Expand Down
Loading
Loading