Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Apr 1, 2024
1 parent f2dfe45 commit eee1b87
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import os
import tempfile
import unittest
from abc import (
ABC,
Expand All @@ -11,6 +12,7 @@
)

import dpdata
import h5py
import numpy as np
import torch

Expand Down Expand Up @@ -55,7 +57,7 @@
DataRequirementItem,
)
from deepmd.utils.path import (
DPH5Path,
DPPath,
)

CUR_DIR = os.path.dirname(__file__)
Expand Down Expand Up @@ -336,26 +338,32 @@ def tf_compute_input_stats(self):


class TestOutputStat(unittest.TestCase):
def test(self):
def setUp(self):
self.data_file = [str(Path(__file__).parent / "water/data/data_0")]
type_map = ["O", "H"] # by dataset
self.type_map = ["O", "H"] # by dataset
self.data = DpLoaderSet(
self.data_file,
batch_size=1,
type_map=type_map,
type_map=self.type_map,
)
self.data.add_data_requirement(energy_data_requirement)
self.sampled = make_stat_input(
self.data.systems,
self.data.dataloaders,
nbatches=1,
)
stat_file_name = "my_output_stat"
if os.path.isfile(stat_file_name):
os.remove(stat_file_name)
# Path(stat_file_name).mkdir(exist_ok=True)
stat_file_path = DPH5Path(stat_file_name, "a")
atom_ener = np.array([3.0, 5.0]).reshape(2, 1)
self.tempdir = tempfile.TemporaryDirectory()
h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve())
with h5py.File(h5file, "w") as f:
pass
self.stat_file_path = DPPath(h5file, "a")

def tearDown(self):
self.tempdir.cleanup()

def test_calc_and_load(self):
stat_file_path = self.stat_file_path
type_map = self.type_map

# compute from sample
ret0 = compute_output_stats(
Expand All @@ -379,7 +387,7 @@ def test(self):
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), ground_truth_shift, decimal=10
)
self.assertTrue(stat_file_path.is_dir())
# self.assertTrue(stat_file_path.is_dir())

def raise_error():
raise RuntimeError
Expand All @@ -397,7 +405,11 @@ def raise_error():
np.testing.assert_almost_equal(
to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), decimal=10
)
os.remove(stat_file_name)

def test_assigned(self):
atom_ener = np.array([3.0, 5.0]).reshape(2, 1)
stat_file_path = self.stat_file_path
type_map = self.type_map

# from assigned atom_ener
ret2 = compute_output_stats(
Expand All @@ -411,7 +423,6 @@ def raise_error():
np.testing.assert_almost_equal(
to_numpy_array(ret2["energy"]), atom_ener, decimal=10
)
os.remove(stat_file_name)


if __name__ == "__main__":
Expand Down

0 comments on commit eee1b87

Please sign in to comment.