From 6a8b1b435f71770a3dd6bf50ab4b607f33b35a6e Mon Sep 17 00:00:00 2001 From: Chengqian-Zhang <2000011006@stu.pku.edu.cn> Date: Mon, 20 May 2024 06:12:36 +0000 Subject: [PATCH 1/2] init 3741 --- deepmd/pt/train/training.py | 2 ++ deepmd/pt/utils/finetune.py | 1 + 2 files changed, 3 insertions(+) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 4056b30d87..9056fee19a 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -83,6 +83,7 @@ ) log = logging.getLogger(__name__) +from IPython import embed class Trainer: @@ -533,6 +534,7 @@ def update_single_finetune_params( _origin_state_dict[new_key].clone().detach() ) + embed() if not self.multi_task: model_key = "Default" model_key_from = self.finetune_links[model_key] diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 2de4214070..b8f089ca6a 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -5,6 +5,7 @@ ) import torch +from IPython import embed from deepmd.pt.utils import ( env, From 8a8f027be774d9e9c2b0154b3584e2cd3c076634 Mon Sep 17 00:00:00 2001 From: Chengqian-Zhang <2000011006@stu.pku.edu.cn> Date: Mon, 20 May 2024 09:03:50 +0000 Subject: [PATCH 2/2] Change list-model-branch to command --- deepmd/main.py | 17 +++++++++++++++++ deepmd/pt/entrypoints/main.py | 11 +++++++++++ deepmd/pt/train/training.py | 2 +- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/deepmd/main.py b/deepmd/main.py index 43059a41c6..e6f4de6359 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -756,6 +756,22 @@ def main_parser() -> argparse.ArgumentParser: ) parser_convert_backend.add_argument("INPUT", help="The input model file.") parser_convert_backend.add_argument("OUTPUT", help="The output model file.") + + # check available model branches + parser_list_model_branch = subparsers.add_parser( + "list-model-branch", + parents=[parser_log], + help="Check the available model branches in multi-task pre-trained model", + formatter_class=RawTextArgumentDefaultsHelpFormatter, + epilog=textwrap.dedent( + """\ + examples: + dp --pt list-model-branch model.pt + """ + ), + ) + parser_list_model_branch.add_argument("INPUT", help="The input multi-task pre-trained model file") + return parser @@ -813,6 +829,7 @@ def main(): "compress", "convert-from", "train-nvnmd", + "list-model-branch" ): deepmd_main = BACKENDS[args.backend]().entry_point_hook elif args.command is None: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 518a11987d..3013c35253 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -74,6 +74,7 @@ from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter log = logging.getLogger(__name__) +from IPython import embed def get_trainer( @@ -336,6 +337,16 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) freeze(FLAGS) + elif FLAGS.command == "list-model-branch": + state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) + if "model" in state_dict: + state_dict = state_dict["model"] + model_params = state_dict["_extra_state"]["model_params"] + finetune_from_multi_task = "model_dict" in model_params + # Pretrained model must be multitask mode + assert finetune_from_multi_task, "When using --list-model-branch, the pretrained model must be multitask model" + model_branch = list(model_params["model_dict"].keys()) + log.info(f"Available model branches are {model_branch}") else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 9056fee19a..2ba7789821 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -519,6 +519,7 @@ def update_single_finetune_params( for i in _random_state_dict.keys() if i != "_extra_state" and f".{_model_key}." in i ] + embed() for item_key in target_keys: if _new_fitting and ".fitting_net." in item_key: # print(f'Keep {item_key} in old model!') @@ -534,7 +535,6 @@ def update_single_finetune_params( _origin_state_dict[new_key].clone().detach() ) - embed() if not self.multi_task: model_key = "Default" model_key_from = self.finetune_links[model_key]