Skip to content

Commit

Permalink
Merge branch 'devel' into rf_finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored May 30, 2024
2 parents bc8bdf8 + dd7f27a commit 297b5d6
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 5 deletions.
24 changes: 24 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,29 @@ def main_parser() -> argparse.ArgumentParser:
)
parser_convert_backend.add_argument("INPUT", help="The input model file.")
parser_convert_backend.add_argument("OUTPUT", help="The output model file.")

# * show model ******************************************************************
parser_show = subparsers.add_parser(
"show",
parents=[parser_log],
help="(Supported backend: PyTorch) Show the information of a model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp --pt show model.pt model-branch type-map descriptor fitting-net
dp --pt show frozen_model.pth type-map descriptor fitting-net
"""
),
)
parser_show.add_argument(
"INPUT", help="The input checkpoint file or frozen model file"
)
parser_show.add_argument(
"ATTRIBUTES",
choices=["model-branch", "type-map", "descriptor", "fitting-net"],
nargs="+",
)
return parser


Expand Down Expand Up @@ -807,6 +830,7 @@ def main():
"compress",
"convert-from",
"train-nvnmd",
"show",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
Expand Down
66 changes: 66 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from deepmd.pt.train import (
training,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.dataloader import (
DpLoaderSet,
)
Expand Down Expand Up @@ -282,6 +285,67 @@ def freeze(FLAGS):
)


def show(FLAGS):
if FLAGS.INPUT.split(".")[-1] == "pt":
state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
if "model" in state_dict:
state_dict = state_dict["model"]
model_params = state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.split(".")[-1] == "pth":
model_params_string = torch.jit.load(
FLAGS.INPUT, map_location=env.DEVICE
).model_def_script
model_params = json.loads(model_params_string)
else:
raise RuntimeError(
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
model_is_multi_task = "model_dict" in model_params
log.info("This is a multitask model") if model_is_multi_task else log.info(
"This is a singletask model"
)

if "model-branch" in FLAGS.ATTRIBUTES:
# The model must be multitask mode
if not model_is_multi_task:
raise RuntimeError(
"The 'model-branch' option requires a multitask model."
" The provided model does not meet this criterion."
)
model_branches = list(model_params["model_dict"].keys())
log.info(f"Available model branches are {model_branches}")
if "type-map" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
type_map = model_params["model_dict"][branch]["type_map"]
log.info(f"The type_map of branch {branch} is {type_map}")
else:
type_map = model_params["type_map"]
log.info(f"The type_map is {type_map}")
if "descriptor" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
descriptor = model_params["model_dict"][branch]["descriptor"]
log.info(f"The descriptor parameter of branch {branch} is {descriptor}")
else:
descriptor = model_params["descriptor"]
log.info(f"The descriptor parameter is {descriptor}")
if "fitting-net" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
fitting_net = model_params["model_dict"][branch]["fitting_net"]
log.info(
f"The fitting_net parameter of branch {branch} is {fitting_net}"
)
else:
fitting_net = model_params["fitting_net"]
log.info(f"The fitting_net parameter is {fitting_net}")


@record
def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if not isinstance(args, argparse.Namespace):
Expand All @@ -304,6 +368,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
elif FLAGS.command == "show":
show(FLAGS)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
9 changes: 5 additions & 4 deletions doc/train/finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ $ dp --pt train input.json --finetune multitask_pretrained.pt --model-branch CHO
```

:::{note}
To check the available model branches, you can typically refer to the documentation of the pre-trained model.
If you're still unsure about the available branches, you can try inputting an arbitrary branch name.
This will prompt an error message that displays a list of all the available model branches.
One can check the available model branches in multi-task pre-trained model by refering to the documentation of the pre-trained model or by using the following command:

```bash
$ dp --pt show multitask_pretrained.pt model-branch
```

Please note that this feature will be improved in the upcoming version to provide a more user-friendly experience.
:::

This command will start fine-tuning based on the pre-trained model's descriptor and the selected branch's fitting net.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ lmp = [
"lammps~=2023.8.2.3.0",
]
ipi = [
"i-PI",
"ipi",
]
gui = [
"dpgui",
Expand Down
29 changes: 29 additions & 0 deletions source/tests/pt/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.entrypoints.main import (
main,
)


def run_dp(cmd: str) -> int:
"""Run DP directly from the entry point instead of the subprocess.
It is quite slow to start DeePMD-kit with subprocess.
Parameters
----------
cmd : str
The command to run.
Returns
-------
int
Always returns 0.
"""
cmds = cmd.split()
if cmds[0] == "dp":
cmds = cmds[1:]
else:
raise RuntimeError("The command is not dp")

main(cmds)
return 0
199 changes: 199 additions & 0 deletions source/tests/pt/test_dp_show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import io
import json
import os
import shutil
import unittest
from contextlib import (
redirect_stderr,
)
from copy import (
deepcopy,
)
from pathlib import (
Path,
)

from deepmd.pt.entrypoints.main import (
get_trainer,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)

from .common import (
run_dp,
)
from .model.test_permutation import (
model_se_e2_a,
)


class TestSingleTaskModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
data_file = [str(Path(__file__).parent / "water/data/single")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_se_e2_a)
self.config["model"]["type_map"] = ["O", "H", "Au"]
trainer = get_trainer(deepcopy(self.config))
trainer.run()
run_dp("dp --pt freeze")

def test_checkpoint(self):
INPUT = "model.pt"
ATTRIBUTES = "type-map descriptor fitting-net"
with redirect_stderr(io.StringIO()) as f:
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a singletask model" in results[-4]
assert "The type_map is ['O', 'H', 'Au']" in results[-3]
assert (
"{'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut': 4.0"
) in results[-2]
assert (
"The fitting_net parameter is {'neuron': [24, 24, 24], 'resnet_dt': True, 'seed': 1}"
in results[-1]
)

def test_frozen_model(self):
INPUT = "frozen_model.pth"
ATTRIBUTES = "type-map descriptor fitting-net"
with redirect_stderr(io.StringIO()) as f:
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a singletask model" in results[-4]
assert "The type_map is ['O', 'H', 'Au']" in results[-3]
assert (
"{'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut': 4.0"
) in results[-2]
assert (
"The fitting_net parameter is {'neuron': [24, 24, 24], 'resnet_dt': True, 'seed': 1}"
in results[-1]
)

def test_checkpoint_error(self):
INPUT = "model.pt"
ATTRIBUTES = "model-branch type-map descriptor fitting-net"
with self.assertRaisesRegex(
RuntimeError, "The 'model-branch' option requires a multitask model"
):
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith("pt"):
os.remove(f)
if f in ["lcurve.out", "frozen_model.pth", "output.txt", "checkpoint"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)


class TestMultiTaskModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/multitask.json")
with open(input_json) as f:
self.config = json.load(f)
self.config["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[
"descriptor"
]
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.stat_files = "se_e2_a"
os.makedirs(self.stat_files, exist_ok=True)
self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = (
data_file
)
self.config["training"]["data_dict"]["model_1"]["validation_data"][
"systems"
] = data_file
self.config["training"]["data_dict"]["model_1"]["stat_file"] = (
f"{self.stat_files}/model_1"
)
self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = (
data_file
)
self.config["training"]["data_dict"]["model_2"]["validation_data"][
"systems"
] = data_file
self.config["training"]["data_dict"]["model_2"]["stat_file"] = (
f"{self.stat_files}/model_2"
)
self.config["model"]["model_dict"]["model_1"]["fitting_net"] = {
"neuron": [1, 2, 3],
"seed": 678,
}
self.config["model"]["model_dict"]["model_2"]["fitting_net"] = {
"neuron": [9, 8, 7],
"seed": 1111,
}
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.origin_config = deepcopy(self.config)
self.config["model"], self.shared_links = preprocess_shared_params(
self.config["model"]
)
trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
trainer.run()
run_dp("dp --pt freeze --head model_1")

def test_checkpoint(self):
INPUT = "model.ckpt.pt"
ATTRIBUTES = "model-branch type-map descriptor fitting-net"
with redirect_stderr(io.StringIO()) as f:
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a multitask model" in results[-8]
assert "Available model branches are ['model_1', 'model_2']" in results[-7]
assert "The type_map of branch model_1 is ['O', 'H', 'B']" in results[-6]
assert "The type_map of branch model_2 is ['O', 'H', 'B']" in results[-5]
assert (
"model_1"
and "'type': 'se_e2_a'"
and "'sel': [46, 92, 4]"
and "'rcut_smth': 0.5"
) in results[-4]
assert (
"model_2"
and "'type': 'se_e2_a'"
and "'sel': [46, 92, 4]"
and "'rcut_smth': 0.5"
) in results[-3]
assert (
"The fitting_net parameter of branch model_1 is {'neuron': [1, 2, 3], 'seed': 678}"
in results[-2]
)
assert (
"The fitting_net parameter of branch model_2 is {'neuron': [9, 8, 7], 'seed': 1111}"
in results[-1]
)

def test_frozen_model(self):
INPUT = "frozen_model.pth"
ATTRIBUTES = "type-map descriptor fitting-net"
with redirect_stderr(io.StringIO()) as f:
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a singletask model" in results[-4]
assert "The type_map is ['O', 'H', 'B']" in results[-3]
assert (
"'type': 'se_e2_a'" and "'sel': [46, 92, 4]" and "'rcut_smth': 0.5"
) in results[-2]
assert (
"The fitting_net parameter is {'neuron': [1, 2, 3], 'seed': 678}"
in results[-1]
)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith("pt"):
os.remove(f)
if f in ["lcurve.out", "frozen_model.pth", "checkpoint", "output.txt"]:
os.remove(f)
if f in ["stat_files", self.stat_files]:
shutil.rmtree(f)

0 comments on commit 297b5d6

Please sign in to comment.