Skip to content

Commit

Permalink
Use numba to work around int32 scipy.sparse issue
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jan 31, 2024
1 parent dbe3cab commit 7c6e658
Showing 1 changed file with 54 additions and 8 deletions.
62 changes: 54 additions & 8 deletions inferelator_velocity/utils/mcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def _molecular_split(count_data, random_seed=800, p=0.5):

rng = np.random.default_rng(random_seed)

if sps.issparse(count_data):
normalization_depth = np.median(
safer_sum(count_data, 1)
)

normalization_depth = np.median(
count_data.sum(axis=1).A1
)
if sps.issparse(count_data):

if sps.isspmatrix_csr(count_data):
mat_func = sps.csr_matrix
Expand All @@ -145,10 +145,6 @@ def _molecular_split(count_data, random_seed=800, p=0.5):

else:

normalization_depth = np.median(
count_data.sum(axis=1)
)

cv_data = np.zeros_like(count_data)

for i in range(count_data.shape[0]):
Expand Down Expand Up @@ -214,6 +210,26 @@ def _mse_rowwise(

return output

@numba.njit(parallel=False)
def _sum_columns(data, indices, n_col):

output = np.zeros(n_col, dtype=data.dtype)

for i in numba.prange(data.shape[0]):
output[indices[i]] += data[i]

return output

@numba.njit(parallel=False)
def _sum_rows(data, indptr):

output = np.zeros(indptr.shape[0] - 1, dtype=data.dtype)

for i in numba.prange(output.shape[0]):
output[i] = np.sum(data[indptr[i]:indptr[i + 1]])

return output

def mcv_mse(x, pc, rotation, by_row=False, **metric_kwargs):

if sps.issparse(x):
Expand Down Expand Up @@ -242,6 +258,27 @@ def mcv_mse(x, pc, rotation, by_row=False, **metric_kwargs):
**metric_kwargs
)

def safer_sum(sparse_array, axis=None):

if not sps.issparse(sparse_array):
return np.sum(sparse_array, axis=axis)

if axis is None:
return np.sum(sparse_array.data)

elif axis == 0:
return _sum_columns(
sparse_array.data,
sparse_array.indices,
sparse_array.shape[1]
)

elif axis == 1:
return _sum_rows(
sparse_array.data,
sparse_array.indptr
)

except ImportError:

def mcv_mse(x, pc, rotation, **metric_kwargs):
Expand All @@ -258,3 +295,12 @@ def mcv_mse(x, pc, rotation, **metric_kwargs):
metric='mse',
**metric_kwargs
)

def safer_sum(sparse_array, axis=None):

x = sparse_array.sum(axis)

if hasattr(x, 'A1'):
return x.A1
else:
return x

0 comments on commit 7c6e658

Please sign in to comment.