Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EHN: cluster: JAX support (non-jitted) #22255

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 27 additions & 42 deletions scipy/cluster/tests/test_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from scipy.cluster._hierarchy import Heap
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import xp_assert_close, xp_assert_equal
import scipy._lib.array_api_extra as xpx

from threading import Lock

Expand Down Expand Up @@ -445,55 +446,47 @@ def test_is_valid_linkage_4_and_up(self, xp):
Z = linkage(y)
assert_(is_valid_linkage(Z) is True)

@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_is_valid_linkage_4_and_up_neg_index_left(self, xp):
# Tests is_valid_linkage(Z) on linkage on observation sets between
# sizes 4 and 15 (step size 3) with negative indices (left).
for i in range(4, 15, 3):
y = np.random.rand(i*(i-1)//2)
y = xp.asarray(y)
Z = linkage(y)
Z[i//2,0] = -2
Z = xpx.at(Z)[i//2, 0].set(-2)
assert_(is_valid_linkage(Z) is False)
assert_raises(ValueError, is_valid_linkage, Z, throw=True)

@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_is_valid_linkage_4_and_up_neg_index_right(self, xp):
# Tests is_valid_linkage(Z) on linkage on observation sets between
# sizes 4 and 15 (step size 3) with negative indices (right).
for i in range(4, 15, 3):
y = np.random.rand(i*(i-1)//2)
y = xp.asarray(y)
Z = linkage(y)
Z[i//2,1] = -2
Z = xpx.at(Z)[i//2, 1].set(-2)
assert_(is_valid_linkage(Z) is False)
assert_raises(ValueError, is_valid_linkage, Z, throw=True)

@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_is_valid_linkage_4_and_up_neg_dist(self, xp):
# Tests is_valid_linkage(Z) on linkage on observation sets between
# sizes 4 and 15 (step size 3) with negative distances.
for i in range(4, 15, 3):
y = np.random.rand(i*(i-1)//2)
y = xp.asarray(y)
Z = linkage(y)
Z[i//2,2] = -0.5
Z = xpx.at(Z)[i//2, 2].set(-0.5)
assert_(is_valid_linkage(Z) is False)
assert_raises(ValueError, is_valid_linkage, Z, throw=True)

@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_is_valid_linkage_4_and_up_neg_counts(self, xp):
# Tests is_valid_linkage(Z) on linkage on observation sets between
# sizes 4 and 15 (step size 3) with negative counts.
for i in range(4, 15, 3):
y = np.random.rand(i*(i-1)//2)
y = xp.asarray(y)
Z = linkage(y)
Z[i//2,3] = -2
Z = xpx.at(Z)[i//2, 3].set(-2)
assert_(is_valid_linkage(Z) is False)
assert_raises(ValueError, is_valid_linkage, Z, throw=True)

Expand Down Expand Up @@ -538,7 +531,6 @@ def test_is_valid_im_4_and_up(self, xp):
R = inconsistent(Z)
assert_(is_valid_im(R) is True)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment')
def test_is_valid_im_4_and_up_neg_index_left(self, xp):
# Tests is_valid_im(R) on im on observation sets between sizes 4 and 15
# (step size 3) with negative link height means.
Expand All @@ -547,11 +539,10 @@ def test_is_valid_im_4_and_up_neg_index_left(self, xp):
y = xp.asarray(y)
Z = linkage(y)
R = inconsistent(Z)
R[i//2,0] = -2.0
R = xpx.at(R)[i//2 , 0].set(-2.0)
assert_(is_valid_im(R) is False)
assert_raises(ValueError, is_valid_im, R, throw=True)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment')
def test_is_valid_im_4_and_up_neg_index_right(self, xp):
# Tests is_valid_im(R) on im on observation sets between sizes 4 and 15
# (step size 3) with negative link height standard deviations.
Expand All @@ -560,11 +551,10 @@ def test_is_valid_im_4_and_up_neg_index_right(self, xp):
y = xp.asarray(y)
Z = linkage(y)
R = inconsistent(Z)
R[i//2,1] = -2.0
R = xpx.at(R)[i//2 , 1].set(-2.0)
assert_(is_valid_im(R) is False)
assert_raises(ValueError, is_valid_im, R, throw=True)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment')
def test_is_valid_im_4_and_up_neg_dist(self, xp):
# Tests is_valid_im(R) on im on observation sets between sizes 4 and 15
# (step size 3) with negative link counts.
Expand All @@ -573,7 +563,7 @@ def test_is_valid_im_4_and_up_neg_dist(self, xp):
y = xp.asarray(y)
Z = linkage(y)
R = inconsistent(Z)
R[i//2,2] = -0.5
R = xpx.at(R)[i//2, 2].set(-0.5)
assert_(is_valid_im(R) is False)
assert_raises(ValueError, is_valid_im, R, throw=True)

Expand Down Expand Up @@ -766,12 +756,11 @@ def test_is_monotonic_tdist_linkage1(self, xp):
Z = linkage(xp.asarray(hierarchy_test_data.ytdist), 'single')
assert is_monotonic(Z)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment')
def test_is_monotonic_tdist_linkage2(self, xp):
# Tests is_monotonic(Z) on clustering generated by single linkage on
# tdist data set. Perturbing. Expecting False.
Z = linkage(xp.asarray(hierarchy_test_data.ytdist), 'single')
Z[2,2] = 0.0
Z = xpx.at(Z)[2, 2].set(0.0)
assert not is_monotonic(Z)

def test_is_monotonic_Q_linkage(self, xp):
Expand All @@ -790,15 +779,13 @@ def test_maxdists_empty_linkage(self, xp):
Z = xp.zeros((0, 4), dtype=xp.float64)
assert_raises(ValueError, maxdists, Z)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment')
def test_maxdists_one_cluster_linkage(self, xp):
# Tests maxdists(Z) on linkage with one cluster.
Z = xp.asarray([[0, 1, 0.3, 4]], dtype=xp.float64)
MD = maxdists(Z)
expectedMD = calculate_maximum_distances(Z, xp)
xp_assert_close(MD, expectedMD, atol=1e-15)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment')
def test_maxdists_Q_linkage(self, xp):
for method in ['single', 'complete', 'ward', 'centroid', 'median']:
self.check_maxdists_Q_linkage(method, xp)
Expand Down Expand Up @@ -829,8 +816,7 @@ def test_maxinconsts_difrow_linkage(self, xp):
R = xp.asarray(R)
assert_raises(ValueError, maxinconsts, Z, R)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment',
cpu_only=True)
@skip_xp_backends(cpu_only=True, reason="implicit device->host transfer")
def test_maxinconsts_one_cluster_linkage(self, xp):
# Tests maxinconsts(Z, R) on linkage with one cluster.
Z = xp.asarray([[0, 1, 0.3, 4]], dtype=xp.float64)
Expand All @@ -839,8 +825,7 @@ def test_maxinconsts_one_cluster_linkage(self, xp):
expectedMD = calculate_maximum_inconsistencies(Z, R, xp=xp)
xp_assert_close(MD, expectedMD, atol=1e-15)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment',
cpu_only=True)
@skip_xp_backends(cpu_only=True, reason="implicit device->host transfer")
def test_maxinconsts_Q_linkage(self, xp):
for method in ['single', 'complete', 'ward', 'centroid', 'median']:
self.check_maxinconsts_Q_linkage(method, xp)
Expand Down Expand Up @@ -893,8 +878,7 @@ def check_maxRstat_difrow_linkage(self, i, xp):
R = xp.asarray(R)
assert_raises(ValueError, maxRstat, Z, R, i)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment',
cpu_only=True)
@skip_xp_backends(cpu_only=True, reason="implicit device->host transfer")
def test_maxRstat_one_cluster_linkage(self, xp):
for i in range(4):
self.check_maxRstat_one_cluster_linkage(i, xp)
Expand All @@ -907,8 +891,7 @@ def check_maxRstat_one_cluster_linkage(self, i, xp):
expectedMD = calculate_maximum_inconsistencies(Z, R, 1, xp)
xp_assert_close(MD, expectedMD, atol=1e-15)

@skip_xp_backends('jax.numpy', reason='jax arrays do not support item assignment',
cpu_only=True)
@skip_xp_backends(cpu_only=True, reason="implicit device->host transfer")
def test_maxRstat_Q_linkage(self, xp):
for method in ['single', 'complete', 'ward', 'centroid', 'median']:
for i in range(4):
Expand Down Expand Up @@ -1129,17 +1112,18 @@ def calculate_maximum_distances(Z, xp):
# Used for testing correctness of maxdists.
n = Z.shape[0] + 1
B = xp.zeros((n-1,), dtype=Z.dtype)
q = xp.zeros((3,))
for i in range(0, n - 1):
q[:] = 0.0
q = xp.zeros((3,))
left = Z[i, 0]
right = Z[i, 1]
if left >= n:
q[0] = B[xp.asarray(left, dtype=xp.int64) - n]
b_left = B[xp.asarray(left, dtype=xp.int64) - n]
q = xpx.at(q, 0).set(b_left)
if right >= n:
q[1] = B[xp.asarray(right, dtype=xp.int64) - n]
q[2] = Z[i, 2]
B[i] = xp.max(q)
b_right = B[xp.asarray(right, dtype=xp.int64) - n]
q = xpx.at(q, 1).set(b_right)
q = xpx.at(q, 2).set(Z[i, 2])
B = xpx.at(B, i).set(xp.max(q))
return B


Expand All @@ -1148,17 +1132,18 @@ def calculate_maximum_inconsistencies(Z, R, k=3, xp=np):
n = Z.shape[0] + 1
dtype = xp.result_type(Z, R)
B = xp.zeros((n-1,), dtype=dtype)
q = xp.zeros((3,))
for i in range(0, n - 1):
q[:] = 0.0
q = xp.zeros((3,))
left = Z[i, 0]
right = Z[i, 1]
if left >= n:
q[0] = B[xp.asarray(left, dtype=xp.int64) - n]
b_left = B[xp.asarray(left, dtype=xp.int64) - n]
q = xpx.at(q, 0).set(b_left)
if right >= n:
q[1] = B[xp.asarray(right, dtype=xp.int64) - n]
q[2] = R[i, k]
B[i] = xp.max(q)
b_right = B[xp.asarray(right, dtype=xp.int64) - n]
q = xpx.at(q, 1).set(b_right)
q = xpx.at(q, 2).set(R[i, k])
B = xpx.at(B, i).set(xp.max(q))
return B


Expand Down
10 changes: 0 additions & 10 deletions scipy/cluster/tests/test_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def test_whiten(self, xp):
def whiten_lock(self):
return Lock()

@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_whiten_zero_std(self, xp, whiten_lock):
desired = xp.asarray([[0., 1.0, 2.86666544],
[0., 1.0, 1.32460034],
Expand Down Expand Up @@ -334,8 +332,6 @@ def test_kmeans2_high_dim(self, xp):
data = xp.reshape(data, (20, 20))[:10, :]
kmeans2(data, 2)

@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_kmeans2_init(self, xp):
rng = np.random.default_rng(12345678)
data = xp.asarray(TESTDATA_2D)
Expand Down Expand Up @@ -390,8 +386,6 @@ def test_kmeans_large_thres(self, xp):
xp_assert_close(res[0], xp.asarray([4.], dtype=xp.float64))
xp_assert_close(res[1], xp.asarray(2.3999999999999999, dtype=xp.float64)[()])

@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_kmeans2_kpp_low_dim(self, xp):
# Regression test for gh-11462
rng = np.random.default_rng(2358792345678234568)
Expand All @@ -401,8 +395,6 @@ def test_kmeans2_kpp_low_dim(self, xp):
xp_assert_close(res, prev_res)

@pytest.mark.thread_unsafe
@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_kmeans2_kpp_high_dim(self, xp):
# Regression test for gh-11462
rng = np.random.default_rng(23587923456834568)
Expand All @@ -427,8 +419,6 @@ def test_kmeans_diff_convergence(self, xp):
xp_assert_close(res[0], xp.asarray([-0.4, 8.], dtype=xp.float64))
xp_assert_close(res[1], xp.asarray(1.0666666666666667, dtype=xp.float64)[()])

@skip_xp_backends('jax.numpy',
reason='jax arrays do not support item assignment')
def test_kmeans_and_kmeans2_random_seed(self, xp):

seed_list = [
Expand Down
11 changes: 6 additions & 5 deletions scipy/cluster/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def whiten(obs, check_finite=True):
obs = _asarray(obs, check_finite=check_finite, xp=xp)
std_dev = xp.std(obs, axis=0)
zero_std_mask = std_dev == 0
if xp.any(zero_std_mask):
std_dev[zero_std_mask] = 1.0
std_dev = xpx.at(std_dev, zero_std_mask).set(1.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails on jax.jit. My current intention is to change jax.jit itself to special-case arr.at[idx].set(value) when idx is a boolean mask and value is a scalar, so that it can be rewritten as jnp.where(idx, value, arr). Failing that, I can implement the same special case in array-api-extra.

Copy link
Member

@jakevdp jakevdp Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My current intention is to change jax.jit itself

I don't think changing jax.jit itself is a viable path here – I wouldn't suggest starting on that route.

It might be viable to make arr.at[idx].set(value) lower to lax.select rather than lax.scatter in the specific case of a boolean idx. I've tried that in the past, but it's really tricky to properly handle all corner cases of broadcasted and/or multi-dimensional indices, correctly implementing autodiff and batching rules, etc.

The easiest thing would probably be to do this at the level of xp.at, though boolean indices were specifically excluded from the initial discussions there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried that in the past

Do you have a (partial, non-functioning) PR I could start from?

if check_finite and xp.any(zero_std_mask):
Copy link
Contributor Author

@crusaderky crusaderky Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails on jax.jit with an error message that's not not useful to final scipy users, unless you explicitly pass check_finite=False.
data-apis/array-api-compat#225 is a blocker for a more robust fix.

warnings.warn("Some columns have standard deviation zero. "
"The values of these columns will not change.",
RuntimeWarning, stacklevel=2)
Expand Down Expand Up @@ -607,15 +607,16 @@ def _kpp(data, k, rng, xp):

for i in range(k):
if i == 0:
init[i, :] = data[rng_integers(rng, data.shape[0]), :]

data_idx = rng_integers(rng, data.shape[0])
else:
D2 = cdist(init[:i,:], data, metric='sqeuclidean').min(axis=0)
probs = D2/D2.sum()
cumprobs = probs.cumsum()
r = rng.uniform()
cumprobs = np.asarray(cumprobs)
init[i, :] = data[np.searchsorted(cumprobs, r), :]
data_idx = np.searchsorted(cumprobs, r)

init = xpx.at(init)[i, :].set(data[data_idx, :])

if ndim == 1:
init = init[:, 0]
Expand Down
Loading