Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More improvements to test_linalg #101

Merged
merged 61 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
360df9f
Add shape testing for vector_norm()
asmeurer Feb 17, 2022
e34f492
Test the dtype and stacks in the vector_norm() test
asmeurer Feb 18, 2022
111c237
Remove an ununsed variable
asmeurer Feb 18, 2022
4f3aa54
Use a simpler strategy for ord in test_vector_norm
asmeurer Feb 19, 2022
979b81b
Skip the test_vector_norm test on the NumPy CI
asmeurer Feb 25, 2022
8df237a
Fix syntax error
asmeurer Feb 25, 2022
d11a685
Fix the input strategies for test_tensordot()
asmeurer Feb 26, 2022
a776cd4
Add a test for the tensordot result shape
asmeurer Apr 9, 2022
45b36d6
Test stacking for tensordot
asmeurer Apr 12, 2022
414b322
Add allclose() and assert_allclose() helper functions
asmeurer Apr 25, 2022
9bb8c7a
Use assert_allclose() in the linalg tests for float inputs
asmeurer Apr 25, 2022
b3fb4ec
Remove skip from test_eigh
asmeurer Apr 26, 2022
241220e
Disable eigenvectors stack test
asmeurer May 6, 2022
ca70fbe
Reduce the relative tolerance in assert_allclose
asmeurer May 6, 2022
720b309
Sort the eigenvalues when testing stacks
asmeurer May 6, 2022
17d93bf
Merge branch 'more-linalg2' of github.com:asmeurer/array-api-tests in…
asmeurer May 9, 2022
f439259
Sort the results in eigvalsh before comparing
asmeurer Jun 3, 2022
75ca73a
Remove the allclose testing in linalg
asmeurer Jun 13, 2022
d86a0a1
Add (commented out) stacking tests for solve()
asmeurer Jun 16, 2022
9bccfa5
Remove unused none standin in the linalg tests
asmeurer Jun 16, 2022
f494b45
Don't compare float elements in test_tensordot
asmeurer Jun 16, 2022
74add08
Fix test_vecdot
asmeurer Jun 24, 2022
f12be47
Fix typo in test_vecdot
asmeurer Jul 5, 2022
d41d0bd
Expand vecdot tests
asmeurer Jul 5, 2022
1220d6e
Merge branch 'master' into more-linalg2
asmeurer Sep 27, 2022
a96a5df
Merge branch 'master' into more-linalg2
asmeurer Oct 20, 2022
48a8442
Check specially that the result of linalg functions is not a unnamed …
asmeurer Nov 29, 2022
fd6367f
Use a more robust fallback helper for matrix_transpose
asmeurer Mar 17, 2023
7017797
Be more constrained about constructing symmetric matrices
asmeurer Mar 20, 2023
335574e
Merge branch 'more-linalg2' of github.com:asmeurer/array-api-tests in…
asmeurer Mar 21, 2023
246e38a
Don't require the arguments to assert_keepdimable_shape to be positio…
asmeurer Mar 23, 2023
02542ff
Show the arrays in the error message for assert_exactly_equal
asmeurer Mar 29, 2023
72974e0
Allow passing an extra assertion message to assert_equal in linalg an…
asmeurer Mar 29, 2023
1daba5d
Fix the true_value check for test_vecdot
asmeurer Mar 29, 2023
bbfe50f
Fix the test_diagonal true value check
asmeurer Mar 29, 2023
64b0342
Use a function instead of operation
asmeurer Mar 29, 2023
9cb58a1
Add a comment
asmeurer Apr 18, 2023
0b3e170
Merge branch 'master' into more-linalg2
asmeurer Feb 3, 2024
c51216b
Remove flaky skips from linalg tests
asmeurer Feb 3, 2024
cffd076
Fix some issues in linalg tests from recent merge
asmeurer Feb 3, 2024
3501116
Fix vector_norm to not use our custom arrays strategy
asmeurer Feb 3, 2024
5c1aa45
Update _test_stacks to use updated ndindex behavior
asmeurer Feb 3, 2024
7a46e6b
Further limit the size of n in test_matrix_power
asmeurer Feb 3, 2024
6d154f2
Fix test_trace
asmeurer Feb 3, 2024
257aa13
Fix test_vecdot to only generate axis in [-min(x1.ndim, x2.ndim), -1]
asmeurer Feb 3, 2024
afc8a25
Update test_cross to test broadcastable shapes
asmeurer Feb 3, 2024
3cb9912
Fix test_cross to use assert_dtype and assert_shape helpers
asmeurer Feb 3, 2024
012ca19
Remove some completed TODO comments
asmeurer Feb 3, 2024
5ceb81d
Update linalg tests to test complex dtypes
asmeurer Feb 3, 2024
a4d419f
Update linalg tests to use assert_dtype and assert_shape helpers
asmeurer Feb 3, 2024
6f9db94
Factor out dtype logic from test_sum() and test_prod() and apply it t…
asmeurer Feb 3, 2024
5aa9083
Remove unused allclose and assert_allclose helpers
asmeurer Feb 7, 2024
938f086
Update ndindex version requirement
asmeurer Feb 16, 2024
3856b8f
Fix linting issue
asmeurer Feb 16, 2024
ccc6ca3
Skip `test_cross` in CI
honno Feb 20, 2024
3092422
Test matmul, matrix_transpose, tensordot, and vecdot for the main and…
asmeurer Feb 23, 2024
2d918e4
Merge branch 'more-linalg2' of github.com:asmeurer/array-api-tests in…
asmeurer Feb 23, 2024
3fefd20
Remove need for filtering in `invertible_matrices()`
honno Feb 26, 2024
a76e051
Merge branch 'master' into more-linalg2
honno Feb 26, 2024
268682d
Skip flaky `test_reshape`
honno Feb 26, 2024
0ddb0cd
Less filtering in `positive_definitive_matrices`
honno Feb 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions array_api_tests/array_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# These are exported here so that they can be included in the special cases
# tests from this file.
from ._array_module import logical_not, subtract, floor, ceil, where
from . import _array_module as xp
from . import dtype_helpers as dh


