diff --git a/inferelator_velocity/utils/mcv.py b/inferelator_velocity/utils/mcv.py index ea3229f..f43e549 100644 --- a/inferelator_velocity/utils/mcv.py +++ b/inferelator_velocity/utils/mcv.py @@ -118,11 +118,11 @@ def _molecular_split(count_data, random_seed=800, p=0.5): rng = np.random.default_rng(random_seed) - if sps.issparse(count_data): + normalization_depth = np.median( + safer_sum(count_data, 1) + ) - normalization_depth = np.median( - count_data.sum(axis=1).A1 - ) + if sps.issparse(count_data): if sps.isspmatrix_csr(count_data): mat_func = sps.csr_matrix @@ -145,10 +145,6 @@ def _molecular_split(count_data, random_seed=800, p=0.5): else: - normalization_depth = np.median( - count_data.sum(axis=1) - ) - cv_data = np.zeros_like(count_data) for i in range(count_data.shape[0]): @@ -214,6 +210,26 @@ def _mse_rowwise( return output + @numba.njit(parallel=False) + def _sum_columns(data, indices, n_col): + + output = np.zeros(n_col, dtype=data.dtype) + + for i in numba.prange(data.shape[0]): + output[indices[i]] += data[i] + + return output + + @numba.njit(parallel=False) + def _sum_rows(data, indptr): + + output = np.zeros(indptr.shape[0] - 1, dtype=data.dtype) + + for i in numba.prange(output.shape[0]): + output[i] = np.sum(data[indptr[i]:indptr[i + 1]]) + + return output + def mcv_mse(x, pc, rotation, by_row=False, **metric_kwargs): if sps.issparse(x): @@ -242,6 +258,27 @@ def mcv_mse(x, pc, rotation, by_row=False, **metric_kwargs): **metric_kwargs ) + def safer_sum(sparse_array, axis=None): + + if not sps.issparse(sparse_array): + return np.sum(sparse_array, axis=axis) + + if axis is None: + return np.sum(sparse_array.data) + + elif axis == 0: + return _sum_columns( + sparse_array.data, + sparse_array.indices, + sparse_array.shape[1] + ) + + elif axis == 1: + return _sum_rows( + sparse_array.data, + sparse_array.indptr + ) + except ImportError: def mcv_mse(x, pc, rotation, **metric_kwargs): @@ -258,3 +295,12 @@ def mcv_mse(x, pc, rotation, **metric_kwargs): metric='mse', **metric_kwargs ) + + def safer_sum(sparse_array, axis=None): + + x = sparse_array.sum(axis) + + if hasattr(x, 'A1'): + return x.A1 + else: + return x