From 5139e64af74c30497df11b121af7bd8d6784cd1f Mon Sep 17 00:00:00 2001 From: dbdimitrov Date: Tue, 11 Jun 2024 12:05:42 +0200 Subject: [PATCH] replace squidpy with liana nn in MistyGeneric --- docs/source/notebooks/bivariate.ipynb | 39 +++++++++++++++++++-- liana/method/sp/_misty/_Misty.py | 3 +- liana/method/sp/_misty/_misty_constructs.py | 24 +++++-------- liana/tests/test_bivar.py | 2 +- liana/tests/test_misty.py | 15 +++----- liana/utils/query_bandwidth.py | 9 +++-- liana/utils/spatial_neighbors.py | 10 +++--- 7 files changed, 62 insertions(+), 40 deletions(-) diff --git a/docs/source/notebooks/bivariate.ipynb b/docs/source/notebooks/bivariate.ipynb index c80be02..d84db8f 100644 --- a/docs/source/notebooks/bivariate.ipynb +++ b/docs/source/notebooks/bivariate.ipynb @@ -624,8 +624,8 @@ "text": [ "Using `.X`!\n", "Using resource `consensus`.\n", - "100%|██████████| 100/100 [01:56<00:00, 1.17s/it]\n", - "100%|██████████| 100/100 [01:01<00:00, 1.64it/s]\n" + "100%|██████████| 100/100 [01:06<00:00, 1.51it/s]\n", + "100%|██████████| 100/100 [00:38<00:00, 2.60it/s]\n" ] } ], @@ -651,6 +651,39 @@ "Now that this is done, we can extract and explore the newly-created AnnData object that counts our local scores" ] }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AnnData object with n_obs × n_vars = 4113 × 17703\n", + " obs: 'in_tissue', 'array_row', 'array_col', 'sample', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'mt_frac', 'celltype_niche', 'molecular_niche'\n", + " var: 'gene_ids', 'feature_types', 'genome', 'SYMBOL', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'mt', 'rps', 'mrp', 'rpl', 'duplicated'\n", + " uns: 'spatial', 'log1p', 'celltype_niche_colors'\n", + " obsm: 'compositions', 'mt', 'spatial', 'local_scores'\n", + " layers: 'counts'\n", + " obsp: 'spatial_connectivities'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 10, @@ -2094,7 +2127,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.8.17" }, "orig_nbformat": 4 }, diff --git a/liana/method/sp/_misty/_Misty.py b/liana/method/sp/_misty/_Misty.py index 43389ee..429e917 100644 --- a/liana/method/sp/_misty/_Misty.py +++ b/liana/method/sp/_misty/_Misty.py @@ -23,7 +23,8 @@ def __init__(self, obs:(pd.DataFrame | None)=None, spatial_key:str=K.spatial_key, enforce_obs:bool=True, - **kwargs): + **kwargs + ): """ Construct a MistyData object from a dictionary of views (anndatas). diff --git a/liana/method/sp/_misty/_misty_constructs.py b/liana/method/sp/_misty/_misty_constructs.py index 9bb8181..5cc8e09 100644 --- a/liana/method/sp/_misty/_misty_constructs.py +++ b/liana/method/sp/_misty/_misty_constructs.py @@ -13,7 +13,6 @@ from liana.resource import select_resource from liana.method._pipe_utils import prep_check_adata from liana.method.sp._utils import _add_complexes_to_var -from liana._logging import _check_if_installed def _make_view(adata, nz_threshold=0.1, add_obs=False, use_raw=False, layer=None, connecitivity=None, spatial_key=None, verbose=False): @@ -57,8 +56,7 @@ def genericMistyData(intra, cutoff = 0.1, add_juxta=True, n_neighs = 6, - verbose=False, - **kwargs, + verbose=False ): """ @@ -96,14 +94,12 @@ def genericMistyData(intra, cutoff : `float`, optional (default: 0.1) The cutoff for the connectivity matrix. add_juxta : `bool`, optional (default: True) - Whether to add the juxtaview. The juxtaview is constructed using `squidpy.gr.spatial_neighbors`, - and should represent the direct spatial neighbors of each cell/spot. + Whether to add the juxtaview. The juxtaview is constructed using only the nearest neighbors. + A bandwidth of 5 times the bandwidth of the paraview is used to ensure that the nearest neighbors within the radius. n_neighs : `int`, optional (default: 6) The number of neighbors to consider when constructing the juxtaview. verbose : `bool`, optional (default: False) Whether to print progress. - **kwargs : `dict`, optional - Additional arguments to pass to `squidpy.gr.spatial_neighbors`. Returns ------- @@ -121,14 +117,12 @@ def genericMistyData(intra, extra = intra if add_juxta: - sq = _check_if_installed('squidpy') - neighbors, _ = sq.gr.spatial_neighbors(adata=extra, - copy=True, - spatial_key=spatial_key, - set_diag=set_diag, - n_neighs=n_neighs, - **kwargs - ) + neighbors = spatial_neighbors(extra, + bandwidth=bandwidth*5, + spatial_key=spatial_key, + max_neighbours=n_neighs, + set_diag=set_diag, + inplace=False) views['juxta'] = _make_view(adata=extra, nz_threshold=nz_threshold, use_raw=extra_use_raw, layer=extra_layer, diff --git a/liana/tests/test_bivar.py b/liana/tests/test_bivar.py index 01a19eb..a03a75b 100644 --- a/liana/tests/test_bivar.py +++ b/liana/tests/test_bivar.py @@ -225,7 +225,7 @@ def test_large_adata(): ) lrdata = adata.obsm['local_scores'] np.testing.assert_almost_equal(lrdata.X.mean(), 0.00048977, decimal=4) - np.testing.assert_almost_equal(lrdata.var['morans'].mean(), 0.00030397394, decimal=4) + np.testing.assert_almost_equal(lrdata.var['morans'].mean(), 0.00012773558, decimal=4) def test_wrong_interactions(): diff --git a/liana/tests/test_misty.py b/liana/tests/test_misty.py index 7936b05..6b1c264 100644 --- a/liana/tests/test_misty.py +++ b/liana/tests/test_misty.py @@ -19,7 +19,6 @@ def test_misty_para(): cutoff=0, add_juxta=False, set_diag=False, - seed=133 ) misty(model=RandomForestModel, bypass_intra=False, seed=42, n_estimators=11) assert np.isin(list(misty.uns.keys()), ['target_metrics', 'interactions']).all() @@ -37,9 +36,7 @@ def test_misty_bypass(): bandwidth=10, add_juxta=True, set_diag=True, - cutoff=0, - coord_type="generic", - delaunay=True) + cutoff=0) misty(model=RandomForestModel, alphas=1, bypass_intra=True, seed=42, n_estimators=11) assert np.isin(['juxta', 'para'], misty.uns['target_metrics'].columns).all() assert ~np.isin(['intra'], misty.uns['target_metrics'].columns).all() @@ -51,7 +48,7 @@ def test_misty_bypass(): assert interactions['importances'].sum().round(10) == 22.0 np.testing.assert_almost_equal(interactions[(interactions['target']=='ligC') & (interactions['predictor']=='ligA')]['importances'].values, - np.array([0.0444664, 0.0551506]), decimal=3) + np.array([0.095, 0.07]), decimal=3) def test_misty_groups(): @@ -60,8 +57,6 @@ def test_misty_groups(): add_juxta=True, set_diag=False, cutoff=0, - coord_type="generic", - delaunay=True ) misty(model=RandomForestModel, alphas=1, @@ -82,7 +77,7 @@ def test_misty_groups(): # assert that there are self interactions = var_n * var_n interactions = misty.uns['interactions'] self_interactions = interactions[(interactions['target']==interactions['predictor'])] - # 11 vars * 4 envs * 3 views = 132; NOTE: However, I drop NAs -> to be refactored... + # 11 vars * 4 envs * 3 views = 132; NOTE: However, I drop NAs assert self_interactions.shape == (44, 5) assert self_interactions[self_interactions['view']=='intra']['importances'].isna().all() @@ -110,7 +105,7 @@ def test_linear_misty(): assert misty.uns['interactions'].shape == (330, 4) actual = misty.uns['interactions']['importances'].values.mean() - np.testing.assert_almost_equal(actual, 0.4941761900911731, decimal=3) + np.testing.assert_almost_equal(actual, 0.5135328101662447, decimal=3) def test_misty_mask(): @@ -126,7 +121,7 @@ def test_misty_mask(): np.testing.assert_almost_equal(misty.uns['target_metrics']['intra_R2'].mean(), 0.4248588250759459, decimal=3) assert misty.uns['interactions'].shape == (330, 4) - np.testing.assert_almost_equal(misty.uns['interactions']['importances'].sum(), 141.05332654128952, decimal=0) + np.testing.assert_almost_equal(misty.uns['interactions']['importances'].sum(), 149.30560405771703, decimal=0) def test_misty_custom(): diff --git a/liana/utils/query_bandwidth.py b/liana/utils/query_bandwidth.py index cd5f613..71df84a 100644 --- a/liana/utils/query_bandwidth.py +++ b/liana/utils/query_bandwidth.py @@ -1,12 +1,12 @@ import numpy as np from sklearn.neighbors import BallTree -from plotnine import ggplot, aes, geom_line, geom_point, theme_bw, xlab, ylab, scale_y_continuous +from plotnine import ggplot, aes, geom_line, geom_point, theme_bw, xlab, ylab from pandas import DataFrame def query_bandwidth(coordinates: np.ndarray, start: int = 0, end: int = 500, - interval_n:int = 50, + interval_n: int = 50, reference: np.ndarray = None ): """ @@ -49,13 +49,12 @@ def query_bandwidth(coordinates: np.ndarray, num_neighbors = tree.query_radius(_reference, r=max_distance, count_only=True) # calculate the average number of neighbors - avg_nn = np.mean(num_neighbors) - df.loc[n, 'neighbours'] = avg_nn + avg_nn = np.ceil(np.median(num_neighbors)) + df.loc[n, 'neighbours'] = avg_nn - 1 p = (ggplot(df, aes(x='bandwith', y='neighbours')) + geom_line() + geom_point() + - scale_y_continuous(breaks=range(start, end, interval_n)) + theme_bw(base_size=16) + xlab("Bandwidth") + ylab("Number of Neighbors") diff --git a/liana/utils/spatial_neighbors.py b/liana/utils/spatial_neighbors.py index f54ebd8..e6c300f 100644 --- a/liana/utils/spatial_neighbors.py +++ b/liana/utils/spatial_neighbors.py @@ -23,7 +23,7 @@ def _linear(distance_mtx, bandwidth): @d.dedent def spatial_neighbors(adata: AnnData, bandwidth=None, - cutoff=None, + cutoff=0.1, max_neighbours=None, kernel='gaussian', set_diag=False, @@ -103,9 +103,9 @@ def spatial_neighbors(adata: AnnData, if max_neighbours is None: max_neighbours = int(adata.shape[0] / 10) - tree = NearestNeighbors(n_neighbors=max_neighbours, - algorithm='ball_tree', - metric='euclidean').fit(_reference) + tree = NearestNeighbors(n_neighbors=max_neighbours + 1, # +1 to exclude self + algorithm='ball_tree', + metric='euclidean').fit(_reference) dist = tree.kneighbors_graph(coordinates, mode='distance') # prevent float overflow @@ -114,7 +114,7 @@ def spatial_neighbors(adata: AnnData, # define zone of indifference dist.data[dist.data < zoi] = np.inf - # NOTE: dist gets converted to a connectivity matrix + # NOTE: dist gets converted to a connectivity (proximity) matrix if kernel == 'gaussian': dist.data = _gaussian(dist.data, bandwidth) elif kernel == 'misty_rbf':