Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 17, 2024
1 parent 68a3cd1 commit 7bbb2e6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ def train(
# update init_model or init_frz_model config if necessary
if (init_model is not None or init_frz_model is not None) and use_pretrain_script:
if init_model is not None:
init_state_dict = torch.load(init_model, map_location=DEVICE, weights_only=True)
init_state_dict = torch.load(
init_model, map_location=DEVICE, weights_only=True
)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]
Expand Down Expand Up @@ -380,7 +382,9 @@ def change_bias(
output: Optional[str] = None,
):
if input_file.endswith(".pt"):
old_state_dict = torch.load(input_file, map_location=env.DEVICE, weights_only=True)
old_state_dict = torch.load(
input_file, map_location=env.DEVICE, weights_only=True
)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
model_params = model_state_dict["_extra_state"]["model_params"]
elif input_file.endswith(".pth"):
Expand Down

0 comments on commit 7bbb2e6

Please sign in to comment.