__all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN',
Expand Down Expand Up @@ -164,19 +164,21 @@ def notequal(x, y):

return not_equal(x, y)

def assert_exactly_equal(x, y):
def assert_exactly_equal(x, y, msg_extra=None):
"""
Test that the arrays x and y are exactly equal.

If x and y do not have the same shape and dtype, they are not considered
equal.

"""
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})"
extra = '' if not msg_extra else f' ({msg_extra})'

assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}"

assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})"
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}"

assert all(exactly_equal(x, y)), "The input arrays have different values"
assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r}){extra}"

def assert_finite(x):
"""
Expand Down Expand Up @@ -306,3 +308,13 @@ def same_sign(x, y):
def assert_same_sign(x, y):
assert all(same_sign(x, y)), "The input arrays do not have the same sign"

def _matrix_transpose(x):
if not isinstance(xp.matrix_transpose, xp._UndefinedStub):
return xp.matrix_transpose(x)
if hasattr(x, 'mT'):
return x.mT
if not isinstance(xp.permute_dims, xp._UndefinedStub):
perm = list(range(x.ndim))
perm[-1], perm[-2] = perm[-2], perm[-1]
return xp.permute_dims(x, axes=tuple(perm))
raise NotImplementedError("No way to compute matrix transpose")
51 changes: 51 additions & 0 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,57 @@ class MinMax(NamedTuple):
{"complex64": xp.float32, "complex128": xp.float64}
)

def as_real_dtype(dtype):
"""
Return the corresponding real dtype for a given floating-point dtype.
"""
if dtype in real_float_dtypes:
return dtype
elif dtype_to_name[dtype] in complex_names:
return dtype_components[dtype]
else:
raise ValueError("as_real_dtype requires a floating-point dtype")

def accumulation_result_dtype(x_dtype, dtype_kwarg):
"""
Result dtype logic for sum(), prod(), and trace()

