Skip to content

Commit

Permalink
pt: support --init-frz-model
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 28, 2024
1 parent b1de9e6 commit 2166d11
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def main_parser() -> argparse.ArgumentParser:
"--init-frz-model",
type=str,
default=None,
help="(Supported backend: TensorFlow) Initialize the training from the frozen model.",
help="Initialize the training from the frozen model.",
)
parser_train_subgroup.add_argument(
"-t",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_trainer(
finetune_model=None,
model_branch="",
force_load=False,
init_frz_model=None,
):
# Initialize DDP
local_rank = os.environ.get("LOCAL_RANK")
Expand Down Expand Up @@ -200,6 +201,7 @@ def prepare_trainer_input_single(
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
init_frz_model=init_frz_model,
)
return trainer

Expand Down Expand Up @@ -243,6 +245,7 @@ def train(FLAGS):
FLAGS.finetune,
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
)
trainer.run()

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
finetune_model=None,
force_load=False,
shared_links=None,
init_frz_model=None,
):
"""Construct a DeePMD trainer.
Expand Down Expand Up @@ -394,6 +395,9 @@ def get_loss(loss_params, start_lr, _ntypes):
ntest=ntest,
bias_shift=model_params.get("bias_shift", "delta"),
)
if init_frz_model is not None:
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
self.model.load_state_dict(frz_model.state_dict())

Check warning on line 400 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L399-L400

Added lines #L399 - L400 were not covered by tests

# Set trainable params
self.wrapper.set_trainable_params()
Expand Down

0 comments on commit 2166d11

Please sign in to comment.