From f6d7464d874eadbc08202a2545e263feab48a88c Mon Sep 17 00:00:00 2001 From: Marvin Friede <51965259+marvinfriede@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:26:45 +0200 Subject: [PATCH] Bump version (#28) --- src/tad_dftd3/__version__.py | 2 +- tests/test_utils/test_cdist.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/tad_dftd3/__version__.py b/src/tad_dftd3/__version__.py index a5ae512..be774a8 100644 --- a/src/tad_dftd3/__version__.py +++ b/src/tad_dftd3/__version__.py @@ -17,6 +17,6 @@ """ import torch -__version__ = "0.1.3" +__version__ = "0.1.4" __torch_version__ = torch.__version__ diff --git a/tests/test_utils/test_cdist.py b/tests/test_utils/test_cdist.py index 9486a0c..dced45a 100644 --- a/tests/test_utils/test_cdist.py +++ b/tests/test_utils/test_cdist.py @@ -24,24 +24,28 @@ @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_all(dtype: torch.dtype) -> None: + tol = 1e-6 if dtype == torch.float else 1e-14 + x = torch.randn(2, 3, 4, dtype=dtype) d1 = util.cdist(x) d2 = util.distance.cdist_direct_expansion(x, x, p=2) d3 = util.distance.euclidean_dist_quadratic_expansion(x, x) - assert pytest.approx(d1) == d2 - assert pytest.approx(d2) == d3 - assert pytest.approx(d3) == d1 + assert pytest.approx(d1, abs=tol) == d2 + assert pytest.approx(d2, abs=tol) == d3 + assert pytest.approx(d3, abs=tol) == d1 @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @pytest.mark.parametrize("p", [2, 3, 4, 5]) def test_ps(dtype: torch.dtype, p: int) -> None: + tol = 1e-6 if dtype == torch.float else 1e-14 + x = torch.randn(2, 4, 5, dtype=dtype) y = torch.randn(2, 4, 5, dtype=dtype) d1 = util.cdist(x, y, p=p) d2 = torch.cdist(x, y, p=p) - assert pytest.approx(d1) == d2 + assert pytest.approx(d1, abs=tol) == d2