diff --git a/tests/plot.py b/tests/plot.py index e204b6bd..46b1b920 100644 --- a/tests/plot.py +++ b/tests/plot.py @@ -8,6 +8,7 @@ from matplotlib.markers import MarkerStyle from palantir.plot import ( + density_2d, plot_molecules_per_cell_and_gene, cell_types, highlight_cells_on_umap, @@ -89,13 +90,21 @@ def mock_cells(): def mock_gene_trends(): return { "Branch_1": { - "trends": pd.DataFrame({"0.0": [0.2, 0.3], "1.0": [0.4, 0.5]}, index=["Gene1", "Gene2"]), - "std": pd.DataFrame({"0.0": [0.02, 0.03], "1.0": [0.04, 0.05]}, index=["Gene1", "Gene2"]) + "trends": pd.DataFrame( + {"0.0": [0.2, 0.3], "1.0": [0.4, 0.5]}, index=["Gene1", "Gene2"] + ), + "std": pd.DataFrame( + {"0.0": [0.02, 0.03], "1.0": [0.04, 0.05]}, index=["Gene1", "Gene2"] + ), }, "Branch_2": { - "trends": pd.DataFrame({"0.0": [0.1, 0.2], "1.0": [0.2, 0.3]}, index=["Gene1", "Gene2"]), - "std": pd.DataFrame({"0.0": [0.01, 0.02], "1.0": [0.02, 0.03]}, index=["Gene1", "Gene2"]) - } + "trends": pd.DataFrame( + {"0.0": [0.1, 0.2], "1.0": [0.2, 0.3]}, index=["Gene1", "Gene2"] + ), + "std": pd.DataFrame( + {"0.0": [0.01, 0.02], "1.0": [0.02, 0.03]}, index=["Gene1", "Gene2"] + ), + }, } @@ -118,23 +127,35 @@ def mock_anndata(mock_umap_df): np.random.randint(2, size=(100, 3)), columns=["a", "b", "c"], index=mock_umap_df.index, - dtype=bool + dtype=bool, ) - for branch in ['a', 'b', 'c']: - adata.uns[f'gene_trends_{branch}_pseudotime'] = np.linspace(0, 1, 10) - adata.varm[f'gene_trends_{branch}'] = pd.DataFrame( + for branch in ["a", "b", "c"]: + adata.uns[f"gene_trends_{branch}_pseudotime"] = np.linspace(0, 1, 10) + adata.varm[f"gene_trends_{branch}"] = pd.DataFrame( np.random.rand(5, 10), index=adata.var_names, - columns=adata.uns[f'gene_trends_{branch}_pseudotime'], + columns=adata.uns[f"gene_trends_{branch}_pseudotime"], ) - adata.var['clusters'] = pd.Series( - ['A', 'A', 'B', 'B', 'B'], + adata.var["clusters"] = pd.Series( + ["A", "A", "B", "B", "B"], index=adata.var_names, ) - adata.var['gene_score'] = np.random.rand(5) + adata.var["gene_score"] = np.random.rand(5) return adata +def test_density_2d(): + # Test with random data + x = np.random.rand(100) + y = np.random.rand(100) + x_out, y_out, z_out = density_2d(x, y) + + # Validate output shape and types + assert x_out.shape == x.shape + assert y_out.shape == y.shape + assert z_out.shape == x.shape + + def test_plot_molecules_per_cell_and_gene(): # Create synthetic data data = np.random.rand(100, 20) @@ -188,43 +209,47 @@ def test_cell_types_n_cols(mock_tsne, mock_clusters): assert ncols == 1, "Number of columns should be 1" -# Test highlight_cells_on_umap -def test_highlight_cells_on_umap(mock_umap_df, mock_anndata): - # Define cells to highlight - highlight_cells_dict = {"cell_1": "A", "cell_3": "B"} +def test_highlight_cells_on_umap(mock_anndata, mock_umap_df): + # Test KeyError + with pytest.raises(KeyError): + highlight_cells_on_umap( + mock_anndata, ["cell_1"], embedding_basis="unknown_basis" + ) - # Test with DataFrame - fig, ax = highlight_cells_on_umap(mock_umap_df, highlight_cells_dict) - assert isinstance(fig, plt.Figure), "Output should include a matplotlib Figure" - assert ax.collections, "Should have scatter plots" + # Test TypeError for data + with pytest.raises(TypeError): + highlight_cells_on_umap("InvalidType", ["cell_1"]) - # Test with AnnData - fig, ax = highlight_cells_on_umap(mock_anndata, highlight_cells_dict) - assert isinstance(fig, plt.Figure), "Output should include a matplotlib Figure" - assert ax.collections, "Should have scatter plots" + # Test TypeError for cells + with pytest.raises(TypeError): + highlight_cells_on_umap(mock_anndata, 123) - # Test annotation_offset - fig, ax = highlight_cells_on_umap( - mock_umap_df, highlight_cells_dict, annotation_offset=0.05 - ) - assert isinstance(fig, plt.Figure), "Output should include a matplotlib Figure" + # Test normal use case with AnnData + fig, ax = highlight_cells_on_umap(mock_anndata, ["cell_1", "cell_2"]) + assert isinstance(fig, plt.Figure) + assert isinstance(ax, plt.Axes) - # Test size of highlighted points + # Test normal use case with DataFrame + fig, ax = highlight_cells_on_umap(mock_umap_df, ["cell_1", "cell_2"]) + assert isinstance(fig, plt.Figure) + assert isinstance(ax, plt.Axes) + + # Test with different types for cells parameter fig, ax = highlight_cells_on_umap( - mock_umap_df, highlight_cells_dict, s_highlighted=20 + mock_anndata, {"cell_1": "label1", "cell_2": "label2"} ) - assert np.any( - [p.get_sizes()[0] == 20 for p in ax.collections] - ), "Highlighted scatter point size should be 20" + assert isinstance(fig, plt.Figure) + assert isinstance(ax, plt.Axes) - # Test errors - with pytest.raises(KeyError): - highlight_cells_on_umap( - mock_anndata, highlight_cells_dict, embedding_basis="X_invalid" - ) + mask = np.array([True if i < 2 else False for i in range(100)]) + fig, ax = highlight_cells_on_umap(mock_anndata, mask) + assert isinstance(fig, plt.Figure) + assert isinstance(ax, plt.Axes) - with pytest.raises(TypeError): - highlight_cells_on_umap(mock_anndata, 3) # Invalid 'cells' argument + cell_series = pd.Series({"cell_1": "label1", "cell_2": "label2"}) + fig, ax = highlight_cells_on_umap(mock_anndata, cell_series) + assert isinstance(fig, plt.Figure) + assert isinstance(ax, plt.Axes) # Test plot_tsne_by_cell_sizes @@ -253,7 +278,7 @@ def test_plot_tsne_by_cell_sizes(mock_data, mock_tsne): def test_plot_gene_expression(mock_gene_data, mock_tsne): genes = ["gene_0", "gene_1"] - fig, axs = plot_gene_expression(mock_gene_data, mock_tsne, genes) + fig, axs = plot_gene_expression(mock_gene_data, mock_tsne, genes, plot_scale=True) assert isinstance(fig, plt.Figure) @@ -365,6 +390,7 @@ def test_plot_terminal_state_probs_custom_args(mock_anndata, mock_cells): ax = fig.axes[0] # Assuming first subplot holds the first bar plot assert ax.patches[0].get_linewidth() == 2.0 + # Test if the function uses the correct keys and raises appropriate errors def test_plot_branch_selection_keys(mock_anndata): # This will depend on how your mock_anndata is structured @@ -377,17 +403,22 @@ def test_plot_branch_selection_keys(mock_anndata): with pytest.raises(KeyError): plot_branch_selection(mock_anndata, embedding_basis="invalid_basis") + # Test the scatter custom arguments def test_plot_branch_selection_custom_args(mock_anndata): - fig = plot_branch_selection(mock_anndata, marker='x', alpha=0.5) - ax1, ax2 = fig.axes[0], fig.axes[1] # Assuming the first two axes correspond to the first fate - + fig = plot_branch_selection(mock_anndata, marker="x", alpha=0.5) + ax1, ax2 = ( + fig.axes[0], + fig.axes[1], + ) # Assuming the first two axes correspond to the first fate + # Extract the scatter plots, assuming that the plot with custom markers is the last one scatter1, scatter2 = ax1.collections[-1], ax2.collections[-1] - + alpha1 = scatter1.get_alpha() assert alpha1 == 0.5 + # Test 1: Basic functionality def test_plot_gene_trends_legacy_basic(mock_gene_trends): fig = plot_gene_trends_legacy(mock_gene_trends) @@ -396,6 +427,7 @@ def test_plot_gene_trends_legacy_basic(mock_gene_trends): assert len(axes) == 2 # Perform additional checks on axes content if needed + # Test 2: Custom gene list def test_plot_gene_trends_legacy_custom_genes(mock_gene_trends): fig = plot_gene_trends_legacy(mock_gene_trends, genes=["Gene1"]) @@ -405,6 +437,7 @@ def test_plot_gene_trends_legacy_custom_genes(mock_gene_trends): # Check if the title of the subplot matches the custom gene assert axes[0].get_title() == "Gene1" + # Test 3: Color consistency def test_plot_gene_trends_legacy_color_consistency(mock_gene_trends): fig = plot_gene_trends_legacy(mock_gene_trends) @@ -414,18 +447,21 @@ def test_plot_gene_trends_legacy_color_consistency(mock_gene_trends): # Check if the colors are consistent across different genes assert colors_1 == colors_2 + # Test 1: Basic Functionality with AnnData def test_plot_gene_trends_basic_anndata(mock_anndata): fig = plot_gene_trends(mock_anndata) axes = fig.axes assert len(axes) == mock_anndata.n_vars + # Test 2: Basic Functionality with Dictionary def test_plot_gene_trends_basic_dict(mock_gene_trends): fig = plot_gene_trends(mock_gene_trends) axes = fig.axes assert len(axes) == 2 # Mock data contains 2 genes + # Test 3: Custom Genes def test_plot_gene_trends_custom_genes(mock_anndata): fig = plot_gene_trends(mock_anndata, genes=["gene_1"]) @@ -433,145 +469,183 @@ def test_plot_gene_trends_custom_genes(mock_anndata): assert len(axes) == 1 assert axes[0].get_title() == "gene_1" + # Test 4: Custom Branch Names def test_plot_gene_trends_custom_branch_names(mock_anndata): fig = plot_gene_trends(mock_anndata, branch_names=["a", "b"]) axes = fig.axes assert len(axes) == mock_anndata.n_vars + # Test 5: Error Handling - Invalid Data Type def test_plot_gene_trends_invalid_data_type(): with pytest.raises(ValueError): plot_gene_trends("invalid_data_type") + # Test 6: Error Handling - Missing Key def test_plot_gene_trends_missing_key(mock_anndata): with pytest.raises(KeyError): - plot_gene_trends(mock_anndata, gene_trend_key="missing_key", branch_names="missing_branch") + plot_gene_trends( + mock_anndata, gene_trend_key="missing_key", branch_names="missing_branch" + ) + @pytest.mark.parametrize("wrong_type", [123, True, 1.23, "unknown_key"]) def test_plot_stats_key_errors(mock_anndata, wrong_type): with pytest.raises(KeyError): plot_stats(mock_anndata, x=wrong_type, y="palantir_pseudotime") + def test_plot_stats_basic(mock_anndata): fig, ax = plot_stats(mock_anndata, x="palantir_pseudotime", y="palantir_entropy") assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) + def test_plot_stats_optional_parameters(mock_anndata): - fig, ax = plot_stats(mock_anndata, x="palantir_pseudotime", y="palantir_entropy", color='palantir_entropy') + fig, ax = plot_stats( + mock_anndata, + x="palantir_pseudotime", + y="palantir_entropy", + color="palantir_entropy", + ) + def test_plot_stats_masking(mock_anndata): # Create a condition here that you want to mask - mask_condition = mock_anndata.obs['palantir_pseudotime'] > 0.5 - mock_anndata.obsm['branch_masks'] = mask_condition - fig, ax = plot_stats(mock_anndata, x="palantir_pseudotime", y="palantir_entropy", masks_key='branch_masks') + mask_condition = mock_anndata.obs["palantir_pseudotime"] > 0.5 + mock_anndata.obsm["branch_masks"] = mask_condition + fig, ax = plot_stats( + mock_anndata, + x="palantir_pseudotime", + y="palantir_entropy", + masks_key="branch_masks", + ) + @pytest.mark.parametrize( "branch_name, position, pseudo_time_key, should_fail", - [("a", "gene_1", "palantir_pseudotime", False), - (123, "gene_1", "palantir_pseudotime", True), - ("b", "gene_1", 123, True)] + [ + ("a", "gene_1", "palantir_pseudotime", False), + (123, "gene_1", "palantir_pseudotime", True), + ("b", "gene_1", 123, True), + ], ) -def test_plot_branch_input_validation(mock_anndata, branch_name, position, pseudo_time_key, should_fail): +def test_plot_branch_input_validation( + mock_anndata, branch_name, position, pseudo_time_key, should_fail +): if should_fail: with pytest.raises((TypeError, ValueError)): - plot_branch(mock_anndata, branch_name, position, pseudo_time_key=pseudo_time_key) + plot_branch( + mock_anndata, branch_name, position, pseudo_time_key=pseudo_time_key + ) else: - plot_branch(mock_anndata, branch_name, position, pseudo_time_key=pseudo_time_key) + plot_branch( + mock_anndata, branch_name, position, pseudo_time_key=pseudo_time_key + ) plt.close() + def test_plot_branch_functionality(mock_anndata): fig, ax = plot_branch(mock_anndata, "a", "gene_1") assert ax.get_xlabel() == "Pseudotime" + def test_plot_trend_type_validation(mock_anndata): with pytest.raises(TypeError): plot_trend("string_instead_of_anndata", "a", "gene_1") with pytest.raises(TypeError): plot_trend(mock_anndata, 123, "gene_1") + def test_plot_trend_value_validation(mock_anndata): with pytest.raises((ValueError, KeyError)): plot_trend(mock_anndata, "nonexistent_branch", "gene_1") + def test_plot_trend_plotting(mock_anndata): fig, ax = plot_trend(mock_anndata, "a", "gene_1") assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) + def test_plot_gene_trend_heatmaps(mock_anndata): - fig = plot_gene_trend_heatmaps(mock_anndata, genes=['gene_1', 'gene_2'], scaling="z-score") - + fig = plot_gene_trend_heatmaps( + mock_anndata, genes=["gene_1", "gene_2"], scaling="z-score" + ) + # Test returned type assert isinstance(fig, plt.Figure) - + # Test number of subplots (should be same as number of branches) - assert len(fig.axes) == len(mock_anndata.obsm['branch_masks'].columns) * 2 - + assert len(fig.axes) == len(mock_anndata.obsm["branch_masks"].columns) * 2 + plt.close(fig) + def test_plot_gene_trend_clusters(mock_anndata): # Test with AnnData object fig = plot_gene_trend_clusters(mock_anndata, branch_name="a", clusters="clusters") assert isinstance(fig, plt.Figure) - + # Verify number of subplots - unique_clusters = mock_anndata.var['clusters'].unique() + unique_clusters = mock_anndata.var["clusters"].unique() expected_subplots = len(unique_clusters) assert len(fig.axes) == expected_subplots - + # Test DataFrame input - trends_df = mock_anndata.varm['gene_trends_a'] - clusters_series = mock_anndata.var['clusters'] + trends_df = mock_anndata.varm["gene_trends_a"] + clusters_series = mock_anndata.var["clusters"] fig_df = plot_gene_trend_clusters(trends_df, clusters=clusters_series) - + assert isinstance(fig_df, plt.Figure) assert len(fig_df.axes) == expected_subplots - + plt.close(fig) plt.close(fig_df) + def test_gene_score_histogram(mock_anndata): # Test with minimum required parameters - fig = gene_score_histogram(mock_anndata, 'gene_score') + fig = gene_score_histogram(mock_anndata, "gene_score") assert isinstance(fig, plt.Figure) plt.close(fig) - + # Test with optional parameters fig = gene_score_histogram( mock_anndata, - 'gene_score', - genes=['gene_0', 'gene_1'], + "gene_score", + genes=["gene_0", "gene_1"], bins=50, quantile=0.9, ) assert isinstance(fig, plt.Figure) plt.close(fig) - + # Test with None quantile fig = gene_score_histogram( mock_anndata, - 'gene_score', + "gene_score", quantile=None, ) assert isinstance(fig, plt.Figure) plt.close(fig) + def test_gene_score_histogram_errors(mock_anndata): # Test with invalid AnnData with pytest.raises(ValueError): - gene_score_histogram(None, 'gene_score') - + gene_score_histogram(None, "gene_score") + # Test with invalid score_key with pytest.raises(ValueError): - gene_score_histogram(mock_anndata, 'invalid_key') - + gene_score_histogram(mock_anndata, "invalid_key") + # Test with invalid gene with pytest.raises(ValueError): - gene_score_histogram(mock_anndata, 'gene_score', genes=['invalid_gene']) - + gene_score_histogram(mock_anndata, "gene_score", genes=["invalid_gene"]) + # Test with invalid quantile with pytest.raises(ValueError): - gene_score_histogram(mock_anndata, 'gene_score', quantile=1.5) + gene_score_histogram(mock_anndata, "gene_score", quantile=1.5)