Skip to content

Commit

Permalink
Merge branch 'devel' into data_load
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Feb 11, 2024
2 parents 0bcd55c + c131c8f commit 5e8734c
Show file tree
Hide file tree
Showing 15 changed files with 1,157 additions and 1,119 deletions.
1,056 changes: 1,056 additions & 0 deletions deepmd/entrypoints/test.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def main_parser() -> argparse.ArgumentParser:
"--checkpoint",
type=str,
default=".",
help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing checkpoint, or a pt file",
help="Path to checkpoint, either a folder containing checkpoint or the checkpoint prefix",
)
parser_frz.add_argument(
"-o",
Expand Down
21 changes: 5 additions & 16 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from deepmd.entrypoints.gui import (
start_dpgui,
)
from deepmd.entrypoints.test import (
test,
)
from deepmd.infer.model_devi import (
make_model_devi,
)
Expand Down Expand Up @@ -269,20 +272,6 @@ def train(FLAGS):
trainer.run()


def test(FLAGS):
trainer = inference.Tester(
FLAGS.model,
input_script=FLAGS.input_script,
system=FLAGS.system,
datafile=FLAGS.datafile,
numb_test=FLAGS.numb_test,
detail_file=FLAGS.detail_file,
shuffle_test=FLAGS.shuffle_test,
head=FLAGS.head,
)
trainer.run()


def freeze(FLAGS):
model = torch.jit.script(
inference.Tester(FLAGS.model, numb_test=1, head=FLAGS.head).model
Expand Down Expand Up @@ -312,8 +301,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if FLAGS.command == "train":
train(FLAGS)
elif FLAGS.command == "test":
FLAGS.output = str(Path(FLAGS.model).with_suffix(".pt"))
test(FLAGS)
dict_args["output"] = str(Path(FLAGS.model).with_suffix(".pt"))
test(**dict_args)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
Expand Down
33 changes: 30 additions & 3 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,26 @@
OutputVariableCategory,
OutputVariableDef,
)
from deepmd.infer.deep_dipole import (
DeepDipole,
)
from deepmd.infer.deep_dos import (
DeepDOS,
)
from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper
from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.infer.deep_polar import (
DeepGlobalPolar,
DeepPolar,
)
from deepmd.infer.deep_pot import (
DeepPot,
)
from deepmd.infer.deep_wfc import (
DeepWFC,
)
from deepmd.pt.model.model import (
get_model,
)
Expand All @@ -44,8 +58,6 @@
if TYPE_CHECKING:
import ase.neighborlist

from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper


class DeepEval(DeepEvalBackend):
"""PyTorch backend implementaion of DeepEval.
Expand Down Expand Up @@ -127,7 +139,22 @@ def get_dim_aparam(self) -> int:
@property
def model_type(self) -> "DeepEvalWrapper":
"""The the evaluator of the model type."""
return DeepPot
output_def = self.dp.model["Default"].model_output_def()
var_defs = output_def.var_defs
if "energy" in var_defs:
return DeepPot
elif "dos" in var_defs:
return DeepDOS
elif "dipole" in var_defs:
return DeepDipole
elif "polar" in var_defs:
return DeepPolar
elif "global_polar" in var_defs:
return DeepGlobalPolar
elif "wfc" in var_defs:
return DeepWFC
else:
raise RuntimeError("Unknown model type")

def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

log = logging.getLogger(__name__)

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'log' is unnecessary as it is
redefined
before this value is used.

log = logging.getLogger(__name__)


@Descriptor.register("se_e2_a")
class DescrptSeA(Descriptor):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
**kwargs,
)

@torch.jit.export
def model_output_def(self):
"""Get the output def for the model."""
return ModelOutputDef(self.fitting_output_def())
Expand Down
12 changes: 6 additions & 6 deletions deepmd/pt/model/model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def __init__(
tab_info,
tab_data,
) = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)
self.register_buffer("tab_info", torch.from_numpy(tab_info))
self.register_buffer("tab_data", torch.from_numpy(tab_data))
else:
self.tab_info = None
self.tab_data = None
self.register_buffer("tab_info", None)
self.register_buffer("tab_data", None)

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class
Expand Down Expand Up @@ -118,8 +118,8 @@ def deserialize(cls, data) -> "PairTabModel":
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
tab_model.tab = tab
tab_model.tab_info = torch.from_numpy(tab_model.tab.tab_info)
tab_model.tab_data = torch.from_numpy(tab_model.tab.tab_data)
tab_model.register_buffer("tab_info", torch.from_numpy(tab_model.tab.tab_info))
tab_model.register_buffer("tab_data", torch.from_numpy(tab_model.tab.tab_data))
return tab_model

def forward_atomic(
Expand Down
12 changes: 9 additions & 3 deletions deepmd/tf/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from os.path import (
abspath,
)
from pathlib import (
Path,
)
from typing import (
List,
Optional,
Expand Down Expand Up @@ -479,7 +482,7 @@ def freeze(
Parameters
----------
checkpoint_folder : str
location of the folder with model
location of either the folder with checkpoint or the checkpoint prefix
output : str
output file name
node_names : Optional[str], optional
Expand All @@ -492,8 +495,11 @@ def freeze(
other arguments
"""
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(checkpoint_folder)
input_checkpoint = checkpoint.model_checkpoint_path
if Path(checkpoint_folder).is_dir():
checkpoint = tf.train.get_checkpoint_state(checkpoint_folder)
input_checkpoint = checkpoint.model_checkpoint_path
else:
input_checkpoint = checkpoint_folder

# expand the output file to full path
output_graph = abspath(output)
Expand Down
Loading

0 comments on commit 5e8734c

Please sign in to comment.