Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): allow using directories to store stat #3633

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,12 @@
if rank != 0:
stat_file_path_single = None
elif stat_file_path_single is not None:
if Path(stat_file_path_single).is_dir():
raise ValueError(
f"stat_file should be a file, not a directory: {stat_file_path_single}"
)
if not Path(stat_file_path_single).is_file():
with h5py.File(stat_file_path_single, "w") as f:
pass
if not Path(stat_file_path_single).exists():
if stat_file_path_single.endswith((".h5", ".hdf5")):
with h5py.File(stat_file_path_single, "w") as f:
pass

Check warning on line 129 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L128-L129

Added lines #L128 - L129 were not covered by tests
else:
Path(stat_file_path_single).mkdir()
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
Expand Down
4 changes: 3 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2179,7 +2179,9 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
doc_stat_file = (
"The file path for saving the data statistics results. "
"If set, the results will be saved and directly loaded during the next training session, "
"avoiding the need to recalculate the statistics"
"avoiding the need to recalculate the statistics. "
"If the file extension is .h5 or .hdf5, an HDF5 file is used to store the statistics; "
"otherwise, a directory containing NumPy binary files are used."
)
doc_opt_type = "The type of optimizer to use."
doc_kf_blocksize = "The blocksize for the Kalman filter."
Expand Down
6 changes: 3 additions & 3 deletions deepmd/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def save_stats(self, path: DPPath) -> None:

Parameters
----------
path : DPH5Path
path : DPPath
The path to save the statistics of the environment matrix.
"""
if len(self.stats) == 0:
Expand All @@ -146,7 +146,7 @@ def load_stats(self, path: DPPath) -> None:

Parameters
----------
path : DPH5Path
path : DPPath
The path to load the statistics of the environment matrix.
"""
if len(self.stats) > 0:
Expand All @@ -166,7 +166,7 @@ def load_or_compute_stats(

Parameters
----------
path : DPH5Path
path : DPPath
The path to load the statistics of the environment matrix.
data : List[Dict[str, np.ndarray]]
The environment matrix.
Expand Down
2 changes: 1 addition & 1 deletion examples/water/dpa2/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./dpa2",
"stat_file": "./dpa2.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion examples/water/se_atten/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./dpa1",
"stat_file": "./dpa1.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion examples/water/se_e2_a/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./se_e2_a",
"stat_file": "./se_e2_a.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_saveload_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def create_wrapper(self, read: bool):
model_config = copy.deepcopy(self.config["model"])
model_config["resuming"] = read
model_config["stat_file_dir"] = "stat_files"
model_config["stat_file"] = "stat.npz"
model_config["stat_file"] = "stat.hdf5"
model_config["stat_file_path"] = os.path.join(
model_config["stat_file_dir"], model_config["stat_file"]
)
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/water/multitask.json
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
},
"data_dict": {
"model_1": {
"stat_file": "./stat_files/model_1",
"stat_file": "./stat_files/model_1.hdf5",
"training_data": {
"systems": [
"pt/water/data/data_0"
Expand All @@ -112,7 +112,7 @@
}
},
"model_2": {
"stat_file": "./stat_files/model_2",
"stat_file": "./stat_files/model_2.hdf5",
"training_data": {
"systems": [
"pt/water/data/data_0"
Expand Down