From 135c5371dc40b9095450b455e26a726f97dc3968 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Tue, 7 Jan 2025 11:16:25 -0800 Subject: [PATCH] fix casting bug --- apis/python/src/tiledbsoma/_indexer.py | 8 ++++---- apis/python/tests/test_indexer.py | 11 ++++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/apis/python/src/tiledbsoma/_indexer.py b/apis/python/src/tiledbsoma/_indexer.py index 7d0649d2a4..c2e29ea54d 100644 --- a/apis/python/src/tiledbsoma/_indexer.py +++ b/apis/python/src/tiledbsoma/_indexer.py @@ -74,12 +74,12 @@ def __init__( # TODO: the map_locations interface does not accept chunked arrays. It would # save a copy (reduce memory usage) if they were natively supported. - if isinstance( - data, (pa.Array, pa.ChunkedArray, pd.arrays.IntegerArray, pd.Series) - ): + if isinstance(data, (pa.Array, pa.ChunkedArray)): data = data.to_numpy() elif isinstance(data, list): data = np.array(data, dtype=np.int64) + elif isinstance(data, (pd.arrays.IntegerArray, pd.Series)): + data = data.to_numpy(dtype=np.int64, copy=False) self._reindexer.map_locations(data) @@ -95,7 +95,7 @@ def get_indexer(self, target: IndexerDataType) -> npt.NDArray[np.intp]: return self._reindexer.get_indexer_pyarrow(target) if isinstance(target, (pd.arrays.IntegerArray, pd.Series)): - target = target.to_numpy() + target = target.to_numpy(dtype=np.int64, copy=False) elif isinstance(target, list): target = np.array(target, dtype=np.int64) diff --git a/apis/python/tests/test_indexer.py b/apis/python/tests/test_indexer.py index 23cff3ebe3..758be0b6ac 100644 --- a/apis/python/tests/test_indexer.py +++ b/apis/python/tests/test_indexer.py @@ -102,9 +102,12 @@ def test_indexer(contextual: bool, keys: np.array, lookups: np.array): num_threads = 10 def target(): - indexer = IntIndexer(keys, context=context) - results = indexer.get_indexer(lookups) - all_results.append(results) + try: + indexer = IntIndexer(keys, context=context) + results = indexer.get_indexer(lookups) + all_results.append(results) + except Exception as e: + all_results.append(e) for t in range(num_threads): thread = threading.Thread(target=target, args=()) @@ -113,6 +116,8 @@ def target(): panda_indexer = pd.Index(keys) panda_results = panda_indexer.get_indexer(lookups) for i in range(num_threads): + if isinstance(all_results[i], Exception): + raise all_results[i] np.testing.assert_equal(all_results[i].all(), panda_results.all())