diff --git a/tests/unit/anndata/test_anndata_factory.py b/tests/unit/anndata/test_anndata_factory.py index 24cf25b3..004cd067 100644 --- a/tests/unit/anndata/test_anndata_factory.py +++ b/tests/unit/anndata/test_anndata_factory.py @@ -1,5 +1,7 @@ """Unit tests for the AnnDataFactory class.""" +from unittest.mock import patch + import numpy as np import pandas as pd import pytest @@ -8,15 +10,21 @@ from alphabase.psm_reader.keys import PsmDfCols -def test_initialization_with_missing_columns(): - """Test that an error is raised when the input DataFrame is missing required columns.""" - psm_df = pd.DataFrame( +def _get_test_psm_df(): + """Return a test PSM DataFrame.""" + return pd.DataFrame( { - PsmDfCols.RAW_NAME: ["raw1", "raw2"], - PsmDfCols.PROTEINS: ["protein1", "protein2"], + PsmDfCols.RAW_NAME: ["raw1", "raw1", "raw2"], + PsmDfCols.PROTEINS: ["protein1", "protein2", "protein1"], + PsmDfCols.INTENSITY: [100, 200, 300], } ) + +def test_initialization_with_missing_columns(): + """Test that an error is raised when the input DataFrame is missing required columns.""" + psm_df = _get_test_psm_df().drop(columns=[PsmDfCols.INTENSITY]) + with pytest.raises(ValueError, match="Missing required columns: \['intensity'\]"): # when AnnDataFactory(psm_df) @@ -24,13 +32,7 @@ def test_initialization_with_missing_columns(): def test_create_anndata_with_valid_dataframe(): """Test that an AnnData object is created correctly from a valid input DataFrame.""" - psm_df = pd.DataFrame( - { - PsmDfCols.RAW_NAME: ["raw1", "raw1", "raw2"], - PsmDfCols.PROTEINS: ["protein1", "protein2", "protein1"], - PsmDfCols.INTENSITY: [100, 200, 300], - } - ) + psm_df = _get_test_psm_df() factory = AnnDataFactory(psm_df) # when @@ -39,7 +41,9 @@ def test_create_anndata_with_valid_dataframe(): assert adata.shape == (2, 2) assert adata.obs_names.tolist() == ["raw1", "raw2"] assert adata.var_names.tolist() == ["protein1", "protein2"] - assert np.array_equal(adata.X, np.array([[100, 200], [300, np.nan]])) + assert np.array_equal( + adata.X, np.array([[100, 200], [300, np.nan]]), equal_nan=True + ) def test_create_anndata_with_missing_intensity_values(): @@ -97,3 +101,20 @@ def test_create_anndata_with_empty_dataframe(): adata = factory.create_anndata() assert adata.shape == (0, 0) + + +@patch("alphabase.psm_reader.psm_reader.psm_reader_provider.get_reader") +def test_from_files(mock_reader): + mock_reader.return_value.load.return_value = _get_test_psm_df() + + factory = AnnDataFactory.from_files(["file1", "file2"], reader_type="maxquant") + + # when + adata = factory.create_anndata() + + assert adata.shape == (2, 2) + assert adata.obs_names.tolist() == ["raw1", "raw2"] + assert adata.var_names.tolist() == ["protein1", "protein2"] + assert np.array_equal( + adata.X, np.array([[100, 200], [300, np.nan]]), equal_nan=True + )