diff --git a/tests/presults.py b/tests/presults.py index cc69af66..903d68f3 100644 --- a/tests/presults.py +++ b/tests/presults.py @@ -2,11 +2,14 @@ import pandas as pd import palantir + def test_PResults(): # Create some dummy data pseudotime = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) entropy = None - branch_probs = pd.DataFrame({'branch1': [0.1, 0.2, 0.3, 0.4, 0.5], 'branch2': [0.5, 0.4, 0.3, 0.2, 0.1]}) + branch_probs = pd.DataFrame( + {"branch1": [0.1, 0.2, 0.3, 0.4, 0.5], "branch2": [0.5, 0.4, 0.3, 0.2, 0.1]} + ) waypoints = None # Initialize PResults object @@ -18,6 +21,7 @@ def test_PResults(): assert presults.waypoints is None assert np.array_equal(presults.branch_probs, branch_probs.values) + def test_gam_fit_predict(): # Create some dummy data x = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) @@ -28,8 +32,10 @@ def test_gam_fit_predict(): spline_order = 2 # Call the function - y_pred, stds = palantir.presults.gam_fit_predict(x, y, weights, pred_x, n_splines, spline_order) + y_pred, stds = palantir.presults.gam_fit_predict( + x, y, weights, pred_x, n_splines, spline_order + ) # Asserts to check the output assert isinstance(y_pred, np.ndarray) - assert isinstance(stds, np.ndarray) \ No newline at end of file + assert isinstance(stds, np.ndarray) diff --git a/tests/presults_compute_gene_trends.py b/tests/presults_compute_gene_trends.py index 845f2c99..5dbc2ef7 100644 --- a/tests/presults_compute_gene_trends.py +++ b/tests/presults_compute_gene_trends.py @@ -1,12 +1,12 @@ import pytest +import pandas as pd +import numpy as np +from anndata import AnnData import palantir + @pytest.fixture def mock_adata(): - import pandas as pd - import numpy as np - from anndata import AnnData - n_cells = 10 # Create mock data @@ -18,21 +18,41 @@ def mock_adata(): ), var=pd.DataFrame(index=[f"gene_{i}" for i in range(3)]), ) - + adata.obsm["branch_masks"] = pd.DataFrame( np.random.randint(2, size=(n_cells, 2)), columns=["branch_1", "branch_2"], index=adata.obs_names, + ).astype(bool) + + return adata + + +@pytest.fixture +def custom_mock_adata(): + n_cells = 10 + + # Create mock data + adata = AnnData( + X=np.random.rand(n_cells, 3), + obs=pd.DataFrame( + {"custom_time": np.random.rand(n_cells)}, + index=[f"cell_{i}" for i in range(n_cells)], + ), + var=pd.DataFrame(index=[f"gene_{i}" for i in range(3)]), ) + adata.obsm["custom_masks"] = pd.DataFrame( + np.random.randint(2, size=(n_cells, 2)), + columns=["branch_1", "branch_2"], + index=adata.obs_names, + ).astype(bool) + return adata + @pytest.fixture def mock_adata_old(): - import pandas as pd - import numpy as np - from anndata import AnnData - n_cells = 10 # Create mock data @@ -46,13 +66,16 @@ def mock_adata_old(): ) # Create mock branch_masks in obsm - adata.obsm["branch_masks"] = pd.DataFrame(np.random.randint(2, size=(n_cells, 2)) + adata.obsm["branch_masks"] = np.random.randint(2, size=(n_cells, 2)).astype(bool) adata.uns["branch_masks_columns"] = ["branch_1", "branch_2"] return adata - -@pytest.mark.parametrize("adata", [mock_adata, mock_adata_old]) -def test_compute_gene_trends(adata): + + +@pytest.mark.parametrize("adata_fixture", ["mock_adata", "mock_adata_old"]) +def test_compute_gene_trends(request, adata_fixture): + adata = request.getfixturevalue(adata_fixture) + # Call the function with default keys res = palantir.presults.compute_gene_trends(adata) @@ -65,9 +88,11 @@ def test_compute_gene_trends(adata): assert "gene_0" in res["branch_1"]["trends"].index assert adata.varm["gene_trends_branch_1"].shape == (3, 500) + +def test_compute_gene_trends_custom_anndata(custom_mock_adata): # Call the function with custom keys res = palantir.presults.compute_gene_trends( - adata, + custom_mock_adata, masks_key="custom_masks", pseudo_time_key="custom_time", gene_trend_key="custom_trends", @@ -80,5 +105,4 @@ def test_compute_gene_trends(adata): assert isinstance(res["branch_1"], dict) assert isinstance(res["branch_1"]["trends"], pd.DataFrame) assert "gene_0" in res["branch_1"]["trends"].index - assert adata.varm["custom_trends_branch_1"].shape == (3, 500) - + assert custom_mock_adata.varm["custom_trends_branch_1"].shape == (3, 500)