Skip to content

Commit

Permalink
fix(pt): make state_dict safe for weights_only (#4148)
Browse files Browse the repository at this point in the history
See #4147 and #4143.
We can first make `state_dict` safe for `weights_only`, then make a
breaking change when loading `state_dict` in the future.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced model saving functionality by ensuring learning rates are
consistently stored as floats, improving type consistency.
  
- **Bug Fixes**
- Updated model loading behavior in tests to focus solely on model
weights, which may resolve issues related to state dictionary loading.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Sep 21, 2024
1 parent d224fdd commit 532e309
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
7 changes: 5 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,10 +1030,13 @@ def save_model(self, save_path, lr=0.0, step=0):
if dist.is_available() and dist.is_initialized()
else self.wrapper
)
module.train_infos["lr"] = lr
module.train_infos["lr"] = float(lr)
module.train_infos["step"] = step
optim_state_dict = deepcopy(self.optimizer.state_dict())
for item in optim_state_dict["param_groups"]:
item["lr"] = float(item["lr"])
torch.save(
{"model": module.state_dict(), "optimizer": self.optimizer.state_dict()},
{"model": module.state_dict(), "optimizer": optim_state_dict},
save_path,
)
checkpoint_dir = save_path.parent
Expand Down
10 changes: 7 additions & 3 deletions source/tests/pt/test_change_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def test_change_bias_with_data(self):
run_dp(
f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}"
)
state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE)
state_dict = torch.load(
str(self.model_path_data_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
wrapper = ModelWrapper(model_for_wrapper)
Expand All @@ -114,7 +116,7 @@ def test_change_bias_with_data_sys_file(self):
f"dp --pt change-bias {self.model_path!s} -f {tmp_file.name} -o {self.model_path_data_file_bias!s}"
)
state_dict = torch.load(
str(self.model_path_data_file_bias), map_location=DEVICE
str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
Expand All @@ -134,7 +136,9 @@ def test_change_bias_with_user_defined(self):
run_dp(
f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}"
)
state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE)
state_dict = torch.load(
str(self.model_path_user_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
wrapper = ModelWrapper(model_for_wrapper)
Expand Down

0 comments on commit 532e309

Please sign in to comment.