diff --git a/alphastats/DataSet.py b/alphastats/DataSet.py index 271ee1d7..cb0b4d65 100644 --- a/alphastats/DataSet.py +++ b/alphastats/DataSet.py @@ -64,19 +64,20 @@ def __init__( """ self._check_loader(loader=loader) + self._data_harmonizer = DataHarmonizer(loader, sample_column) + # fill data from loader - self.rawinput: pd.DataFrame = DataHarmonizer(loader).get_harmonized_rawinput( + self.rawinput: pd.DataFrame = self._data_harmonizer.get_harmonized_rawinput( loader.rawinput ) self.filter_columns: List[str] = loader.filter_columns - self.software: str = loader.software - self._intensity_column: Union[str, list] = ( loader._extract_sample_names( - metadata=self.metadata, sample_column=self.sample + metadata=self.metadata, sample_column=sample_column ) - if loader == "Generic" + if loader + == "Generic" # TODO is this ever the case? not rather instanceof(loader, GenericLoader)? else loader.intensity_column ) @@ -86,14 +87,13 @@ def __init__( rawinput=self.rawinput, intensity_column=self._intensity_column, metadata_path_or_df=metadata_path_or_df, - sample_column=sample_column, + data_harmonizer=self._data_harmonizer, ) - rawmat, mat, metadata, sample, preprocessing_info = self._get_init_dataset() + rawmat, mat, metadata, preprocessing_info = self._get_init_dataset() self.rawmat: pd.DataFrame = rawmat self.mat: pd.DataFrame = mat self.metadata: pd.DataFrame = metadata - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info self._gene_name_to_protein_id_map = ( @@ -115,11 +115,11 @@ def __init__( def _get_init_dataset( self, - ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, str, Dict]: + ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, Dict]: """Get the initial data structure for the DataSet.""" rawmat, mat = self._dataset_factory.create_matrix_from_rawinput() - metadata, sample = self._dataset_factory.create_metadata(mat) + metadata = self._dataset_factory.create_metadata(mat) preprocessing_info = Preprocess.init_preprocessing_info( num_samples=mat.shape[0], @@ -128,7 +128,7 @@ def _get_init_dataset( filter_columns=self.filter_columns, ) - return rawmat, mat, metadata, sample, preprocessing_info + return rawmat, mat, metadata, preprocessing_info def _check_loader(self, loader): """Checks if the Loader is from class AlphaPeptLoader, MaxQuantLoader, DIANNLoader, FragPipeLoader @@ -157,7 +157,6 @@ def _get_preprocess(self) -> Preprocess: return Preprocess( self.filter_columns, self.rawinput, - self.sample, self.metadata, self.preprocessing_info, self.mat, @@ -194,7 +193,6 @@ def reset_preprocessing(self): self.rawmat, self.mat, self.metadata, - self.sample, self.preprocessing_info, ) = self._get_init_dataset() @@ -207,7 +205,6 @@ def _get_statistics(self) -> Statistics: return Statistics( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, ) @@ -232,8 +229,8 @@ def diff_expression_analysis( def tukey_test(self, protein_id: str, group: str) -> pd.DataFrame: """A wrapper for tukey_test.tukey_test(), see documentation there.""" - df = self.mat[[protein_id]].reset_index().rename(columns={"index": self.sample}) - df = df.merge(self.metadata, how="inner", on=[self.sample]) + df = self.mat[[protein_id]].reset_index().rename(columns={"index": Cols.SAMPLE}) + df = df.merge(self.metadata, how="inner", on=[Cols.SAMPLE]) return tukey_test( df, @@ -265,7 +262,6 @@ def plot_pca(self, group: Optional[str] = None, circle: bool = False): dimensionality_reduction = DimensionalityReduction( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group=group, circle=circle, @@ -293,7 +289,6 @@ def plot_tsne( dimensionality_reduction = DimensionalityReduction( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group=group, method="tsne", @@ -317,7 +312,6 @@ def plot_umap(self, group: Optional[str] = None, circle: bool = False): dimensionality_reduction = DimensionalityReduction( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group=group, method="umap", @@ -398,7 +392,6 @@ def plot_volcano( mat=self.mat, rawinput=self.rawinput, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group1=group1, group2=group2, @@ -482,7 +475,6 @@ def plot_intensity( intensity_plot = IntensityPlot( mat=self.mat, metadata=self.metadata, - sample=self.sample, intensity_column=self._intensity_column, preprocessing_info=self.preprocessing_info, protein_id=protein_id, @@ -519,7 +511,6 @@ def plot_clustermap( clustermap = ClusterMap( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, label_bar=label_bar, only_significant=only_significant, @@ -542,7 +533,6 @@ def _get_plot(self) -> Plot: self.mat, self.rawmat, self.metadata, - self.sample, self.preprocessing_info, ) diff --git a/alphastats/DataSet_Plot.py b/alphastats/DataSet_Plot.py index b132f2a7..9c16a027 100644 --- a/alphastats/DataSet_Plot.py +++ b/alphastats/DataSet_Plot.py @@ -7,6 +7,7 @@ import scipy import seaborn as sns +from alphastats.keys import Cols from alphastats.plots.PlotUtils import PlotUtils from alphastats.utils import check_for_missing_values @@ -50,13 +51,11 @@ def __init__( mat: pd.DataFrame, rawmat: pd.DataFrame, metadata: pd.DataFrame, - sample: str, preprocessing_info: Dict, ): self.mat: pd.DataFrame = mat self.rawmat: pd.DataFrame = rawmat self.metadata: pd.DataFrame = metadata - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info def plot_correlation_matrix(self, method: str = "pearson"): # TODO unused @@ -95,15 +94,15 @@ def plot_sampledistribution( # create long df matrix = self.mat if not use_raw else self.rawmat df = matrix.unstack().reset_index() - df.rename(columns={"level_1": self.sample, 0: "Intensity"}, inplace=True) + df.rename(columns={"level_1": Cols.SAMPLE, 0: "Intensity"}, inplace=True) if color is not None: - df = df.merge(self.metadata, how="inner", on=[self.sample]) + df = df.merge(self.metadata, how="inner", on=[Cols.SAMPLE]) if method == "violin": fig = px.violin( df, - x=self.sample, + x=Cols.SAMPLE, y="Intensity", color=color, template="simple_white+alphastats_colors", @@ -112,7 +111,7 @@ def plot_sampledistribution( elif method == "box": fig = px.box( df, - x=self.sample, + x=Cols.SAMPLE, y="Intensity", color=color, template="simple_white+alphastats_colors", diff --git a/alphastats/DataSet_Preprocess.py b/alphastats/DataSet_Preprocess.py index 42d10724..1b7338ed 100644 --- a/alphastats/DataSet_Preprocess.py +++ b/alphastats/DataSet_Preprocess.py @@ -46,7 +46,6 @@ def __init__( self, filter_columns: List[str], rawinput: pd.DataFrame, - sample: str, metadata: pd.DataFrame, preprocessing_info: Dict, mat: pd.DataFrame, @@ -54,7 +53,6 @@ def __init__( self.filter_columns = filter_columns self.rawinput = rawinput - self.sample = sample self.metadata = metadata self.preprocessing_info = preprocessing_info @@ -88,17 +86,17 @@ def init_preprocessing_info( def _remove_samples(self, sample_list: list): # exclude samples for analysis self.mat = self.mat.drop(sample_list) - self.metadata = self.metadata[~self.metadata[self.sample].isin(sample_list)] + self.metadata = self.metadata[~self.metadata[Cols.SAMPLE].isin(sample_list)] @staticmethod def subset( - mat: pd.DataFrame, metadata: pd.DataFrame, sample: str, preprocessing_info: Dict + mat: pd.DataFrame, metadata: pd.DataFrame, preprocessing_info: Dict ) -> pd.DataFrame: """Filter matrix so only samples that are described in metadata are also found in matrix.""" preprocessing_info.update( {PreprocessingStateKeys.NUM_SAMPLES: metadata.shape[0]} ) - return mat[mat.index.isin(metadata[sample].tolist())] + return mat[mat.index.isin(metadata[Cols.SAMPLE].tolist())] def _remove_na_values(self, cut_off): if ( @@ -350,7 +348,7 @@ def batch_correction(self, batch: str) -> pd.DataFrame: from combat.pycombat import pycombat data = self.mat.transpose() - series_of_batches = self.metadata.set_index(self.sample).reindex( + series_of_batches = self.metadata.set_index(Cols.SAMPLE).reindex( data.columns.to_list() )[batch] @@ -418,6 +416,8 @@ def preprocess( ]: raise ValueError(f"Invalid keyword argument: {k}") + # TODO this is a stateful method as we change self.mat, self.metadata and self.processing_info + # refactor such that it does not change self.mat etc but just return the latest result if remove_contaminations: self._filter() @@ -425,9 +425,7 @@ def preprocess( self._remove_samples(sample_list=remove_samples) if subset: - self.mat = self.subset( - self.mat, self.metadata, self.sample, self.preprocessing_info - ) + self.mat = self.subset(self.mat, self.metadata, self.preprocessing_info) if data_completeness > 0: self._remove_na_values(cut_off=data_completeness) diff --git a/alphastats/DataSet_Statistics.py b/alphastats/DataSet_Statistics.py index e1300e52..e292a99d 100644 --- a/alphastats/DataSet_Statistics.py +++ b/alphastats/DataSet_Statistics.py @@ -4,6 +4,7 @@ import pandas as pd import pingouin +from alphastats.keys import Cols from alphastats.statistics.Anova import Anova from alphastats.statistics.DifferentialExpressionAnalysis import ( DifferentialExpressionAnalysis, @@ -17,12 +18,10 @@ def __init__( *, mat: pd.DataFrame, metadata: pd.DataFrame, - sample: str, preprocessing_info: Dict, ): self.mat: pd.DataFrame = mat self.metadata: pd.DataFrame = metadata - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info @ignore_warning(RuntimeWarning) @@ -60,7 +59,6 @@ def diff_expression_analysis( df = DifferentialExpressionAnalysis( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group1=group1, group2=group2, @@ -89,7 +87,6 @@ def anova(self, column: str, protein_ids="all", tukey: bool = True) -> pd.DataFr return Anova( mat=self.mat, metadata=self.metadata, - sample=self.sample, column=column, protein_ids=protein_ids, tukey=tukey, @@ -119,8 +116,8 @@ def ancova( * ``'p-unc'``: Uncorrected p-values * ``'np2'``: Partial eta-squared """ - df = self.mat[protein_id].reset_index().rename(columns={"index": self.sample}) - df = self.metadata.merge(df, how="inner", on=[self.sample]) + df = self.mat[protein_id].reset_index().rename(columns={"index": Cols.SAMPLE}) + df = self.metadata.merge(df, how="inner", on=[Cols.SAMPLE]) ancova_df = pingouin.ancova(df, dv=protein_id, covar=covar, between=between) return ancova_df diff --git a/alphastats/dataset_factory.py b/alphastats/dataset_factory.py index 555bdcfe..c7c75d48 100644 --- a/alphastats/dataset_factory.py +++ b/alphastats/dataset_factory.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +from alphastats.dataset_harmonizer import DataHarmonizer from alphastats.keys import Cols @@ -17,12 +18,12 @@ def __init__( rawinput: pd.DataFrame, intensity_column: Union[List[str], str], metadata_path_or_df: Union[str, pd.DataFrame], - sample_column: str, + data_harmonizer: DataHarmonizer, ): self.rawinput: pd.DataFrame = rawinput - self.sample_column: str = sample_column self.intensity_column: Union[List[str], str] = intensity_column self.metadata_path_or_df: Union[str, pd.DataFrame] = metadata_path_or_df + self._data_harmonizer = data_harmonizer def create_matrix_from_rawinput(self) -> Tuple[pd.DataFrame, pd.DataFrame]: """Creates a matrix: features (Proteins) as columns, samples as rows.""" @@ -58,28 +59,27 @@ def _check_matrix_values(mat: pd.DataFrame) -> None: if np.isinf(mat).values.sum() > 0: logging.warning("Data contains infinite values.") - def create_metadata(self, mat: pd.DataFrame) -> Tuple[pd.DataFrame, str]: + def create_metadata(self, mat: pd.DataFrame) -> pd.DataFrame: """Create metadata DataFrame from metadata file or DataFrame.""" if self.metadata_path_or_df is not None: - sample = self.sample_column metadata = self._load_metadata(file_path=self.metadata_path_or_df) - metadata = self._remove_missing_samples_from_metadata(mat, metadata, sample) + metadata = self._data_harmonizer.get_harmonized_metadata(metadata) + metadata = self._remove_missing_samples_from_metadata(mat, metadata) else: - sample = "sample" - metadata = pd.DataFrame({"sample": list(mat.index)}) + metadata = pd.DataFrame({Cols.SAMPLE: list(mat.index)}) - return metadata, sample + return metadata def _remove_missing_samples_from_metadata( - self, mat: pd.DataFrame, metadata: pd.DataFrame, sample + self, mat: pd.DataFrame, metadata: pd.DataFrame ) -> pd.DataFrame: """Remove samples from metadata that are not in the protein data.""" samples_matrix = mat.index.to_list() - samples_metadata = metadata[sample].to_list() + samples_metadata = metadata[Cols.SAMPLE].to_list() misc_samples = list(set(samples_metadata) - set(samples_matrix)) if len(misc_samples) > 0: - metadata = metadata[~metadata[sample].isin(misc_samples)] + metadata = metadata[~metadata[Cols.SAMPLE].isin(misc_samples)] logging.warning( f"{misc_samples} are not described in the protein data and" "are removed from the metadata." @@ -116,11 +116,6 @@ def _load_metadata( ) return None - if df is not None and self.sample_column not in df.columns: - logging.error( - f"sample_column: {self.sample_column} not found in {file_path}" - ) - # check whether sample labeling matches protein data # warnings.warn("WARNING: Sample names do not match sample labelling in protein data") df.columns = df.columns.astype(str) diff --git a/alphastats/dataset_harmonizer.py b/alphastats/dataset_harmonizer.py index 8bdc5ee8..c63af08b 100644 --- a/alphastats/dataset_harmonizer.py +++ b/alphastats/dataset_harmonizer.py @@ -1,5 +1,7 @@ """Harmonize the input data to a common format.""" +from typing import Dict, Optional + import pandas as pd from alphastats import BaseLoader @@ -9,21 +11,49 @@ class DataHarmonizer: """Harmonize input data to a common format.""" - def __init__(self, loader: BaseLoader): - self._rename_dict = { + def __init__(self, loader: BaseLoader, sample_column: Optional[str] = None): + _rawinput_rename_dict = { loader.index_column: Cols.INDEX, - loader.gene_names_column: Cols.GENE_NAMES, } + if loader.gene_names_column is not None: + _rawinput_rename_dict[loader.gene_names_column] = Cols.GENE_NAMES + + self._rawinput_rename_dict = _rawinput_rename_dict + + self._metadata_rename_dict = ( + { + sample_column: Cols.SAMPLE, + } + if sample_column is not None + else {} + ) def get_harmonized_rawinput(self, rawinput: pd.DataFrame) -> pd.DataFrame: """Harmonize the rawinput data to a common format.""" - for target_name in self._rename_dict.values(): - if target_name in rawinput.columns: + return self._get_harmonized_data( + rawinput, + self._rawinput_rename_dict, + ) + + def get_harmonized_metadata(self, metadata: pd.DataFrame) -> pd.DataFrame: + """Harmonize the rawinput data to a common format.""" + return self._get_harmonized_data( + metadata, + self._metadata_rename_dict, + ) + + @staticmethod + def _get_harmonized_data( + input_df: pd.DataFrame, rename_dict: Dict[str, str] + ) -> pd.DataFrame: + """Harmonize data to a common format.""" + for target_name in rename_dict.values(): + if target_name in input_df.columns: raise ValueError( - f"Column name {target_name} already exists in rawinput. Please rename the column." + f"Column name '{target_name}' already exists. Please rename the column in your input data." ) - return rawinput.rename( - columns=self._rename_dict, - errors="ignore", + return input_df.rename( + columns=rename_dict, + errors="raise", ) diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index cf88de9d..8aa3f27c 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -1,8 +1,5 @@ -from typing import List - import streamlit as st -from alphastats import BaseLoader from alphastats.DataSet import DataSet from alphastats.gui.utils.import_helper import ( load_example_data, @@ -22,15 +19,9 @@ def _finalize_data_loading( - loader: BaseLoader, - metadata_columns: List[str], dataset: DataSet, ) -> None: """Finalize the data loading process.""" - st.session_state[StateKeys.LOADER] = ( - loader # TODO: Figure out if we even need the loader here, as the dataset has the loader as an attribute. - ) - st.session_state[StateKeys.METADATA_COLUMNS] = metadata_columns st.session_state[StateKeys.DATASET] = dataset sidebar_info() @@ -56,9 +47,9 @@ def _finalize_data_loading( if c2.button("Start new Session with example DataSet", key="_load_example_data"): empty_session_state() init_session_state() - loader, metadata_columns, dataset = load_example_data() + dataset = load_example_data() - _finalize_data_loading(loader, metadata_columns, dataset) + _finalize_data_loading(dataset) st.stop() @@ -124,26 +115,23 @@ def _finalize_data_loading( "Upload metadata file with information about your samples", ) -if metadatafile_upload is None: - st.stop() - -metadatafile_df = uploaded_file_to_df(metadatafile_upload) +metadatafile_df = None +if metadatafile_upload is not None: + metadatafile_df = uploaded_file_to_df(metadatafile_upload) -sample_column = show_select_sample_column_for_metadata( - metadatafile_df, software, loader -) + sample_column = show_select_sample_column_for_metadata( + metadatafile_df, software, loader + ) # ########## Create dataset st.markdown("##### 4. Create DataSet") dataset = None -metadata_columns = [] c1, c2 = st.columns(2) if c2.button("Create DataSet without metadata"): dataset = DataSet(loader=loader) - metadata_columns = ["sample"] if c1.button( "Create DataSet with metadata", @@ -160,8 +148,7 @@ def _finalize_data_loading( metadata_path_or_df=metadatafile_df, sample_column=sample_column, ) - metadata_columns = metadatafile_df.columns.to_list() if dataset is not None: st.info("DataSet has been created.") - _finalize_data_loading(loader, metadata_columns, dataset) + _finalize_data_loading(dataset) diff --git a/alphastats/gui/pages/03_Preprocessing.py b/alphastats/gui/pages/03_Preprocessing.py index b0c4d407..90652813 100644 --- a/alphastats/gui/pages/03_Preprocessing.py +++ b/alphastats/gui/pages/03_Preprocessing.py @@ -1,7 +1,6 @@ import streamlit as st from alphastats.gui.utils.preprocessing_helper import ( - PREPROCESSING_STEPS, configure_preprocessing, display_preprocessing_info, draw_workflow, @@ -14,12 +13,6 @@ init_session_state() sidebar_info() -if StateKeys.WORKFLOW not in st.session_state: - st.session_state[StateKeys.WORKFLOW] = [ - PREPROCESSING_STEPS.REMOVE_CONTAMINATIONS, - PREPROCESSING_STEPS.SUBSET, - PREPROCESSING_STEPS.LOG2_TRANSFORM, - ] st.markdown("### Preprocessing") c1, c2 = st.columns([1, 1]) @@ -40,18 +33,16 @@ st.info("Import data first to configure and run preprocessing") else: + dataset = st.session_state[StateKeys.DATASET] + c11, c12 = st.columns([1, 1]) if c11.button("Run preprocessing", key="_run_preprocessing"): - run_preprocessing(settings, st.session_state[StateKeys.DATASET]) + run_preprocessing(settings, dataset) # TODO show more info about the preprocessing steps - display_preprocessing_info( - st.session_state[StateKeys.DATASET].preprocessing_info - ) + display_preprocessing_info(dataset.preprocessing_info) if c12.button("Reset all Preprocessing steps", key="_reset_preprocessing"): - reset_preprocessing(st.session_state[StateKeys.DATASET]) - display_preprocessing_info( - st.session_state[StateKeys.DATASET].preprocessing_info - ) + reset_preprocessing(dataset) + display_preprocessing_info(dataset.preprocessing_info) # TODO: Add comparison plot of intensity distribution before and after preprocessing diff --git a/alphastats/gui/utils/analysis_helper.py b/alphastats/gui/utils/analysis_helper.py index fd324ab0..2c8e521d 100644 --- a/alphastats/gui/utils/analysis_helper.py +++ b/alphastats/gui/utils/analysis_helper.py @@ -5,6 +5,7 @@ import streamlit as st from alphastats.gui.utils.ui_helper import StateKeys, convert_df +from alphastats.keys import Cols from alphastats.plots.VolcanoPlot import VolcanoPlot @@ -165,7 +166,6 @@ def gui_volcano_plot() -> Tuple[Optional[Any], Optional[Any], Optional[Dict]]: mat=dataset.mat, rawinput=dataset.rawinput, metadata=dataset.metadata, - sample=dataset.sample, preprocessing_info=dataset.preprocessing_info, **parameters, ) @@ -268,20 +268,12 @@ def helper_compare_two_groups(): else: group1 = st.multiselect( "Group 1 samples:", - options=dataset.metadata[ - st.session_state[StateKeys.DATASET].sample - ].to_list(), + options=dataset.metadata[Cols.SAMPLE].to_list(), ) group2 = st.multiselect( "Group 2 samples:", - options=list( - reversed( - dataset.metadata[ - st.session_state[StateKeys.DATASET].sample - ].to_list() - ) - ), + options=list(reversed(dataset.metadata[Cols.SAMPLE].to_list())), ) intersection_list = list(set(group1).intersection(set(group2))) diff --git a/alphastats/gui/utils/import_helper.py b/alphastats/gui/utils/import_helper.py index fffcf52c..407c85c8 100644 --- a/alphastats/gui/utils/import_helper.py +++ b/alphastats/gui/utils/import_helper.py @@ -9,6 +9,7 @@ from alphastats.DataSet import DataSet from alphastats.gui.utils.options import SOFTWARE_OPTIONS +from alphastats.keys import Cols from alphastats.loader.MaxQuantLoader import BaseLoader, MaxQuantLoader @@ -96,13 +97,6 @@ def load_example_data(): folder_to_load = os.path.join(_parent_directory, "sample_data") filepath = os.path.join(folder_to_load, "proteinGroups.txt") - metadatapath = os.path.join(folder_to_load, "metadata.xlsx") - - loader = MaxQuantLoader(file=filepath) - # TODO why is this done twice? - dataset = DataSet( - loader=loader, metadata_path_or_df=metadatapath, sample_column="sample" - ) metadatapath = ( os.path.join(_parent_directory, "sample_data", "metadata.xlsx") .replace("pages/", "") @@ -116,15 +110,14 @@ def load_example_data(): dataset.metadata = dataset.metadata[ [ - "sample", + Cols.SAMPLE, "disease", "Drug therapy (procedure) (416608005)", "Lipid-lowering therapy (134350008)", ] ] dataset.preprocess(subset=True) - metadata_columns = dataset.metadata.columns.to_list() - return loader, metadata_columns, dataset + return dataset def _check_softwarefile_df(df: pd.DataFrame, software: str) -> None: diff --git a/alphastats/gui/utils/options.py b/alphastats/gui/utils/options.py index 989c3d0e..2a8110c2 100644 --- a/alphastats/gui/utils/options.py +++ b/alphastats/gui/utils/options.py @@ -8,14 +8,17 @@ from alphastats.loader.mzTabLoader import mzTabLoader +# TODO get rid of the options dict: the calls to the functions should be done directly +# idea: per plot, have a `PlotWidget` class that knows what parameters to display and then calls the function def get_plotting_options(state): dataset = state[StateKeys.DATASET] + metadata_options = [None] + dataset.metadata.columns.to_list() plotting_options = { "Sampledistribution Plot": { "settings": { "method": {"options": ["violin", "box"], "label": "Plot layout"}, "color": { - "options": [None] + state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "Color according to", }, }, @@ -32,7 +35,7 @@ def get_plotting_options(state): "label": "Plot layout", }, "group": { - "options": [None] + state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "Color according to", }, }, @@ -41,7 +44,7 @@ def get_plotting_options(state): "PCA Plot": { "settings": { "group": { - "options": [None] + state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "Color according to", }, "circle": {"label": "Circle"}, @@ -51,7 +54,7 @@ def get_plotting_options(state): "UMAP Plot": { "settings": { "group": { - "options": [None] + state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "Color according to", }, "circle": {"label": "Circle"}, @@ -61,7 +64,7 @@ def get_plotting_options(state): "t-SNE Plot": { "settings": { "group": { - "options": [None] + state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "Color according to", }, "circle": {"label": "Circle"}, @@ -80,6 +83,7 @@ def get_plotting_options(state): def get_statistic_options(state): dataset = state[StateKeys.DATASET] + metadata_options = dataset.metadata.columns.to_list() statistic_options = { "Differential Expression Analysis - T-test": { "between_two_groups": True, @@ -96,7 +100,7 @@ def get_statistic_options(state): "label": "ProteinID/ProteinGroup", }, "group": { - "options": state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "A metadata variable to calculate pairwise tukey", }, }, @@ -105,7 +109,7 @@ def get_statistic_options(state): "ANOVA": { "settings": { "column": { - "options": state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "A variable from the metadata to calculate ANOVA", }, "protein_ids": { @@ -123,11 +127,11 @@ def get_statistic_options(state): "label": "Color according to", }, "covar": { - "options": state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "Name(s) of column(s) in metadata with the covariate.", }, "between": { - "options": state[StateKeys.METADATA_COLUMNS], + "options": metadata_options, "label": "Name of the column in the metadata with the between factor.", }, }, diff --git a/alphastats/gui/utils/preprocessing_helper.py b/alphastats/gui/utils/preprocessing_helper.py index dbfd81c0..512f6917 100644 --- a/alphastats/gui/utils/preprocessing_helper.py +++ b/alphastats/gui/utils/preprocessing_helper.py @@ -5,6 +5,7 @@ from st_cytoscape import cytoscape from alphastats.DataSet import DataSet +from alphastats.keys import Cols CYTOSCAPE_STYLESHEET = [ { @@ -190,7 +191,7 @@ def configure_preprocessing(dataset): # TODO: value of this widget does not persist across dataset reset (likely because the metadata is reset) remove_samples = st.multiselect( "Remove samples from analysis", - options=dataset.metadata[dataset.sample].to_list(), + options=dataset.metadata[Cols.SAMPLE].to_list(), ) remove_samples = remove_samples if len(remove_samples) != 0 else None diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index 0b3170a5..23387c49 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -5,6 +5,7 @@ import streamlit as st from alphastats import __version__ +from alphastats.gui.utils.preprocessing_helper import PREPROCESSING_STEPS # TODO add logo above the options when issue is closed # https://github.com/streamlit/streamlit/issues/4984 @@ -86,6 +87,13 @@ def init_session_state() -> None: if StateKeys.ORGANISM not in st.session_state: st.session_state[StateKeys.ORGANISM] = 9606 # human + if StateKeys.WORKFLOW not in st.session_state: + st.session_state[StateKeys.WORKFLOW] = [ + PREPROCESSING_STEPS.REMOVE_CONTAMINATIONS, + PREPROCESSING_STEPS.SUBSET, + PREPROCESSING_STEPS.LOG2_TRANSFORM, + ] + if StateKeys.PLOT_LIST not in st.session_state: st.session_state[StateKeys.PLOT_LIST] = [] @@ -96,13 +104,11 @@ def init_session_state() -> None: class StateKeys: ## 02_Data Import # on 1st run - ORGANISM = "organism" + ORGANISM = "organism" # TODO this is essentially a constant USER_SESSION_ID = "user_session_id" - LOADER = "loader" # on sample run (function load_sample_data), removed on new session click DATASET = "dataset" # functions upload_metadatafile - METADATA_COLUMNS = "metadata_columns" WORKFLOW = "workflow" PLOT_LIST = "plot_list" diff --git a/alphastats/keys.py b/alphastats/keys.py index da9d02e4..f6b82429 100644 --- a/alphastats/keys.py +++ b/alphastats/keys.py @@ -7,3 +7,5 @@ class Cols: INDEX = "index_" GENE_NAMES = "gene_names_" + + SAMPLE = "sample_" diff --git a/alphastats/loader/SpectronautLoader.py b/alphastats/loader/SpectronautLoader.py index b9244e36..67112fdb 100644 --- a/alphastats/loader/SpectronautLoader.py +++ b/alphastats/loader/SpectronautLoader.py @@ -88,7 +88,7 @@ def _reshape_long_to_wide(self): other proteomics softwares use a wide format (column for each sample) reshape to a wider format """ - self.rawinput["sample"] = ( + self.rawinput["tmp_sample"] = ( self.rawinput[self.sample_column] + SPECTRONAUT_COLUMN_DELIM + self.intensity_column @@ -98,9 +98,11 @@ def _reshape_long_to_wide(self): indexing_columns.append(self.gene_names_column) df = self.rawinput.pivot( - columns="sample", index=indexing_columns, values=self.intensity_column + columns="tmp_sample", index=indexing_columns, values=self.intensity_column ) df.reset_index(inplace=True) + # get rid of tmp_sample again, which can cause troubles when working with indices downstream + df.rename_axis(columns=None, inplace=True) return df diff --git a/alphastats/multicova/multicova.py b/alphastats/multicova/multicova.py index e0f7d159..0bcfe20f 100644 --- a/alphastats/multicova/multicova.py +++ b/alphastats/multicova/multicova.py @@ -15,6 +15,8 @@ from sklearn.preprocessing import StandardScaler from statsmodels.stats.multitest import multipletests +from alphastats.keys import Cols + # code taken from Isabel Bludau - multicova @@ -638,13 +640,12 @@ def full_regression_analysis( quant_data, annotation, covariates, - sample_column="sample_name", n_permutations=4, fdr=0.05, s0=0.05, seed=42, ): - data_cols = annotation[sample_column].values + data_cols = annotation[Cols.SAMPLE].values quant_data = quant_data.dropna().reset_index(drop=True) y = quant_data[data_cols].to_numpy().astype("float") # @ToDo make sure that columns are sorted correctly!!! @@ -745,7 +746,6 @@ def evaluate_seed_and_perm( covariates, perms, seeds, - sample_column="sample_name", fdr=0.05, s0=0.05, ): @@ -763,7 +763,6 @@ def evaluate_seed_and_perm( annotation=annotation, covariates=covariates, n_permutations=resDF.permutations[i], - sample_column=sample_column, fdr=fdr, s0=s0, seed=resDF.seed[i], @@ -797,7 +796,6 @@ def evaluate_s0s( annotation, covariates, s0s, - sample_column="sample_name", n_permutations=5, seed=42, fdr=0.01, @@ -810,7 +808,6 @@ def evaluate_s0s( quant_data=quant_data, annotation=annotation, covariates=covariates, - sample_column=sample_column, n_permutations=n_permutations, fdr=fdr, s0=resDF.s0[i], diff --git a/alphastats/plots/ClusterMap.py b/alphastats/plots/ClusterMap.py index 395695cd..5b74f586 100644 --- a/alphastats/plots/ClusterMap.py +++ b/alphastats/plots/ClusterMap.py @@ -15,7 +15,6 @@ def __init__( *, mat: pd.DataFrame, metadata: pd.DataFrame, - sample: str, preprocessing_info: Dict, label_bar, only_significant, @@ -24,13 +23,11 @@ def __init__( ): self.mat: pd.DataFrame = mat self.metadata: pd.DataFrame = metadata - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info self._statistics = Statistics( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, ) @@ -48,9 +45,9 @@ def _prepare_df(self): if self.group is not None and self.subgroups is not None: metadata_df = self.metadata[ - self.metadata[self.group].isin(self.subgroups + [self.sample]) + self.metadata[self.group].isin(self.subgroups + [Cols.SAMPLE]) ] - samples = metadata_df[self.sample] + samples = metadata_df[Cols.SAMPLE] df = df.filter(items=samples, axis=0) else: diff --git a/alphastats/plots/DimensionalityReduction.py b/alphastats/plots/DimensionalityReduction.py index f8e73fec..ae43750f 100644 --- a/alphastats/plots/DimensionalityReduction.py +++ b/alphastats/plots/DimensionalityReduction.py @@ -7,6 +7,7 @@ import sklearn from alphastats.DataSet_Preprocess import Preprocess +from alphastats.keys import Cols from alphastats.plots.PlotUtils import PlotUtils, plotly_object # make own alphastats theme @@ -37,7 +38,6 @@ def __init__( *, mat: pd.DataFrame, metadata: pd.DataFrame, - sample: str, preprocessing_info: Dict, group: Optional[str], circle: bool, @@ -46,7 +46,6 @@ def __init__( ) -> None: self.mat: pd.DataFrame = mat self.metadata: pd.DataFrame = metadata - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info self.method = method @@ -96,12 +95,10 @@ def _prepare_df(self): # TODO This is only needed in the DimensionalityReduction class and only if the step was not run during preprocessing. # idea: replace the step in DimensionalityReduction with something like: # mat = self.data.mat.loc[sample_names,:] after creating sample_names. - mat = Preprocess.subset( - self.mat, self.metadata, self.sample, self.preprocessing_info - ) + mat = Preprocess.subset(self.mat, self.metadata, self.preprocessing_info) self.metadata[self.group] = self.metadata[self.group].apply(str) group_color = self.metadata[self.group] - sample_names = self.metadata[self.sample].to_list() + sample_names = self.metadata[Cols.SAMPLE].to_list() else: mat = self.mat @@ -145,7 +142,7 @@ def _umap(self): def _plot(self, sample_names, group_color): components = pd.DataFrame(self.components) - components[self.sample] = sample_names + components[Cols.SAMPLE] = sample_names fig = px.scatter( components, @@ -153,7 +150,7 @@ def _plot(self, sample_names, group_color): y=1, labels=self.labels, color=group_color, - hover_data=[components[self.sample]], + hover_data=[components[Cols.SAMPLE]], template="simple_white+alphastats_colors", ) diff --git a/alphastats/plots/IntensityPlot.py b/alphastats/plots/IntensityPlot.py index cadc7d88..c7abb11e 100644 --- a/alphastats/plots/IntensityPlot.py +++ b/alphastats/plots/IntensityPlot.py @@ -8,6 +8,7 @@ import plotly.graph_objects as go import scipy +from alphastats.keys import Cols from alphastats.plots.PlotUtils import PlotUtils, plotly_object plotly.io.templates["alphastats_colors"] = plotly.graph_objects.layout.Template( @@ -37,7 +38,6 @@ def __init__( *, mat: pd.DataFrame, metadata: pd.DataFrame, - sample: str, intensity_column: str, preprocessing_info: Dict, protein_id, @@ -49,7 +49,6 @@ def __init__( ) -> None: self.mat = mat self.metadata = metadata - self.sample = sample self.intensity_column = intensity_column self.preprocessing_info = preprocessing_info @@ -140,9 +139,9 @@ def _prepare_data(self): df = ( self.mat[[self.protein_id]] .reset_index() - .rename(columns={"index": self.sample}) + .rename(columns={"index": Cols.SAMPLE}) ) - df = df.merge(self.metadata, how="inner", on=[self.sample]) + df = df.merge(self.metadata, how="inner", on=[Cols.SAMPLE]) if self.subgroups is not None: df = df[df[self.group].isin(self.subgroups)] diff --git a/alphastats/plots/VolcanoPlot.py b/alphastats/plots/VolcanoPlot.py index 8f8808f2..578d9557 100644 --- a/alphastats/plots/VolcanoPlot.py +++ b/alphastats/plots/VolcanoPlot.py @@ -49,7 +49,6 @@ def __init__( mat: pd.DataFrame, rawinput: pd.DataFrame, metadata: pd.DataFrame, - sample: str, preprocessing_info: Dict, group1: Union[List[str], str], group2: Union[List[str], str], @@ -68,7 +67,6 @@ def __init__( self.mat: pd.DataFrame = mat self.rawinput = rawinput self.metadata: pd.DataFrame = metadata - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info self.method = method @@ -83,9 +81,7 @@ def __init__( self.color_list = color_list if isinstance(group1, list) and isinstance(group2, list): - self.metadata, self.column = add_metadata_column( - metadata, sample, group1, group2 - ) + self.metadata, self.column = add_metadata_column(metadata, group1, group2) self.group1, self.group2 = "group1", "group2" else: self.metadata, self.column = metadata, column @@ -99,7 +95,6 @@ def __init__( self._statistics = Statistics( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, ) @@ -139,7 +134,6 @@ def _perform_differential_expression_analysis( res, tlim_ttest = DifferentialExpressionAnalysis( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group1=self.group1, group2=self.group2, @@ -164,14 +158,14 @@ def _sam_calculate_fdr_line(self): n_x=len( list( self.metadata[self.metadata[self.column] == self.group1][ - self.sample + Cols.SAMPLE ] ) ), n_y=len( list( self.metadata[self.metadata[self.column] == self.group2][ - self.sample + Cols.SAMPLE ] ) ), @@ -193,11 +187,11 @@ def _anova(self) -> Tuple[pd.DataFrame, str]: ) group1_samples = self.metadata[self.metadata[self.column] == self.group1][ - self.sample + Cols.SAMPLE ].tolist() group2_samples = self.metadata[self.metadata[self.column] == self.group2][ - self.sample + Cols.SAMPLE ].tolist() mat_transpose = self.mat.transpose() diff --git a/alphastats/statistics/Anova.py b/alphastats/statistics/Anova.py index 45073e9d..12e5f2a0 100644 --- a/alphastats/statistics/Anova.py +++ b/alphastats/statistics/Anova.py @@ -13,14 +13,12 @@ def __init__( self, mat: pd.DataFrame, metadata: pd.DataFrame, - sample: str, column: str, protein_ids: Union[str, List[str]], tukey: bool, ): self.mat: pd.DataFrame = mat self.metadata: pd.DataFrame = metadata - self.sample: str = sample # TODO move these to perform()? self.column: str = column @@ -58,7 +56,7 @@ def _prepare_data(self): self.all_groups = [] for sub in subgroup: group_list = self.metadata[self.metadata[self.column] == sub][ - self.sample + Cols.SAMPLE ].tolist() self.all_groups.append(group_list) @@ -69,9 +67,9 @@ def _create_tukey_df(self, anova_df: pd.DataFrame) -> pd.DataFrame: df = ( self.mat[self.protein_ids_list] .reset_index() - .rename(columns={"index": self.sample}) + .rename(columns={"index": Cols.SAMPLE}) ) - df = df.merge(self.metadata, how="inner", on=[self.sample]) + df = df.merge(self.metadata, how="inner", on=[Cols.SAMPLE]) tukey_df_list = [] for protein_id in tqdm(self.protein_ids_list): tukey_df_list.append( diff --git a/alphastats/statistics/DifferentialExpressionAnalysis.py b/alphastats/statistics/DifferentialExpressionAnalysis.py index 312275c5..6464b5ef 100644 --- a/alphastats/statistics/DifferentialExpressionAnalysis.py +++ b/alphastats/statistics/DifferentialExpressionAnalysis.py @@ -18,7 +18,6 @@ def __init__( self, mat: pd.DataFrame, metadata: pd.DataFrame, - sample: str, preprocessing_info: Dict, group1: Union[str, list], group2: Union[str, list], @@ -30,7 +29,6 @@ def __init__( ): self.mat = mat - self.sample = sample self.preprocessing_info = preprocessing_info self.method = method @@ -38,9 +36,7 @@ def __init__( self.fdr = fdr if isinstance(group1, list) and isinstance(group2, list): - self.metadata, self.column = add_metadata_column( - metadata, sample, group1, group2 - ) + self.metadata, self.column = add_metadata_column(metadata, group1, group2) self.group1, self.group2 = "group1", "group2" else: self.metadata, self.column = metadata, column @@ -57,7 +53,7 @@ def _prepare_anndata(self): group_samples = self.metadata[ (self.metadata[self.column] == self.group1) | (self.metadata[self.column] == self.group2) - ][self.sample].tolist() + ][Cols.SAMPLE].tolist() # reduce matrix reduced_matrix = self.mat.loc[group_samples] @@ -66,8 +62,8 @@ def _prepare_anndata(self): list_to_sort = reduced_matrix.index.to_list() # reduce metadata obs_metadata = ( - self.metadata[self.metadata[self.sample].isin(group_samples)] - .set_index(self.sample) + self.metadata[self.metadata[Cols.SAMPLE].isin(group_samples)] + .set_index(Cols.SAMPLE) .loc[list_to_sort] ) @@ -98,10 +94,10 @@ def sam(self) -> Tuple[pd.DataFrame, float]: res_ttest, tlim_ttest = multicova.perform_ttest_analysis( transposed, c1=list( - self.metadata[self.metadata[self.column] == self.group1][self.sample] + self.metadata[self.metadata[self.column] == self.group1][Cols.SAMPLE] ), c2=list( - self.metadata[self.metadata[self.column] == self.group2][self.sample] + self.metadata[self.metadata[self.column] == self.group2][Cols.SAMPLE] ), s0=0.05, n_perm=self.perm, @@ -149,10 +145,10 @@ def _welch_ttest(self) -> pd.DataFrame: def _generic_ttest(self, test_fun: Callable) -> pd.DataFrame: group1_samples = self.metadata[self.metadata[self.column] == self.group1][ - self.sample + Cols.SAMPLE ].tolist() group2_samples = self.metadata[self.metadata[self.column] == self.group2][ - self.sample + Cols.SAMPLE ].tolist() # calculate fold change (if its is not logarithimic normalized) mat_transpose = self.mat.transpose() diff --git a/alphastats/statistics/MultiCovaAnalysis.py b/alphastats/statistics/MultiCovaAnalysis.py index f2690996..c47e5c24 100644 --- a/alphastats/statistics/MultiCovaAnalysis.py +++ b/alphastats/statistics/MultiCovaAnalysis.py @@ -33,7 +33,7 @@ def __init__( self._prepare_matrix() def _subset_metadata(self): - columns_to_keep = self.covariates + [self.dataset.sample] + columns_to_keep = self.covariates + [Cols.SAMPLE] if self.subset is not None: # dict structure {"column_name": ["group1", "group2"]} subset_column = list(self.subset.keys())[0] @@ -101,7 +101,7 @@ def _prepare_matrix(self): transposed = self.dataset.mat.transpose() transposed[Cols.INDEX] = transposed.index transposed = transposed.reset_index(drop=True) - self.transposed = transposed[self.metadata[self.dataset.sample].to_list()] + self.transposed = transposed[self.metadata[Cols.SAMPLE].to_list()] def _plot_volcano_regression(self, res_real, variable): sig_col = res_real.filter(regex=variable + "_" + "FDR").columns[0] @@ -130,7 +130,6 @@ def calculate(self): quant_data=self.transposed, annotation=self.metadata, covariates=self.covariates, - sample_column=self.dataset.sample, n_permutations=self.n_permutations, fdr=self.fdr, s0=self.s0, diff --git a/alphastats/statistics/StatisticUtils.py b/alphastats/statistics/StatisticUtils.py index f40adfa9..5cfa6729 100644 --- a/alphastats/statistics/StatisticUtils.py +++ b/alphastats/statistics/StatisticUtils.py @@ -1,6 +1,8 @@ import numpy as np import pandas as pd +from alphastats.keys import Cols + def calculate_foldchange( mat_transpose: pd.DataFrame, @@ -20,22 +22,20 @@ def calculate_foldchange( return fc -def add_metadata_column( - metadata: pd.DataFrame, sample: str, group1_list: list, group2_list: list -): +def add_metadata_column(metadata: pd.DataFrame, group1_list: list, group2_list: list): # create new column in metadata with defined groups - sample_names = metadata[sample].to_list() + sample_names = metadata[Cols.SAMPLE].to_list() misc_samples = list(set(group1_list + group2_list) - set(sample_names)) if len(misc_samples) > 0: raise ValueError(f"Sample names: {misc_samples} are not described in Metadata.") column = "_comparison_column" - conditons = [ - metadata[sample].isin(group1_list), - metadata[sample].isin(group2_list), + conditions = [ + metadata[Cols.SAMPLE].isin(group1_list), + metadata[Cols.SAMPLE].isin(group2_list), ] choices = ["group1", "group2"] - metadata[column] = np.select(conditons, choices, default=np.nan) + metadata[column] = np.select(conditions, choices, default=np.nan) return metadata, column diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index 568ad5b9..7e5915fa 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -45,8 +45,8 @@ def test_page_02_loads_example_data(mock_page_link: MagicMock): assert not at.exception - assert at.session_state[StateKeys.METADATA_COLUMNS] == [ - "sample", + assert at.session_state[StateKeys.DATASET].metadata.columns.to_list() == [ + "sample_", "disease", "Drug therapy (procedure) (416608005)", "Lipid-lowering therapy (134350008)", @@ -55,10 +55,6 @@ def test_page_02_loads_example_data(mock_page_link: MagicMock): str(type(at.session_state[StateKeys.DATASET])) == "" ) - assert ( - str(type(at.session_state[StateKeys.LOADER])) - == "" - ) @patch("streamlit.file_uploader") @@ -112,8 +108,3 @@ def test_page_02_loads_maxquant_testfiles( assert dataset._intensity_column == "LFQ intensity [sample]" assert dataset.rawmat.shape == (312, 2611) assert dataset.software == "MaxQuant" - assert dataset.sample == "sample" - assert ( - str(type(at.session_state[StateKeys.LOADER])) - == "" - ) diff --git a/tests/test_DataSet.py b/tests/test_DataSet.py index e42881d9..3cd65ac5 100644 --- a/tests/test_DataSet.py +++ b/tests/test_DataSet.py @@ -66,14 +66,6 @@ def test_load_metadata(self): self.assertIsInstance(self.obj.metadata, pd.DataFrame) self.assertFalse(self.obj.metadata.empty) - @patch("logging.Logger.error") - def test_load_metadata_missing_sample_column(self, mock): - # is error raised when name of sample column is missing - path = self.metadata_path - self.obj._dataset_factory.sample_column = "wrong_sample_column" - self.obj._dataset_factory._load_metadata(file_path=path) - mock.assert_called_once() - @patch("logging.Logger.warning") def test_load_metadata_warning(self, mock): # is dataframe None and is warning produced diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 75d86481..5a233233 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -2,10 +2,7 @@ import unittest from unittest.mock import MagicMock, patch -import streamlit as st - from alphastats.DataSet import DataSet -from alphastats.gui.utils.ui_helper import StateKeys from alphastats.llm.uniprot_utils import extract_data, get_uniprot_data from alphastats.loader.MaxQuantLoader import MaxQuantLoader @@ -25,7 +22,6 @@ def setUp(self): self.matrix_dim = (312, 2596) self.matrix_dim_filtered = (312, 2397) self.comparison_column = "disease" - st.session_state[StateKeys.METADATA_COLUMNS] = [self.comparison_column] class TestGetUniProtData(unittest.TestCase):