Skip to content

Commit

Permalink
Merge branch '3741' of github.com:Chengqian-Zhang/deepmd-kit into 3742
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed May 20, 2024
2 parents 065cc2d + 8a8f027 commit 84e8698
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 0 deletions.
17 changes: 17 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,22 @@ def main_parser() -> argparse.ArgumentParser:
)
parser_convert_backend.add_argument("INPUT", help="The input model file.")
parser_convert_backend.add_argument("OUTPUT", help="The output model file.")

# check available model branches
parser_list_model_branch = subparsers.add_parser(
"list-model-branch",
parents=[parser_log],
help="Check the available model branches in multi-task pre-trained model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp --pt list-model-branch model.pt
"""
),
)
parser_list_model_branch.add_argument("INPUT", help="The input multi-task pre-trained model file")

return parser


Expand Down Expand Up @@ -813,6 +829,7 @@ def main():
"compress",
"convert-from",
"train-nvnmd",
"list-model-branch"
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
Expand Down
11 changes: 11 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter

log = logging.getLogger(__name__)
from IPython import embed


def get_trainer(
Expand Down Expand Up @@ -336,6 +337,16 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
elif FLAGS.command == "list-model-branch":
state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
if "model" in state_dict:
state_dict = state_dict["model"]
model_params = state_dict["_extra_state"]["model_params"]
finetune_from_multi_task = "model_dict" in model_params
# Pretrained model must be multitask mode
assert finetune_from_multi_task, "When using --list-model-branch, the pretrained model must be multitask model"
model_branch = list(model_params["model_dict"].keys())
log.info(f"Available model branches are {model_branch}")
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
)

log = logging.getLogger(__name__)
from IPython import embed


class Trainer:
Expand Down Expand Up @@ -518,6 +519,7 @@ def update_single_finetune_params(
for i in _random_state_dict.keys()
if i != "_extra_state" and f".{_model_key}." in i
]
embed()
for item_key in target_keys:
if _new_fitting and ".fitting_net." in item_key:
# print(f'Keep {item_key} in old model!')
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)

import torch
from IPython import embed

from deepmd.pt.utils import (
env,
Expand Down

0 comments on commit 84e8698

Please sign in to comment.