Skip to content

Commit

Permalink
feat(pt): allow using directories to store stat (#3633)
Browse files Browse the repository at this point in the history
If the stat file ends with .h5 or .hdf5, an HDF5 file is used;
otherwise, a directory containing NumPy binary files is used. This gives
flexibility to users.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Apr 2, 2024
1 parent 7d67e0d commit e752e27
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 17 deletions.
13 changes: 6 additions & 7 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,12 @@ def prepare_trainer_input_single(
if rank != 0:
stat_file_path_single = None
elif stat_file_path_single is not None:
if Path(stat_file_path_single).is_dir():
raise ValueError(
f"stat_file should be a file, not a directory: {stat_file_path_single}"
)
if not Path(stat_file_path_single).is_file():
with h5py.File(stat_file_path_single, "w") as f:
pass
if not Path(stat_file_path_single).exists():
if stat_file_path_single.endswith((".h5", ".hdf5")):
with h5py.File(stat_file_path_single, "w") as f:
pass
else:
Path(stat_file_path_single).mkdir()
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
Expand Down
4 changes: 3 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2179,7 +2179,9 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
doc_stat_file = (
"The file path for saving the data statistics results. "
"If set, the results will be saved and directly loaded during the next training session, "
"avoiding the need to recalculate the statistics"
"avoiding the need to recalculate the statistics. "
"If the file extension is .h5 or .hdf5, an HDF5 file is used to store the statistics; "
"otherwise, a directory containing NumPy binary files are used."
)
doc_opt_type = "The type of optimizer to use."
doc_kf_blocksize = "The blocksize for the Kalman filter."
Expand Down
6 changes: 3 additions & 3 deletions deepmd/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def save_stats(self, path: DPPath) -> None:
Parameters
----------
path : DPH5Path
path : DPPath
The path to save the statistics of the environment matrix.
"""
if len(self.stats) == 0:
Expand All @@ -146,7 +146,7 @@ def load_stats(self, path: DPPath) -> None:
Parameters
----------
path : DPH5Path
path : DPPath
The path to load the statistics of the environment matrix.
"""
if len(self.stats) > 0:
Expand All @@ -166,7 +166,7 @@ def load_or_compute_stats(
Parameters
----------
path : DPH5Path
path : DPPath
The path to load the statistics of the environment matrix.
data : List[Dict[str, np.ndarray]]
The environment matrix.
Expand Down
2 changes: 1 addition & 1 deletion examples/water/dpa2/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./dpa2",
"stat_file": "./dpa2.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion examples/water/se_atten/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./dpa1",
"stat_file": "./dpa1.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion examples/water/se_e2_a/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./se_e2_a",
"stat_file": "./se_e2_a.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_saveload_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def create_wrapper(self, read: bool):
model_config = copy.deepcopy(self.config["model"])
model_config["resuming"] = read
model_config["stat_file_dir"] = "stat_files"
model_config["stat_file"] = "stat.npz"
model_config["stat_file"] = "stat.hdf5"
model_config["stat_file_path"] = os.path.join(
model_config["stat_file_dir"], model_config["stat_file"]
)
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/water/multitask.json
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
},
"data_dict": {
"model_1": {
"stat_file": "./stat_files/model_1",
"stat_file": "./stat_files/model_1.hdf5",
"training_data": {
"systems": [
"pt/water/data/data_0"
Expand All @@ -112,7 +112,7 @@
}
},
"model_2": {
"stat_file": "./stat_files/model_2",
"stat_file": "./stat_files/model_2.hdf5",
"training_data": {
"systems": [
"pt/water/data/data_0"
Expand Down

0 comments on commit e752e27

Please sign in to comment.