Note: may return None if a default uint cannot exist (e.g., for pytorch
which doesn't support uint32 or uint64). See https://github.com/data-apis/array-api-tests/issues/106

"""
if dtype_kwarg is None:
if is_int_dtype(x_dtype):
if x_dtype in uint_dtypes:
default_dtype = default_uint
else:
default_dtype = default_int
if default_dtype is None:
_dtype = None
else:
m, M = dtype_ranges[x_dtype]
d_m, d_M = dtype_ranges[default_dtype]
if m < d_m or M > d_M:
_dtype = x_dtype
else:
_dtype = default_dtype
elif is_float_dtype(x_dtype, include_complex=False):
if dtype_nbits[x_dtype] > dtype_nbits[default_float]:
_dtype = x_dtype
else:
_dtype = default_float
elif api_version > "2021.12":
# Complex dtype
if dtype_nbits[x_dtype] > dtype_nbits[default_complex]:
_dtype = x_dtype
else:
_dtype = default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype_kwarg

return _dtype

if not hasattr(xp, "asarray"):
default_int = xp.int32
Expand Down
30 changes: 21 additions & 9 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
sampled_from, shared, builds)

from . import _array_module as xp, api_version
from . import array_helpers as ah
from . import dtype_helpers as dh
from . import shape_helpers as sh
from . import xps
Expand Down Expand Up @@ -211,6 +212,7 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)

# Use this to avoid memory errors with NumPy.
# See https://github.com/numpy/numpy/issues/15753
# Note, the hypothesis default for max_dims is min_dims + 2 (i.e., 0 + 2)
def shapes(**kw):
kw.setdefault('min_dims', 0)
kw.setdefault('min_side', 0)
Expand Down Expand Up @@ -280,25 +282,29 @@ def mutually_broadcastable_shapes(

# Note: This should become hermitian_matrices when complex dtypes are added
@composite
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10.):
shape = draw(square_matrix_shapes)
dtype = draw(dtypes)
if not isinstance(finite, bool):
finite = draw(finite)
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
a = draw(arrays(dtype=dtype, shape=shape, elements=elements))
upper = xp.triu(a)
lower = xp.triu(a, k=1).mT
return upper + lower
at = ah._matrix_transpose(a)
H = (a + at)*0.5
if finite:
assume(not xp.any(xp.isinf(H)))
assume(xp.all((H == 0.) | ((1/bound <= xp.abs(H)) & (xp.abs(H) <= bound))))
return H

@composite
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
# For now just generate stacks of identity matrices
# TODO: Generate arbitrary positive definite matrices, for instance, by
# using something like
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
n = draw(integers(0))
shape = draw(shapes()) + (n, n)
base_shape = draw(shapes())
n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value
shape = base_shape + (n, n)
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
dtype = draw(dtypes)
return broadcast_to(eye(n, dtype=dtype), shape)
Expand All @@ -308,12 +314,18 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
# For now, just generate stacks of diagonal matrices.
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
stack_shape = draw(stack_shapes)
d = draw(arrays(dtypes, shape=(*stack_shape, 1, n),
elements=dict(allow_nan=False, allow_infinity=False)))
dtype = draw(dtypes)
elements = one_of(
from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False),
from_dtype(dtype, max_value=-0.5, allow_nan=False, allow_infinity=False),
)
d = draw(arrays(dtype, shape=(*stack_shape, 1, n), elements=elements))

# Functions that require invertible matrices may do anything when it is
# singular, including raising an exception, so we make sure the diagonals
# are sufficiently nonzero to avoid any numerical issues.
assume(xp.all(xp.abs(d) > 0.5))
assert xp.all(xp.abs(d) >= 0.5)

diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1))
return xp.where(diag_mask, d, xp.zeros_like(d))

Expand Down
16 changes: 16 additions & 0 deletions array_api_tests/meta/test_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from hypothesis import given

from ..hypothesis_helpers import symmetric_matrices
from .. import array_helpers as ah
from .. import _array_module as xp

@pytest.mark.xp_extension('linalg')
@given(x=symmetric_matrices(finite=True))
def test_symmetric_matrices(x):
upper = xp.triu(x)
lower = xp.tril(x)
lowerT = ah._matrix_transpose(lower)

ah.assert_exactly_equal(upper, lowerT)
Loading
Loading