diff --git a/inferelator_velocity/tests/test_mcv.py b/inferelator_velocity/tests/test_mcv.py index 93ebfae..8e7eedb 100644 --- a/inferelator_velocity/tests/test_mcv.py +++ b/inferelator_velocity/tests/test_mcv.py @@ -1,14 +1,44 @@ import unittest import numpy as np +import numpy.testing as npt import scipy.sparse as sps +import anndata as ad -from inferelator_velocity.utils.mcv import mcv_pcs +from inferelator_velocity.utils import TruncRobustScaler +from inferelator_velocity.utils.mcv import ( + mcv_pcs, + mcv_mse, + standardize_data +) +from inferelator_velocity.utils.misc import ( + _normalize_for_pca +) from ._stubs import ( COUNTS ) +def _safe_sum(x, axis): + + sums = x.sum(axis) + + try: + sums = sums.A1 + except AttributeError: + pass + + return sums + + +def _safe_dense(x): + + try: + return x.A + except AttributeError: + return x + + class TestMCV(unittest.TestCase): def test_sparse_log(self): @@ -29,3 +59,96 @@ def test_sparse_log_scale(self): ), 0 ) + + +class TestMCVMetrics(unittest.TestCase): + + def testMSErow(self): + data = sps.csr_matrix(COUNTS) + + mse = mcv_mse( + data, + data @ np.zeros((10, 10)), + np.eye(10), + by_row=True + ) + + npt.assert_almost_equal( + data.power(2).sum(axis=1).A1 / 10, + mse + ) + + +class TestMCVStandardization(unittest.TestCase): + + tol = 6 + + def setUp(self): + super().setUp() + self.data = ad.AnnData(sps.csr_matrix(COUNTS)) + + def test_depth(self): + + _normalize_for_pca(self.data, target_sum=100, log=False, scale=False) + + rowsums = _safe_sum(self.data.X, 1) + + npt.assert_almost_equal( + rowsums, + np.full_like(rowsums, 100.), + decimal=self.tol + ) + + def test_depth_log(self): + + _normalize_for_pca(self.data, target_sum=100, log=True, scale=False) + + rowsums = _safe_sum(self.data.X, 1) + + npt.assert_almost_equal( + rowsums, + np.log1p(100 * COUNTS / np.sum(COUNTS, axis=1)[:, None]).sum(1), + decimal=self.tol + ) + + def test_depth_scale(self): + + _normalize_for_pca(self.data, target_sum=100, log=False, scale=True) + + npt.assert_almost_equal( + _safe_dense(self.data.X), + TruncRobustScaler(with_centering=False).fit_transform( + 100 * COUNTS / np.sum(COUNTS, axis=1)[:, None] + ), + decimal=self.tol + ) + + def test_depth_log_scale(self): + + _normalize_for_pca(self.data, target_sum=100, log=True, scale=True) + + npt.assert_almost_equal( + _safe_dense(self.data.X), + TruncRobustScaler(with_centering=False).fit_transform( + np.log1p(100 * COUNTS / np.sum(COUNTS, axis=1)[:, None]) + ), + decimal=self.tol + ) + + +class TestMCVStandardizationDense(TestMCVStandardization): + + tol = 4 + + def setUp(self): + super().setUp() + self.data = ad.AnnData(COUNTS.copy()) + + +class TestMCVStandardizationCSC(TestMCVStandardization): + + tol = 4 + + def setUp(self): + super().setUp() + self.data = ad.AnnData(sps.csc_matrix(COUNTS)) diff --git a/inferelator_velocity/utils/math.py b/inferelator_velocity/utils/math.py index 929deaf..1d3fab9 100644 --- a/inferelator_velocity/utils/math.py +++ b/inferelator_velocity/utils/math.py @@ -1,5 +1,4 @@ import numpy as np -import numba import scipy.sparse as sps from .misc import ( make_vector_2D, @@ -140,7 +139,7 @@ def _calc_se(x, y, slope): return se_y / mse_x -def _log_loss(x, y): +def _log_loss(x, y, axis=1): if y is None: raise ValueError( @@ -168,12 +167,12 @@ def _log_loss(x, y): x, np.log(y) ) - err = err.sum(axis=1) + err = err.sum(axis=axis) err *= -1 return err -def _mse(x, y): +def _mse(x, y, axis=1): if y is not None: ssr = x - y @@ -188,17 +187,17 @@ def _mse(x, y): else: ssr **= 2 - return ssr.sum(axis=1) + return ssr.sum(axis=axis) -def _mae(x, y): +def _mae(x, y, axis=1): if y is not None: ssr = x - y else: ssr = x - return ssr.sum(axis=1) + return ssr.sum(axis=axis) def pairwise_metric( @@ -206,6 +205,7 @@ def pairwise_metric( y, metric='mse', by_row=False, + axis=1, **kwargs ): """ @@ -233,6 +233,7 @@ def pairwise_metric( loss = metric( x, y, + axis=axis, **kwargs ) @@ -241,7 +242,7 @@ def pairwise_metric( except AttributeError: pass - loss = loss / x.shape[1] + loss = loss / x.shape[axis] if by_row: return loss @@ -343,77 +344,50 @@ def coefficient_of_variation( ) -@numba.njit(parallel=False) -def _csr_row_divide( - a_data, - a_indptr, - row_vec +def mcv_mse( + x, + pc, + rotation, + by_row=False, + axis=1, + **metric_kwargs ): - n_row = row_vec.shape[0] - - for i in numba.prange(n_row): - a_data[a_indptr[i]:a_indptr[i + 1]] /= row_vec[i] - + if sps.issparse(x): -@numba.njit(parallel=False) -def _mse_rowwise( - a_data, - a_indices, - a_indptr, - b_pcs, - b_rotation -): + from .sparse_math import mcv_mse_sparse - n_row = b_pcs.shape[0] + return mcv_mse_sparse( + x, + pc, + rotation, + by_row=by_row, + axis=axis, + **metric_kwargs + ) - output = np.zeros(n_row, dtype=float) + else: - for i in numba.prange(n_row): + return pairwise_metric( + x, + pc @ rotation, + metric='mse', + by_row=by_row, + axis=axis, + **metric_kwargs + ) - _idx_a = a_indices[a_indptr[i]:a_indptr[i + 1]] - _nnz_a = _idx_a.shape[0] - row = b_pcs[i, :] @ b_rotation +def array_sum(array, axis=None, squared=False): - if _nnz_a == 0: - continue + if not is_csr(array): + if squared and not sps.issparse(array): + _sums = (array ** 2).sum(axis=axis) + elif squared: + _sums = array.power(2).sum(axis=axis) else: - - row[_idx_a] -= a_data[a_indptr[i]:a_indptr[i + 1]] - - output[i] = np.mean(row ** 2) - - return output - - -@numba.njit(parallel=False) -def _sum_columns(data, indices, n_col): - - output = np.zeros(n_col, dtype=data.dtype) - - for i in numba.prange(data.shape[0]): - output[indices[i]] += data[i] - - return output - - -@numba.njit(parallel=False) -def _sum_rows(data, indptr): - - output = np.zeros(indptr.shape[0] - 1, dtype=data.dtype) - - for i in numba.prange(output.shape[0]): - output[i] = np.sum(data[indptr[i]:indptr[i + 1]]) - - return output - - -def array_sum(array, axis=None): - - if not is_csr(array): - _sums = array.sum(axis=axis) + _sums = array.sum(axis=axis) try: _sums = _sums.A1 except AttributeError: @@ -423,44 +397,11 @@ def array_sum(array, axis=None): if axis is None: return np.sum(array.data) - elif axis == 0: - return _sum_columns( - array.data, - array.indices, - array.shape[1] - ) - - elif axis == 1: - return _sum_rows( - array.data, - array.indptr - ) - - -def mcv_mse(x, pc, rotation, by_row=False, **metric_kwargs): - - if sps.issparse(x): - - y = _mse_rowwise( - x.data, - x.indices, - x.indptr, - np.ascontiguousarray(pc), - np.ascontiguousarray(rotation, dtype=pc.dtype) - ) - - if by_row: - return y - - else: - return np.mean(y) - else: + from .sparse_math import sparse_sum - return pairwise_metric( - x, - pc @ rotation, - metric='mse', - by_row=by_row, - **metric_kwargs + return sparse_sum( + array, + axis=axis, + squared=squared ) diff --git a/inferelator_velocity/utils/mcv.py b/inferelator_velocity/utils/mcv.py index 96a519e..f6cc18b 100644 --- a/inferelator_velocity/utils/mcv.py +++ b/inferelator_velocity/utils/mcv.py @@ -9,7 +9,8 @@ ) from inferelator_velocity.utils.math import ( pairwise_metric, - mcv_mse + mcv_mse, + array_sum ) @@ -121,7 +122,7 @@ def _molecular_split(count_data, random_seed=800, p=0.5): if sps.issparse(count_data): normalization_depth = np.median( - count_data.sum(axis=1).A1 + array_sum(count_data, axis=1) ) if sps.isspmatrix_csr(count_data): @@ -162,19 +163,55 @@ def _molecular_split(count_data, random_seed=800, p=0.5): return count_data, cv_data, normalization_depth -def mcv_comp(x, pc, rotation, metric, **metric_kwargs): +def mcv_comp( + x, + pc, + rotation, + metric, + calculate_r2=False, + column_tss=None, + **metric_kwargs +): if metric != 'mse': - return pairwise_metric( + metric_arr = pairwise_metric( x, pc @ rotation, metric=metric, **metric_kwargs ) else: - return mcv_mse( + metric_arr = mcv_mse( + x, + pc, + rotation, + **metric_kwargs + ) + + if calculate_r2: + if column_tss is None: + column_tss = array_sum(x, axis=0, squared=True) + + r2_array = mcv_mse( x, pc, rotation, + axis=0, **metric_kwargs ) + + np.divide( + r2_array, + column_tss, + where=column_tss != 0, + out=r2_array + ) + + r2_array[column_tss == 0] = 0. + r2_array *= -1 + r2_array += 1 + + return metric_arr, r2_array + + else: + return metric_arr diff --git a/inferelator_velocity/utils/misc.py b/inferelator_velocity/utils/misc.py index 6a736db..fc74478 100644 --- a/inferelator_velocity/utils/misc.py +++ b/inferelator_velocity/utils/misc.py @@ -226,12 +226,15 @@ def fit(self, X, y=None): return self -def _normalize_for_pca_log( +def _normalize_for_pca( count_data, - target_sum=None + target_sum=None, + log=False, + scale=False ): """ Depth normalize and log pseudocount + This operation will be entirely inplace :param count_data: Integer data :type count_data: ad.AnnData @@ -239,34 +242,38 @@ def _normalize_for_pca_log( :rtype: np.ad.AnnData """ - sc.pp.normalize_total( - count_data, - target_sum=target_sum - ) - sc.pp.log1p(count_data) - return count_data + if is_csr(count_data.X): + from .sparse_math import sparse_normalize_total + sparse_normalize_total( + count_data.X, + target_sum=target_sum + ) + else: + sc.pp.normalize_total( + count_data, + target_sum=target_sum + ) -def _normalize_for_pca_scale( - count_data, - target_sum=None -): - """ - Depth normalize and scale using truncated robust scaling + if log: + sc.pp.log1p(count_data) - :param count_data: Integer data - :type count_data: ad.AnnData - :return: Standardized data - :rtype: ad.AnnData - """ + if scale: + scaler = TruncRobustScaler(with_centering=False) + scaler.fit(count_data.X) + + if is_csr(count_data.X): + from .sparse_math import _csr_column_divide + _csr_column_divide( + count_data.X.data, + count_data.X.indices, + scaler.scale_ + ) + else: + count_data.X = scaler.transform( + count_data.X + ) - sc.pp.normalize_total( - count_data, - target_sum=target_sum - ) - count_data.X = TruncRobustScaler(with_centering=False).fit_transform( - count_data.X - ) return count_data @@ -277,29 +284,29 @@ def standardize_data( ): if method == 'log': - return _normalize_for_pca_log( + return _normalize_for_pca( count_data, - target_sum + target_sum, + log=True ) elif method == 'scale': - return _normalize_for_pca_scale( + return _normalize_for_pca( count_data, - target_sum + target_sum, + scale=True ) elif method == 'log_scale': - data = _normalize_for_pca_log( + return _normalize_for_pca( count_data, - target_sum - ) - data.X = TruncRobustScaler(with_centering=False).fit_transform( - data.X + target_sum, + log=True, + scale=True ) - return data elif method is None: return count_data else: raise ValueError( - f'method must be `log`, `scale`, or `log_scale`, ' + f'method must be None, `log`, `scale`, or `log_scale`, ' f'{method} provided' ) diff --git a/inferelator_velocity/utils/noise2self.py b/inferelator_velocity/utils/noise2self.py index 43a412b..ce9e06d 100644 --- a/inferelator_velocity/utils/noise2self.py +++ b/inferelator_velocity/utils/noise2self.py @@ -15,8 +15,7 @@ from .math import ( dot, pairwise_metric, - array_sum, - _csr_row_divide + array_sum ) from .misc import ( vprint, @@ -398,6 +397,8 @@ def _dist_to_row_stochastic(graph): # Somehow faster then element-wise \_o_/ if is_csr(graph): + from .sparse_math import _csr_row_divide + _csr_row_divide( graph.data, graph.indptr, diff --git a/inferelator_velocity/utils/sparse_math.py b/inferelator_velocity/utils/sparse_math.py new file mode 100644 index 0000000..9ada783 --- /dev/null +++ b/inferelator_velocity/utils/sparse_math.py @@ -0,0 +1,215 @@ +import numpy as np +import scipy.sparse as sps +import numba + + +def mcv_mse_sparse( + x, + pc, + rotation, + by_row=False, + axis=1, + **metric_kwargs +): + + if axis == 1: + func = _mse_rowwise + elif axis == 0: + func = _mse_columnwise + else: + raise ValueError + + y = func( + x.data, + x.indices, + x.indptr, + np.ascontiguousarray(pc), + np.ascontiguousarray(rotation, dtype=pc.dtype), + x.shape[1] + ) + + if by_row: + return y + + else: + return np.mean(y) + + +def sparse_sum(sparse_array, axis=None, squared=False): + + if not sps.issparse(sparse_array): + raise ValueError("sparse_sum requires a sparse array") + + if axis is None: + return np.sum(sparse_array.data) + + elif axis == 0: + func = _sum_columns_squared if squared else _sum_columns + return func( + sparse_array.data, + sparse_array.indices, + sparse_array.shape[1] + ) + + elif axis == 1: + func = _sum_rows_squared if squared else _sum_rows + return func( + sparse_array.data, + sparse_array.indptr + ) + + +def sparse_normalize_total(sparse_array, target_sum=10_000): + + if not is_csr(sparse_array): + raise ValueError("sparse_sum requires a sparse csr_array") + + if sparse_array.data.dtype == np.int32: + dtype = np.float32 + elif sparse_array.data.dtype == np.int64: + dtype = np.float64 + else: + dtype = None + + if dtype is not None: + float_view = sparse_array.data.view(dtype) + float_view[:] = sparse_array.data + sparse_array.data = float_view + + n_counts = sparse_sum(sparse_array, axis=1) + + if target_sum is None: + target_sum = np.median(n_counts) + + _csr_row_divide( + sparse_array.data, + sparse_array.indptr, + n_counts / target_sum + ) + + +@numba.njit(parallel=False) +def _mse_rowwise( + a_data, + a_indices, + a_indptr, + b_pcs, + b_rotation, + n_cols +): + + n_row = b_pcs.shape[0] + + output = np.zeros(n_row, dtype=float) + + for i in numba.prange(n_row): + + _idx_a = a_indices[a_indptr[i]:a_indptr[i + 1]] + _nnz_a = _idx_a.shape[0] + + row = b_pcs[i, :] @ b_rotation + + if _nnz_a == 0: + pass + else: + row[_idx_a] -= a_data[a_indptr[i]:a_indptr[i + 1]] + + output[i] = np.mean(row ** 2) + + return output + + +@numba.njit(parallel=False) +def _mse_columnwise( + a_data, + a_indices, + a_indptr, + b_pcs, + b_rotation, + n_cols +): + + n_row = b_pcs.shape[0] + output = np.zeros(n_cols, dtype=float) + + for i in numba.prange(n_row): + + _idx_a = a_indices[a_indptr[i]:a_indptr[i + 1]] + _nnz_a = _idx_a.shape[0] + + row = b_pcs[i, :] @ b_rotation + + if _nnz_a == 0: + pass + else: + row[_idx_a] -= a_data[a_indptr[i]:a_indptr[i + 1]] + + output += row ** 2 + + return output / n_row + + +@numba.njit(parallel=False) +def _sum_columns(data, indices, n_col): + + output = np.zeros(n_col, dtype=data.dtype) + + for i in numba.prange(data.shape[0]): + output[indices[i]] += data[i] + + return output + + +@numba.njit(parallel=False) +def _sum_columns_squared(data, indices, n_col): + + output = np.zeros(n_col, dtype=data.dtype) + + for i in numba.prange(data.shape[0]): + output[indices[i]] += data[i] ** 2 + + return output + + +@numba.njit(parallel=False) +def _sum_rows(data, indptr): + + output = np.zeros(indptr.shape[0] - 1, dtype=data.dtype) + + for i in numba.prange(output.shape[0]): + output[i] = np.sum(data[indptr[i]:indptr[i + 1]]) + + return output + + +@numba.njit(parallel=False) +def _sum_rows_squared(data, indptr): + + output = np.zeros(indptr.shape[0] - 1, dtype=data.dtype) + + for i in numba.prange(output.shape[0]): + output[i] = np.sum(data[indptr[i]:indptr[i + 1]] ** 2) + + return output + + +@numba.njit(parallel=False) +def _csr_row_divide(data, indptr, row_normalization_vec): + + for i in numba.prange(indptr.shape[0] - 1): + data[indptr[i]:indptr[i + 1]] /= row_normalization_vec[i] + + +@numba.njit(parallel=False) +def _csr_column_divide(data, indices, column_normalization_vec): + + for i, idx in enumerate(indices): + data[i] /= column_normalization_vec[idx] + + +def is_csr(x): + return sps.isspmatrix_csr(x) or isinstance(x, sps.csr_array) + + +def is_csc(x): + return sps.isspmatrix_csc(x) or isinstance(x, sps.csc_array)