Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 7, 2024
1 parent 590aeba commit 7748723
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions source/tests/pt/model/test_polar_stat.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np
import torch

from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.model.task.polarizability import PolarFittingNet
from deepmd.tf.fit.polar import PolarFittingSeA
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.tf.fit.polar import (
PolarFittingSeA,
)


class TestConsistency(unittest.TestCase):
def setUp(self) -> None:
types = torch.randint(0,4,(1,5), device=env.DEVICE)
types = torch.cat((types,types,types), dim=0)
types = torch.randint(0, 4, (1, 5), device=env.DEVICE)
types = torch.cat((types, types, types), dim=0)
ntypes = 4
atomic_polarizability = torch.rand(3,5,9)
polarizability = torch.rand(3,9)
atomic_polarizability = torch.rand(3, 5, 9)
polarizability = torch.rand(3, 9)
find_polarizability = torch.rand(1)
find_atomic_polarizability = torch.rand(1)
self.sampled = [{"type": types,
"find_atomic_polarizability": find_atomic_polarizability,
"atomic_polarizability": atomic_polarizability,
"polarizability": polarizability,
"find_polarizability": find_polarizability}]
self.sampled = [
{
"type": types,
"find_atomic_polarizability": find_atomic_polarizability,
"atomic_polarizability": atomic_polarizability,
"polarizability": polarizability,
"find_polarizability": find_polarizability,
}
]
self.all_stat = {k: [v.numpy()] for d in self.sampled for k, v in d.items()}
self.tfpolar = PolarFittingSeA(
ntypes=ntypes,
Expand All @@ -36,7 +48,6 @@ def setUp(self) -> None:
dim_descrpt=1,
embedding_width=1,
)


def test_atomic_consistency(self):
self.tfpolar.compute_output_stats(self.all_stat)
Expand All @@ -48,12 +59,16 @@ def test_atomic_consistency(self):

def test_global_consistency(self):
self.sampled[0]["find_atomic_polarizability"] = -1
self.sampled[0]["polarizability"] = self.sampled[0]["atomic_polarizability"].sum(dim=1)
self.sampled[0]["polarizability"] = self.sampled[0][
"atomic_polarizability"
].sum(dim=1)
self.all_stat["find_atomic_polarizability"] = [-1]
self.all_stat["polarizability"] = list(self.all_stat["atomic_polarizability"][0].sum(axis=1))
self.all_stat["polarizability"] = list(
self.all_stat["atomic_polarizability"][0].sum(axis=1)
)
self.tfpolar.compute_output_stats(self.all_stat)
tfbias = self.tfpolar.constant_matrix
self.ptpolar.compute_output_stats(self.sampled)
ptbias = self.ptpolar.constant_matrix
print(tfbias, to_numpy_array(ptbias))
np.testing.assert_allclose(tfbias, to_numpy_array(ptbias))
np.testing.assert_allclose(tfbias, to_numpy_array(ptbias))

0 comments on commit 7748723

Please sign in to comment.