Skip to content

Commit

Permalink
Add dp test ut for spin
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 6, 2024
1 parent 3c09148 commit a73c314
Show file tree
Hide file tree
Showing 17 changed files with 172 additions and 32 deletions.
Binary file added source/tests/pt/NiO/data/data_0/set.000/box.npy
Binary file not shown.
Binary file added source/tests/pt/NiO/data/data_0/set.000/coord.npy
Binary file not shown.
Binary file added source/tests/pt/NiO/data/data_0/set.000/energy.npy
Binary file not shown.
Binary file added source/tests/pt/NiO/data/data_0/set.000/force.npy
Binary file not shown.
Binary file not shown.
Binary file added source/tests/pt/NiO/data/data_0/set.000/spin.npy
Binary file not shown.
32 changes: 32 additions & 0 deletions source/tests/pt/NiO/data/data_0/type.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2 changes: 2 additions & 0 deletions source/tests/pt/NiO/data/data_0/type_map.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Ni
O
Binary file added source/tests/pt/NiO/data/single/set.000/box.npy
Binary file not shown.
Binary file added source/tests/pt/NiO/data/single/set.000/coord.npy
Binary file not shown.
Binary file added source/tests/pt/NiO/data/single/set.000/energy.npy
Binary file not shown.
Binary file added source/tests/pt/NiO/data/single/set.000/force.npy
Binary file not shown.
Binary file not shown.
Binary file added source/tests/pt/NiO/data/single/set.000/spin.npy
Binary file not shown.
32 changes: 32 additions & 0 deletions source/tests/pt/NiO/data/single/type.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2 changes: 2 additions & 0 deletions source/tests/pt/NiO/data/single/type_map.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Ni
O
136 changes: 104 additions & 32 deletions source/tests/pt/test_dp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
Expand All @@ -13,59 +14,130 @@
import numpy as np
import torch

from deepmd.entrypoints.test import test as dp_test
from deepmd.pt.entrypoints.main import (
get_trainer,
)
from deepmd.pt.infer import (
inference,
from deepmd.pt.utils.utils import (
to_numpy_array,
)

from .model.test_permutation import (
model_se_e2_a,
model_spin,
)

class TestDPTest(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]
self.input_json = "test_dp_test.json"
with open(self.input_json, "w") as fp:
json.dump(self.config, fp, indent=4)

def test_dp_test(self):
class DPTest:
def test_dp_test_1_frame(self):
trainer = get_trainer(deepcopy(self.config))
trainer.run()

with torch.device("cpu"):
input_dict, label_dict, _ = trainer.get_data(is_train=False)
_, _, more_loss = trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0)

tester = inference.Tester("model.pt", input_script=self.input_json)
try:
res = tester.run()
except StopIteration:
raise StopIteration("Unexpected stop iteration.(test step < total batch)")
for k, v in res.items():
if k == "rmse" or "mae" in k or k not in more_loss:
continue
np.testing.assert_allclose(
v, more_loss[k].cpu().detach().numpy(), rtol=1e-04, atol=1e-07
has_spin = getattr(trainer.model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
if not has_spin:
input_dict.pop("spin", None)
input_dict["do_atomic_virial"] = True
result = trainer.model(**input_dict)
model = torch.jit.script(trainer.model)
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
torch.jit.save(model, tmp_model.name)
dp_test(
model=tmp_model.name,
system=self.config["training"]["validation_data"]["systems"][0],
datafile=None,
set_prefix="set",
numb_test=0,
rand_seed=None,
shuffle_test=False,
detail_file=self.detail_file,
atomic=False,
)
os.unlink(tmp_model.name)
natom = input_dict["atype"].shape[1]
pred_e = np.loadtxt(self.detail_file + ".e.out", ndmin=2)[0, 1]
np.testing.assert_almost_equal(
pred_e,
to_numpy_array(result["energy"])[0][0],
)
pred_e_peratom = np.loadtxt(self.detail_file + ".e_peratom.out", ndmin=2)[0, 1]
np.testing.assert_almost_equal(pred_e_peratom, pred_e / natom)
if not has_spin:
pred_f = np.loadtxt(self.detail_file + ".f.out", ndmin=2)[:, 3:6]
np.testing.assert_almost_equal(
pred_f,
to_numpy_array(result["force"]).reshape(-1, 3),
)
pred_v = np.loadtxt(self.detail_file + ".v.out", ndmin=2)[:, 9:18]
np.testing.assert_almost_equal(
pred_v,
to_numpy_array(result["virial"]),
)
pred_v_peratom = np.loadtxt(self.detail_file + ".v_peratom.out", ndmin=2)[
:, 9:18
]
np.testing.assert_almost_equal(pred_v_peratom, pred_v / natom)
else:
pred_fr = np.loadtxt(self.detail_file + ".fr.out", ndmin=2)[:, 3:6]
np.testing.assert_almost_equal(
pred_fr,
to_numpy_array(result["force"]).reshape(-1, 3),
)
pred_fm = np.loadtxt(self.detail_file + ".fm.out", ndmin=2)[:, 3:6]
np.testing.assert_almost_equal(
pred_fm,
to_numpy_array(
result["force_mag"][result["mask_mag"].bool().squeeze(-1)]
).reshape(-1, 3),
)

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f.startswith(self.detail_file):
os.remove(f)
if f in ["lcurve.out", self.input_json]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)


class TestDPTestSeA(unittest.TestCase, DPTest):

Check warning

Code scanning / CodeQL

Conflicting attributes in base classes Warning test

Base classes have conflicting values for attribute 'tearDown':
Function tearDown
and
Function tearDown
.
def setUp(self):
self.detail_file = "test_dp_test_ener_detail"
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
data_file = [str(Path(__file__).parent / "water/data/single")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_se_e2_a)
self.input_json = "test_dp_test.json"
with open(self.input_json, "w") as fp:
json.dump(self.config, fp, indent=4)


class TestDPTestSeASpin(unittest.TestCase, DPTest):

Check warning

Code scanning / CodeQL

Conflicting attributes in base classes Warning test

Base classes have conflicting values for attribute 'tearDown':
Function tearDown
and
Function tearDown
.
def setUp(self):
self.detail_file = "test_dp_test_ener_spin_detail"
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
data_file = [str(Path(__file__).parent / "NiO/data/single")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_spin)
self.config["model"]["type_map"] = ["Ni", "O", "B"]
self.input_json = "test_dp_test.json"
with open(self.input_json, "w") as fp:
json.dump(self.config, fp, indent=4)


if __name__ == "__main__":
unittest.main()

0 comments on commit a73c314

Please sign in to comment.