Skip to content

Commit

Permalink
Bump version (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Jul 25, 2023
1 parent f9c585e commit f6d7464
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/tad_dftd3/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
"""
import torch

__version__ = "0.1.3"
__version__ = "0.1.4"

__torch_version__ = torch.__version__
12 changes: 8 additions & 4 deletions tests/test_utils/test_cdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,28 @@

@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_all(dtype: torch.dtype) -> None:
tol = 1e-6 if dtype == torch.float else 1e-14

x = torch.randn(2, 3, 4, dtype=dtype)

d1 = util.cdist(x)
d2 = util.distance.cdist_direct_expansion(x, x, p=2)
d3 = util.distance.euclidean_dist_quadratic_expansion(x, x)

assert pytest.approx(d1) == d2
assert pytest.approx(d2) == d3
assert pytest.approx(d3) == d1
assert pytest.approx(d1, abs=tol) == d2
assert pytest.approx(d2, abs=tol) == d3
assert pytest.approx(d3, abs=tol) == d1


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize("p", [2, 3, 4, 5])
def test_ps(dtype: torch.dtype, p: int) -> None:
tol = 1e-6 if dtype == torch.float else 1e-14

x = torch.randn(2, 4, 5, dtype=dtype)
y = torch.randn(2, 4, 5, dtype=dtype)

d1 = util.cdist(x, y, p=p)
d2 = torch.cdist(x, y, p=p)

assert pytest.approx(d1) == d2
assert pytest.approx(d1, abs=tol) == d2

0 comments on commit f6d7464

Please sign in to comment.