-
Notifications
You must be signed in to change notification settings - Fork 523
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(pt): consistent "frozen" model (#3450)
This PR is based on #3449, as the test needs #3449 to pass. Add a consistent `frozen` model in pt. Both TF and PT now support using models in any format. --------- Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
9 changed files
with
387 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import json | ||
import tempfile | ||
from typing import ( | ||
Dict, | ||
List, | ||
Optional, | ||
) | ||
|
||
import torch | ||
|
||
from deepmd.dpmodel.output_def import ( | ||
FittingOutputDef, | ||
) | ||
from deepmd.entrypoints.convert_backend import ( | ||
convert_backend, | ||
) | ||
from deepmd.pt.model.model.model import ( | ||
BaseModel, | ||
) | ||
|
||
|
||
@BaseModel.register("frozen") | ||
class FrozenModel(BaseModel): | ||
"""Load model from a frozen model, which cannot be trained. | ||
Parameters | ||
---------- | ||
model_file : str | ||
The path to the frozen model | ||
""" | ||
|
||
def __init__(self, model_file: str, **kwargs): | ||
super().__init__(**kwargs) | ||
self.model_file = model_file | ||
if model_file.endswith(".pth"): | ||
self.model = torch.jit.load(model_file) | ||
else: | ||
# try to convert from other formats | ||
with tempfile.NamedTemporaryFile(suffix=".pth") as f: | ||
convert_backend(INPUT=model_file, OUTPUT=f.name) | ||
self.model = torch.jit.load(f.name) | ||
|
||
@torch.jit.export | ||
def fitting_output_def(self) -> FittingOutputDef: | ||
"""Get the output def of developer implemented atomic models.""" | ||
return self.model.fitting_output_def() | ||
|
||
@torch.jit.export | ||
def get_rcut(self) -> float: | ||
"""Get the cut-off radius.""" | ||
return self.model.get_rcut() | ||
|
||
@torch.jit.export | ||
def get_type_map(self) -> List[str]: | ||
"""Get the type map.""" | ||
return self.model.get_type_map() | ||
|
||
@torch.jit.export | ||
def get_sel(self) -> List[int]: | ||
"""Returns the number of selected atoms for each type.""" | ||
return self.model.get_sel() | ||
|
||
@torch.jit.export | ||
def get_dim_fparam(self) -> int: | ||
"""Get the number (dimension) of frame parameters of this atomic model.""" | ||
return self.model.get_dim_fparam() | ||
|
||
@torch.jit.export | ||
def get_dim_aparam(self) -> int: | ||
"""Get the number (dimension) of atomic parameters of this atomic model.""" | ||
return self.model.get_dim_aparam() | ||
|
||
@torch.jit.export | ||
def get_sel_type(self) -> List[int]: | ||
"""Get the selected atom types of this model. | ||
Only atoms with selected atom types have atomic contribution | ||
to the result of the model. | ||
If returning an empty list, all atom types are selected. | ||
""" | ||
return self.model.get_sel_type() | ||
|
||
@torch.jit.export | ||
def is_aparam_nall(self) -> bool: | ||
"""Check whether the shape of atomic parameters is (nframes, nall, ndim). | ||
If False, the shape is (nframes, nloc, ndim). | ||
""" | ||
return self.model.is_aparam_nall() | ||
|
||
@torch.jit.export | ||
def mixed_types(self) -> bool: | ||
"""If true, the model | ||
1. assumes total number of atoms aligned across frames; | ||
2. uses a neighbor list that does not distinguish different atomic types. | ||
If false, the model | ||
1. assumes total number of atoms of each atom type aligned across frames; | ||
2. uses a neighbor list that distinguishes different atomic types. | ||
""" | ||
return self.model.mixed_types() | ||
|
||
@torch.jit.export | ||
def forward( | ||
self, | ||
coord, | ||
atype, | ||
box: Optional[torch.Tensor] = None, | ||
fparam: Optional[torch.Tensor] = None, | ||
aparam: Optional[torch.Tensor] = None, | ||
do_atomic_virial: bool = False, | ||
) -> Dict[str, torch.Tensor]: | ||
return self.model.forward( | ||
coord, | ||
atype, | ||
box=box, | ||
fparam=fparam, | ||
aparam=aparam, | ||
do_atomic_virial=do_atomic_virial, | ||
) | ||
|
||
@torch.jit.export | ||
def get_model_def_script(self) -> str: | ||
"""Get the model definition script.""" | ||
# try to use the original script instead of "frozen model" | ||
# Note: this cannot change the script of the parent model | ||
# it may still try to load hard-coded filename, which might | ||
# be a problem | ||
return self.model.get_model_def_script() | ||
|
||
def serialize(self) -> dict: | ||
from deepmd.pt.model.model import ( | ||
get_model, | ||
) | ||
|
||
# try to recover the original model | ||
model_def_script = json.loads(self.get_model_def_script()) | ||
model = get_model(model_def_script) | ||
model.load_state_dict(self.model.state_dict()) | ||
return model.serialize() | ||
|
||
@classmethod | ||
def deserialize(cls, data: dict): | ||
raise RuntimeError("Should not touch here.") | ||
|
||
@torch.jit.export | ||
def get_nnei(self) -> int: | ||
"""Returns the total number of selected neighboring atoms in the cut-off radius.""" | ||
return self.model.get_nnei() | ||
|
||
@torch.jit.export | ||
def get_nsel(self) -> int: | ||
"""Returns the total number of selected neighboring atoms in the cut-off radius.""" | ||
return self.model.get_nsel() | ||
|
||
@classmethod | ||
def update_sel(cls, global_jdata: dict, local_jdata: dict): | ||
"""Update the selection and perform neighbor statistics. | ||
Parameters | ||
---------- | ||
global_jdata : dict | ||
The global data, containing the training section | ||
local_jdata : dict | ||
The local data refer to the current class | ||
""" | ||
return local_jdata | ||
|
||
@torch.jit.export | ||
def model_output_type(self) -> str: | ||
"""Get the output type for the model.""" | ||
return self.model.model_output_type() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.