diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index adaec0968a..77091c3cf7 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -123,13 +123,12 @@ def prepare_trainer_input_single( if rank != 0: stat_file_path_single = None elif stat_file_path_single is not None: - if Path(stat_file_path_single).is_dir(): - raise ValueError( - f"stat_file should be a file, not a directory: {stat_file_path_single}" - ) - if not Path(stat_file_path_single).is_file(): - with h5py.File(stat_file_path_single, "w") as f: - pass + if not Path(stat_file_path_single).exists(): + if stat_file_path_single.endswith((".h5", ".hdf5")): + with h5py.File(stat_file_path_single, "w") as f: + pass + else: + Path(stat_file_path_single).mkdir() stat_file_path_single = DPPath(stat_file_path_single, "a") # validation and training data diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 2a98bee6fe..07add486c1 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2179,7 +2179,9 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. doc_stat_file = ( "The file path for saving the data statistics results. " "If set, the results will be saved and directly loaded during the next training session, " - "avoiding the need to recalculate the statistics" + "avoiding the need to recalculate the statistics. " + "If the file extension is .h5 or .hdf5, an HDF5 file is used to store the statistics; " + "otherwise, a directory containing NumPy binary files are used." ) doc_opt_type = "The type of optimizer to use." doc_kf_blocksize = "The blocksize for the Kalman filter." diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index 217c46844b..bbb43fd703 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -132,7 +132,7 @@ def save_stats(self, path: DPPath) -> None: Parameters ---------- - path : DPH5Path + path : DPPath The path to save the statistics of the environment matrix. """ if len(self.stats) == 0: @@ -146,7 +146,7 @@ def load_stats(self, path: DPPath) -> None: Parameters ---------- - path : DPH5Path + path : DPPath The path to load the statistics of the environment matrix. """ if len(self.stats) > 0: @@ -166,7 +166,7 @@ def load_or_compute_stats( Parameters ---------- - path : DPH5Path + path : DPPath The path to load the statistics of the environment matrix. data : List[Dict[str, np.ndarray]] The environment matrix. diff --git a/examples/water/dpa2/input_torch.json b/examples/water/dpa2/input_torch.json index 108e75df62..e94086b898 100644 --- a/examples/water/dpa2/input_torch.json +++ b/examples/water/dpa2/input_torch.json @@ -69,7 +69,7 @@ "_comment": " that's all" }, "training": { - "stat_file": "./dpa2", + "stat_file": "./dpa2.hdf5", "training_data": { "systems": [ "../data/data_0", diff --git a/examples/water/se_atten/input_torch.json b/examples/water/se_atten/input_torch.json index 7e9cf06f35..4160feda17 100644 --- a/examples/water/se_atten/input_torch.json +++ b/examples/water/se_atten/input_torch.json @@ -61,7 +61,7 @@ "_comment": " that's all" }, "training": { - "stat_file": "./dpa1", + "stat_file": "./dpa1.hdf5", "training_data": { "systems": [ "../data/data_0", diff --git a/examples/water/se_e2_a/input_torch.json b/examples/water/se_e2_a/input_torch.json index fe424afed3..3fc889212c 100644 --- a/examples/water/se_e2_a/input_torch.json +++ b/examples/water/se_e2_a/input_torch.json @@ -52,7 +52,7 @@ "_comment": " that's all" }, "training": { - "stat_file": "./se_e2_a", + "stat_file": "./se_e2_a.hdf5", "training_data": { "systems": [ "../data/data_0", diff --git a/source/tests/pt/model/test_saveload_dpa1.py b/source/tests/pt/model/test_saveload_dpa1.py index 712b44485e..3da06938b5 100644 --- a/source/tests/pt/model/test_saveload_dpa1.py +++ b/source/tests/pt/model/test_saveload_dpa1.py @@ -101,7 +101,7 @@ def create_wrapper(self, read: bool): model_config = copy.deepcopy(self.config["model"]) model_config["resuming"] = read model_config["stat_file_dir"] = "stat_files" - model_config["stat_file"] = "stat.npz" + model_config["stat_file"] = "stat.hdf5" model_config["stat_file_path"] = os.path.join( model_config["stat_file_dir"], model_config["stat_file"] ) diff --git a/source/tests/pt/model/water/multitask.json b/source/tests/pt/model/water/multitask.json index 2f706e4cd9..c59618145d 100644 --- a/source/tests/pt/model/water/multitask.json +++ b/source/tests/pt/model/water/multitask.json @@ -95,7 +95,7 @@ }, "data_dict": { "model_1": { - "stat_file": "./stat_files/model_1", + "stat_file": "./stat_files/model_1.hdf5", "training_data": { "systems": [ "pt/water/data/data_0" @@ -112,7 +112,7 @@ } }, "model_2": { - "stat_file": "./stat_files/model_2", + "stat_file": "./stat_files/model_2.hdf5", "training_data": { "systems": [ "pt/water/data/data_0"