Skip to content

Commit

Permalink
fix compute_gene_trends tests
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Nov 28, 2023
1 parent 9909bde commit 9eff16b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
12 changes: 9 additions & 3 deletions tests/presults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import pandas as pd
import palantir


def test_PResults():
# Create some dummy data
pseudotime = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
entropy = None
branch_probs = pd.DataFrame({'branch1': [0.1, 0.2, 0.3, 0.4, 0.5], 'branch2': [0.5, 0.4, 0.3, 0.2, 0.1]})
branch_probs = pd.DataFrame(
{"branch1": [0.1, 0.2, 0.3, 0.4, 0.5], "branch2": [0.5, 0.4, 0.3, 0.2, 0.1]}
)
waypoints = None

# Initialize PResults object
Expand All @@ -18,6 +21,7 @@ def test_PResults():
assert presults.waypoints is None
assert np.array_equal(presults.branch_probs, branch_probs.values)


def test_gam_fit_predict():
# Create some dummy data
x = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
Expand All @@ -28,8 +32,10 @@ def test_gam_fit_predict():
spline_order = 2

# Call the function
y_pred, stds = palantir.presults.gam_fit_predict(x, y, weights, pred_x, n_splines, spline_order)
y_pred, stds = palantir.presults.gam_fit_predict(
x, y, weights, pred_x, n_splines, spline_order
)

# Asserts to check the output
assert isinstance(y_pred, np.ndarray)
assert isinstance(stds, np.ndarray)
assert isinstance(stds, np.ndarray)
56 changes: 40 additions & 16 deletions tests/presults_compute_gene_trends.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest
import pandas as pd
import numpy as np
from anndata import AnnData
import palantir


@pytest.fixture
def mock_adata():
import pandas as pd
import numpy as np
from anndata import AnnData

n_cells = 10

# Create mock data
Expand All @@ -18,21 +18,41 @@ def mock_adata():
),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(3)]),
)

adata.obsm["branch_masks"] = pd.DataFrame(
np.random.randint(2, size=(n_cells, 2)),
columns=["branch_1", "branch_2"],
index=adata.obs_names,
).astype(bool)

return adata


@pytest.fixture
def custom_mock_adata():
n_cells = 10

# Create mock data
adata = AnnData(
X=np.random.rand(n_cells, 3),
obs=pd.DataFrame(
{"custom_time": np.random.rand(n_cells)},
index=[f"cell_{i}" for i in range(n_cells)],
),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(3)]),
)

adata.obsm["custom_masks"] = pd.DataFrame(
np.random.randint(2, size=(n_cells, 2)),
columns=["branch_1", "branch_2"],
index=adata.obs_names,
).astype(bool)

return adata


@pytest.fixture
def mock_adata_old():
import pandas as pd
import numpy as np
from anndata import AnnData

n_cells = 10

# Create mock data
Expand All @@ -46,13 +66,16 @@ def mock_adata_old():
)

# Create mock branch_masks in obsm
adata.obsm["branch_masks"] = pd.DataFrame(np.random.randint(2, size=(n_cells, 2))
adata.obsm["branch_masks"] = np.random.randint(2, size=(n_cells, 2)).astype(bool)
adata.uns["branch_masks_columns"] = ["branch_1", "branch_2"]

return adata

@pytest.mark.parametrize("adata", [mock_adata, mock_adata_old])
def test_compute_gene_trends(adata):


@pytest.mark.parametrize("adata_fixture", ["mock_adata", "mock_adata_old"])
def test_compute_gene_trends(request, adata_fixture):
adata = request.getfixturevalue(adata_fixture)

# Call the function with default keys
res = palantir.presults.compute_gene_trends(adata)

Expand All @@ -65,9 +88,11 @@ def test_compute_gene_trends(adata):
assert "gene_0" in res["branch_1"]["trends"].index
assert adata.varm["gene_trends_branch_1"].shape == (3, 500)


def test_compute_gene_trends_custom_anndata(custom_mock_adata):
# Call the function with custom keys
res = palantir.presults.compute_gene_trends(
adata,
custom_mock_adata,
masks_key="custom_masks",
pseudo_time_key="custom_time",
gene_trend_key="custom_trends",
Expand All @@ -80,5 +105,4 @@ def test_compute_gene_trends(adata):
assert isinstance(res["branch_1"], dict)
assert isinstance(res["branch_1"]["trends"], pd.DataFrame)
assert "gene_0" in res["branch_1"]["trends"].index
assert adata.varm["custom_trends_branch_1"].shape == (3, 500)

assert custom_mock_adata.varm["custom_trends_branch_1"].shape == (3, 500)

0 comments on commit 9eff16b

Please sign in to comment.