Skip to content

Commit

Permalink
Update linalg tests
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Jan 19, 2025
1 parent 5ba2f41 commit f73bcf3
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,7 @@ def test_matrix_rank(self, data, dtype):

np_rank = numpy.linalg.matrix_rank(a)
dp_rank = dpnp.linalg.matrix_rank(a_dp)
assert np_rank == dp_rank
assert dp_rank.asnumpy() == np_rank

@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
Expand All @@ -1953,7 +1953,7 @@ def test_matrix_rank_hermitian(self, data, dtype):

np_rank = numpy.linalg.matrix_rank(a, hermitian=True)
dp_rank = dpnp.linalg.matrix_rank(a_dp, hermitian=True)
assert np_rank == dp_rank
assert dp_rank.asnumpy() == np_rank

@pytest.mark.parametrize(
"high_tol, low_tol",
Expand Down Expand Up @@ -1986,15 +1986,15 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
dp_rank_high_tol = dpnp.linalg.matrix_rank(
a_dp, hermitian=True, tol=dp_high_tol
)
assert np_rank_high_tol == dp_rank_high_tol
assert dp_rank_high_tol.asnumpy() == np_rank_high_tol

np_rank_low_tol = numpy.linalg.matrix_rank(
a, hermitian=True, tol=low_tol
)
dp_rank_low_tol = dpnp.linalg.matrix_rank(
a_dp, hermitian=True, tol=dp_low_tol
)
assert np_rank_low_tol == dp_rank_low_tol
assert dp_rank_low_tol.asnumpy() == np_rank_low_tol

# rtol kwarg was added in numpy 2.0
@testing.with_requires("numpy>=2.0")
Expand Down Expand Up @@ -2789,15 +2789,14 @@ def check_decomposition(
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
dpnp_diag_s[..., i, i] = dp_s[..., i]
reconstructed = dpnp.dot(dp_u, dpnp.dot(dpnp_diag_s, dp_vt))
# TODO: use assert dpnp.allclose() inside check_decomposition()
# when it will support complex dtypes
assert_allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)

assert dpnp.allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)

assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03)

if compute_vt:
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
if np_u[..., 0, i] * dp_u[..., 0, i] < 0:
if np_u[..., 0, i] * dpnp.asnumpy(dp_u[..., 0, i]) < 0:
np_u[..., :, i] = -np_u[..., :, i]
np_vt[..., i, :] = -np_vt[..., i, :]
for i in range(numpy.count_nonzero(np_s > tol)):
Expand Down

0 comments on commit f73bcf3

Please sign in to comment.