Skip to content

Commit

Permalink
refactor and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mschwoer committed Nov 25, 2024
1 parent c8cfb36 commit f6bb454
Showing 1 changed file with 34 additions and 13 deletions.
47 changes: 34 additions & 13 deletions tests/unit/anndata/test_anndata_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Unit tests for the AnnDataFactory class."""

from unittest.mock import patch

import numpy as np
import pandas as pd
import pytest
Expand All @@ -8,29 +10,29 @@
from alphabase.psm_reader.keys import PsmDfCols


def test_initialization_with_missing_columns():
"""Test that an error is raised when the input DataFrame is missing required columns."""
psm_df = pd.DataFrame(
def _get_test_psm_df():
"""Return a test PSM DataFrame."""
return pd.DataFrame(
{
PsmDfCols.RAW_NAME: ["raw1", "raw2"],
PsmDfCols.PROTEINS: ["protein1", "protein2"],
PsmDfCols.RAW_NAME: ["raw1", "raw1", "raw2"],
PsmDfCols.PROTEINS: ["protein1", "protein2", "protein1"],
PsmDfCols.INTENSITY: [100, 200, 300],
}
)


def test_initialization_with_missing_columns():
"""Test that an error is raised when the input DataFrame is missing required columns."""
psm_df = _get_test_psm_df().drop(columns=[PsmDfCols.INTENSITY])

with pytest.raises(ValueError, match="Missing required columns: \['intensity'\]"):
# when
AnnDataFactory(psm_df)


def test_create_anndata_with_valid_dataframe():
"""Test that an AnnData object is created correctly from a valid input DataFrame."""
psm_df = pd.DataFrame(
{
PsmDfCols.RAW_NAME: ["raw1", "raw1", "raw2"],
PsmDfCols.PROTEINS: ["protein1", "protein2", "protein1"],
PsmDfCols.INTENSITY: [100, 200, 300],
}
)
psm_df = _get_test_psm_df()
factory = AnnDataFactory(psm_df)

# when
Expand All @@ -39,7 +41,9 @@ def test_create_anndata_with_valid_dataframe():
assert adata.shape == (2, 2)
assert adata.obs_names.tolist() == ["raw1", "raw2"]
assert adata.var_names.tolist() == ["protein1", "protein2"]
assert np.array_equal(adata.X, np.array([[100, 200], [300, np.nan]]))
assert np.array_equal(
adata.X, np.array([[100, 200], [300, np.nan]]), equal_nan=True
)


def test_create_anndata_with_missing_intensity_values():
Expand Down Expand Up @@ -97,3 +101,20 @@ def test_create_anndata_with_empty_dataframe():
adata = factory.create_anndata()

assert adata.shape == (0, 0)


@patch("alphabase.psm_reader.psm_reader.psm_reader_provider.get_reader")
def test_from_files(mock_reader):
mock_reader.return_value.load.return_value = _get_test_psm_df()

factory = AnnDataFactory.from_files(["file1", "file2"], reader_type="maxquant")

# when
adata = factory.create_anndata()

assert adata.shape == (2, 2)
assert adata.obs_names.tolist() == ["raw1", "raw2"]
assert adata.var_names.tolist() == ["protein1", "protein2"]
assert np.array_equal(
adata.X, np.array([[100, 200], [300, np.nan]]), equal_nan=True
)

0 comments on commit f6bb454

Please sign in to comment.