diff --git a/deepmd/main.py b/deepmd/main.py index 43059a41c6..e6f4de6359 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -756,6 +756,22 @@ def main_parser() -> argparse.ArgumentParser: ) parser_convert_backend.add_argument("INPUT", help="The input model file.") parser_convert_backend.add_argument("OUTPUT", help="The output model file.") + + # check available model branches + parser_list_model_branch = subparsers.add_parser( + "list-model-branch", + parents=[parser_log], + help="Check the available model branches in multi-task pre-trained model", + formatter_class=RawTextArgumentDefaultsHelpFormatter, + epilog=textwrap.dedent( + """\ + examples: + dp --pt list-model-branch model.pt + """ + ), + ) + parser_list_model_branch.add_argument("INPUT", help="The input multi-task pre-trained model file") + return parser @@ -813,6 +829,7 @@ def main(): "compress", "convert-from", "train-nvnmd", + "list-model-branch" ): deepmd_main = BACKENDS[args.backend]().entry_point_hook elif args.command is None: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index faa606e986..b47f07f847 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -74,6 +74,7 @@ from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter log = logging.getLogger(__name__) +from IPython import embed def get_trainer( @@ -336,6 +337,16 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) freeze(FLAGS) + elif FLAGS.command == "list-model-branch": + state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) + if "model" in state_dict: + state_dict = state_dict["model"] + model_params = state_dict["_extra_state"]["model_params"] + finetune_from_multi_task = "model_dict" in model_params + # Pretrained model must be multitask mode + assert finetune_from_multi_task, "When using --list-model-branch, the pretrained model must be multitask model" + model_branch = list(model_params["model_dict"].keys()) + log.info(f"Available model branches are {model_branch}") else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 4056b30d87..2ba7789821 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -83,6 +83,7 @@ ) log = logging.getLogger(__name__) +from IPython import embed class Trainer: @@ -518,6 +519,7 @@ def update_single_finetune_params( for i in _random_state_dict.keys() if i != "_extra_state" and f".{_model_key}." in i ] + embed() for item_key in target_keys: if _new_fitting and ".fitting_net." in item_key: # print(f'Keep {item_key} in old model!') diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 2de4214070..b8f089ca6a 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -5,6 +5,7 @@ ) import torch +from IPython import embed from deepmd.pt.utils import ( env,