diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index d447ab3fcf..2362821dfa 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json import os +import tempfile import unittest from abc import ( ABC, @@ -11,6 +12,7 @@ ) import dpdata +import h5py import numpy as np import torch @@ -55,7 +57,7 @@ DataRequirementItem, ) from deepmd.utils.path import ( - DPH5Path, + DPPath, ) CUR_DIR = os.path.dirname(__file__) @@ -336,13 +338,13 @@ def tf_compute_input_stats(self): class TestOutputStat(unittest.TestCase): - def test(self): + def setUp(self): self.data_file = [str(Path(__file__).parent / "water/data/data_0")] - type_map = ["O", "H"] # by dataset + self.type_map = ["O", "H"] # by dataset self.data = DpLoaderSet( self.data_file, batch_size=1, - type_map=type_map, + type_map=self.type_map, ) self.data.add_data_requirement(energy_data_requirement) self.sampled = make_stat_input( @@ -350,12 +352,18 @@ def test(self): self.data.dataloaders, nbatches=1, ) - stat_file_name = "my_output_stat" - if os.path.isfile(stat_file_name): - os.remove(stat_file_name) - # Path(stat_file_name).mkdir(exist_ok=True) - stat_file_path = DPH5Path(stat_file_name, "a") - atom_ener = np.array([3.0, 5.0]).reshape(2, 1) + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def tearDown(self): + self.tempdir.cleanup() + + def test_calc_and_load(self): + stat_file_path = self.stat_file_path + type_map = self.type_map # compute from sample ret0 = compute_output_stats( @@ -379,7 +387,7 @@ def test(self): np.testing.assert_almost_equal( to_numpy_array(ret0["energy"]), ground_truth_shift, decimal=10 ) - self.assertTrue(stat_file_path.is_dir()) + # self.assertTrue(stat_file_path.is_dir()) def raise_error(): raise RuntimeError @@ -397,7 +405,11 @@ def raise_error(): np.testing.assert_almost_equal( to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), decimal=10 ) - os.remove(stat_file_name) + + def test_assigned(self): + atom_ener = np.array([3.0, 5.0]).reshape(2, 1) + stat_file_path = self.stat_file_path + type_map = self.type_map # from assigned atom_ener ret2 = compute_output_stats( @@ -411,7 +423,6 @@ def raise_error(): np.testing.assert_almost_equal( to_numpy_array(ret2["energy"]), atom_ener, decimal=10 ) - os.remove(stat_file_name) if __name__ == "__main__":