Skip to content

Commit

Permalink
Merge pull request #374 from MannLabs/refactor_analysis_review_comments
Browse files Browse the repository at this point in the history
add changes/comments requested by code review
  • Loading branch information
mschwoer authored Nov 18, 2024
2 parents f8f3ce2 + 8009d03 commit 257d05c
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 43 deletions.
6 changes: 5 additions & 1 deletion alphastats/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def __init__(
"""
self._check_loader(loader=loader)

self._data_harmonizer = DataHarmonizer(loader, sample_column)
self._data_harmonizer = DataHarmonizer(
loader, sample_column
) # TODO should be moved to the loaders

# fill data from loader
self.rawinput: pd.DataFrame = self._data_harmonizer.get_harmonized_rawinput(
Expand Down Expand Up @@ -110,6 +112,8 @@ def __init__(
if Cols.GENE_NAMES in self.rawinput.columns
else {}
)
# TODO This is not necessarily unique, and should ideally raise an error in some of our test-data sets that
# contain isoform ids. E.g. TPM1 occurs 5 times in testfiles/maxquant/proteinGroups.txt with different base Protein IDs.

print("DataSet has been created.")

Expand Down
3 changes: 3 additions & 0 deletions alphastats/DataSet_Plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def plot_sampledistribution(
# create long df
matrix = self.mat if not use_raw else self.rawmat
df = matrix.unstack().reset_index()
# TODO replace intensity either with the more generic term abundance,
# or use what was actually the original name.
# Intensity or LFQ intensity, or even SILAC ratio makes a bit difference
df.rename(columns={"level_1": Cols.SAMPLE, 0: "Intensity"}, inplace=True)

if color is not None:
Expand Down
41 changes: 19 additions & 22 deletions alphastats/gui/gui.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import os
import sys

from streamlit.web import cli as stcli


def run():
Expand All @@ -20,22 +17,22 @@ def run():
)

# TODO why are we starting the app a second time here?
_this_file = os.path.abspath(__file__)
_this_directory = os.path.dirname(_this_file)

file_path = os.path.join(_this_directory, "AlphaPeptStats.py")

HOME = os.path.expanduser("~")
ST_PATH = os.path.join(HOME, ".streamlit")

for folder in [ST_PATH]:
if not os.path.isdir(folder):
os.mkdir(folder)

print(f"Starting AlphaPeptStats from {file_path}")

args = ["streamlit", "run", file_path, "--global.developmentMode=false"]

sys.argv = args

sys.exit(stcli.main())
# _this_file = os.path.abspath(__file__)
# _this_directory = os.path.dirname(_this_file)
#
# file_path = os.path.join(_this_directory, "AlphaPeptStats.py")
#
# HOME = os.path.expanduser("~")
# ST_PATH = os.path.join(HOME, ".streamlit")
#
# for folder in [ST_PATH]:
# if not os.path.isdir(folder):
# os.mkdir(folder)
#
# print(f"Starting AlphaPeptStats from {file_path}")
#
# args = ["streamlit", "run", file_path, "--global.developmentMode=false"]
#
# sys.argv = args
#
# sys.exit(stcli.main())
3 changes: 1 addition & 2 deletions alphastats/gui/pages/04_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
"""
st.markdown(styl, unsafe_allow_html=True)

# TODO put everything in the session state for a given parameter set?
# or is caching functionality the way to go here?
# TODO use caching functionality for all analysis (not: plot creation)

if StateKeys.DATASET not in st.session_state:
st.info("Import data first.")
Expand Down
5 changes: 3 additions & 2 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
"""The chat interface for the LLM analysis."""

# TODO dump to file -> static file name, plus button to do so
# how to deal with binaries? base64 encode?
# "import chat" functionality?
# Ideas: save chat as txt, without encoding objects, just put a replacement string.
# Offer bulk download of zip with all figures (via plotly download as svg.).
# Alternatively write it all in one pdf report using e.g. pdfrw and reportlab (I have code for that combo).

# no. tokens spent
for message in llm_integration.get_print_view(show_all=show_all):
Expand Down
7 changes: 4 additions & 3 deletions alphastats/gui/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
class PlottingOptions(metaclass=ConstantsClass):
"""Keys for the plotting options, the order determines order in UI."""

VOLCANO_PLOT = "Volcano Plot"
PCA_PLOT = "PCA Plot"
UMAP_PLOT = "UMAP Plot"
TSNE_PLOT = "t-SNE Plot"
VOLCANO_PLOT = "Volcano Plot"
SAMPLE_DISTRIBUTION_PLOT = "Sampledistribution Plot"
INTENSITY_PLOT = "Intensity Plot"
CLUSTERMAP = "Clustermap"
Expand Down Expand Up @@ -55,7 +55,7 @@ def do_analysis(
) -> Tuple[
Union[PlotlyObject, pd.DataFrame], Optional[VolcanoPlot], Dict[str, Any]
]:
"""Perform the analysis after an optional check for NaNs.
"""Perform the analysis after some upfront method-dependent checks (e.g. or NaNs).
Returns a tuple(analysis, analysis_object, parameters) where 'analysis' is the plot or dataframe,
'analysis_object' is the underlying object, 'parameters' is a dictionary of the parameters used.
Expand All @@ -78,7 +78,7 @@ def _do_analysis(

def _nan_check(self) -> None: # noqa: B027
"""Raise ValueError for methods that do not tolerate NaNs if there are any."""
if not self._works_with_nans and self._dataset.mat.isnull().values.any():
if not self._works_with_nans and self._dataset.mat.isnan().values.any():
raise ValueError("This analysis does not work with NaN values.")

def _pre_analysis_check(self) -> None: # noqa: B027
Expand Down Expand Up @@ -301,6 +301,7 @@ def show_widget(self):
"Foldchange cutoff", range(0, 3), value=1
)

# TODO: The sam fdr cutoff should be mutually exclusive with alpha
if method == "sam":
parameters["perm"] = st.number_input(
label="Number of Permutations", min_value=1, max_value=1000, value=10
Expand Down
10 changes: 7 additions & 3 deletions alphastats/gui/utils/analysis_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,16 @@ def _display(
def display_figure(plot: PlotlyObject) -> None:
"""Display plotly or seaborn figure."""
try:
st.plotly_chart(plot.update_layout(plot_bgcolor="white"))
st.plotly_chart()
except Exception:
st.pyplot(plot)


def _show_buttons_download_figure(analysis_result: PlotlyObject, name: str) -> None:
"""Show buttons to download figure as .pdf or .svg."""
# TODO We have to check for all scatter plotly figures, which renderer they use.
# Default is webgl, which is good for browser performance, but looks horrendous in svg download
# rerendering with svg as renderer could be a method of PlotlyObject to invoke prior to saving as svg
_show_button_download_figure(analysis_result, name, "pdf")
_show_button_download_figure(analysis_result, name, "svg")

Expand All @@ -109,7 +112,7 @@ def _show_button_download_figure(

try: # plotly
plot.write_image(file=buffer, format=file_format)
except AttributeError: # TODO figure out what else "plot" can be
except AttributeError: # seaborn
plot.savefig(buffer, format=file_format)

st.download_button(
Expand All @@ -119,6 +122,7 @@ def _show_button_download_figure(
)


# TODO: use pandas stylers, rather than changing the data
def _display_df(df: pd.DataFrame) -> None:
"""Display a dataframe."""
mask = df.applymap(type) != bool # noqa: E721
Expand All @@ -132,7 +136,7 @@ def _show_button_download_analysis_and_preprocessing_info(
parameters: Dict,
name: str,
):
"""Download analysis info (= analysis and preprocessing parameters and ) as .csv."""
"""Download analysis info (= analysis and preprocessing parameters) as .csv."""
parameters_pretty = {
f"analysis_parameter__{k}": "None" if v is None else v
for k, v in parameters.items()
Expand Down
9 changes: 6 additions & 3 deletions alphastats/gui/utils/overview_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import streamlit as st

from alphastats.DataSet import DataSet
from alphastats.gui.utils.ui_helper import StateKeys
from alphastats.gui.utils.ui_helper import StateKeys, show_button_download_df


# @st.cache_data # TODO check if caching is sensible here and if so, reimplement with dataset-hash
Expand All @@ -24,13 +24,16 @@ def display_matrix():
st.markdown("**DataFrame used for analysis** *preview*")

# TODO why not use the actual matrix here?
mat = st.session_state[StateKeys.DATASET].mat
df = pd.DataFrame(
st.session_state[StateKeys.DATASET].mat.values,
index=st.session_state[StateKeys.DATASET].mat.index.to_list(),
mat.values,
index=mat.index.to_list(),
).head(10)

st.dataframe(df)

show_button_download_df(mat, file_name="analysis_matrix")


def display_loaded_dataset(dataset: DataSet) -> None:
st.markdown(f"*Preview:* Raw data from {dataset.software}")
Expand Down
7 changes: 4 additions & 3 deletions alphastats/llm/llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class Models:
"""

GPT4O = "gpt-4o"
OLLAMA_31_8B = "llama3.1:8b"
OLLAMA_31_70B = "llama3.1:70b" # for testing only
OLLAMA_31_70B = "llama3.1:70b"
OLLAMA_31_8B = "llama3.1:8b" # for testing only


class LLMIntegration:
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
self._model = model_name

if model_name in [Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]:
url = f"{base_url}/v1"
url = f"{base_url}/v1" # TODO: enable to configure this per model
self._client = OpenAI(base_url=url, api_key="ollama")
elif model_name in [Models.GPT4O]:
self._client = OpenAI(api_key=api_key)
Expand Down Expand Up @@ -150,6 +150,7 @@ def _truncate_conversation_history(self, max_tokens: int = 100000):
max_tokens : int, optional
The maximum number of tokens to keep in history, by default 100000
"""
# TODO: avoid important messages being removed (e.g. facts about genes)
total_tokens = sum(
len(message["content"].split()) for message in self._messages
)
Expand Down
4 changes: 2 additions & 2 deletions alphastats/plots/ClusterMap.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def _prepare_df(self):
significant_proteins = anova_df[anova_df["ANOVA_pvalue"] < 0.05][
Cols.INDEX
].to_list()
df = df[significant_proteins] # TODO bug?
df = df[significant_proteins]

if self.label_bar is not None:
self._create_label_bar(metadata_df)

self.prepared_df = self.mat.loc[:, (self.mat != 0).any(axis=0)].transpose()
self.prepared_df = df.transpose()

def _plot(self):
fig = sns.clustermap(self.prepared_df, col_colors=self.label_bar)
Expand Down
3 changes: 1 addition & 2 deletions alphastats/plots/DimensionalityReduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import plotly.express as px
import plotly.graph_objects as go
import sklearn
from sklearn.manifold._t_sne import TSNE

from alphastats.DataSet_Preprocess import Preprocess
from alphastats.keys import Cols
Expand Down Expand Up @@ -120,7 +119,7 @@ def _pca(self):
}

def _tsne(self, **kwargs):
tsne = TSNE(n_components=2, verbose=1, **kwargs)
tsne = sklearn.manifold._t_sne.TSNE(n_components=2, verbose=1, **kwargs)
self.components = tsne.fit_transform(self.prepared_df)
self.labels = {
"0": "Dimension 1",
Expand Down
1 change: 1 addition & 0 deletions alphastats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def wrapper(*args):
return wrapper


# TODO: replace with https://pandas.pydata.org/docs/reference/api/pandas.Index.has_duplicates.html#pandas.Index.has_duplicates
def find_duplicates_in_list(input_list: list) -> list:
"""Find duplicates in a list.
Expand Down

0 comments on commit 257d05c

Please sign in to comment.