diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/04_Analysis.py index d69bc7fb..3f63ba44 100644 --- a/alphastats/gui/pages/04_Analysis.py +++ b/alphastats/gui/pages/04_Analysis.py @@ -54,11 +54,13 @@ ) if method in (plotting_options := get_plotting_options(st.session_state)): - analysis_result = do_analysis(method, options_dict=plotting_options) + analysis_result, analysis_object, parameters = do_analysis( + method, options_dict=plotting_options + ) show_plot = analysis_result is not None elif method in (statistic_options := get_statistic_options(st.session_state)): - analysis_result = do_analysis( + analysis_result, *_ = do_analysis( method, options_dict=statistic_options, ) @@ -79,3 +81,7 @@ st.download_button( "Download as .csv", csv, method + ".csv", "text/csv", key="download-csv" ) + +# TODO this is still quite rough, should be a list, mb add a button etc.. +if method == "Volcano Plot" and analysis_result is not None: + st.session_state[StateKeys.LLM_INPUT] = (analysis_object, parameters) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 2d6d3cdd..644dd9c1 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -6,9 +6,6 @@ from alphastats.gui.utils.analysis_helper import ( display_figure, - gui_volcano_plot_differential_expression_analysis, - helper_compare_two_groups, - save_plot_to_session_state, ) from alphastats.gui.utils.gpt_helper import ( display_proteins, @@ -17,29 +14,13 @@ get_subgroups_for_each_group, ) from alphastats.gui.utils.ollama_utils import LLMIntegration -from alphastats.gui.utils.openai_utils import ( - set_api_key, -) -from alphastats.gui.utils.options import get_interpretation_options +from alphastats.gui.utils.openai_utils import set_api_key from alphastats.gui.utils.ui_helper import StateKeys, init_session_state, sidebar_info init_session_state() sidebar_info() -def select_analysis(): - """ - select box - loads keys from option dicts - """ - method = st.selectbox( - "Analysis", - # options=["Volcano plot"], - options=list(get_interpretation_options(st.session_state).keys()), - ) - return method - - if StateKeys.DATASET not in st.session_state: st.info("Import Data first") st.stop() @@ -47,141 +28,50 @@ def select_analysis(): st.markdown("### LLM Analysis") -sidebar_info() -init_session_state() - - -# set background to white so downloaded pngs dont have grey background -styl = """ - - """ -st.markdown(styl, unsafe_allow_html=True) - -# Initialize session state variables -if StateKeys.LLM_INTEGRATION not in st.session_state: - st.session_state[StateKeys.LLM_INTEGRATION] = None -if StateKeys.API_TYPE not in st.session_state: - st.session_state[StateKeys.API_TYPE] = "gpt" - -if StateKeys.PLOT_LIST not in st.session_state: - st.session_state[StateKeys.PLOT_LIST] = [] -if StateKeys.MESSAGES not in st.session_state: - st.session_state[StateKeys.MESSAGES] = [] +@st.fragment +def llm_config(): + """Show the configuration options for the LLM analysis.""" + c1, _ = st.columns((1, 2)) + with c1: + st.session_state[StateKeys.API_TYPE] = st.selectbox( + "Select LLM", + ["gpt4o", "llama3.1 70b"], + ) -if StateKeys.PLOT_SUBMITTED_CLICKED not in st.session_state: - st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] = 0 - st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] = 0 + if st.session_state[StateKeys.API_TYPE] == "gpt4o": + api_key = st.text_input("Enter OpenAI API Key", type="password") + set_api_key(api_key) + else: + st.info("Expecting Ollama API at http://localhost:11434.") -if StateKeys.LOOKUP_SUBMITTED_CLICKED not in st.session_state: - st.session_state[StateKeys.LOOKUP_SUBMITTED_CLICKED] = 0 - st.session_state[StateKeys.LOOKUP_SUBMITTED_COUNTER] = 0 -if StateKeys.GPT_SUBMITTED_CLICKED not in st.session_state: - st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] = 0 - st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] = 0 +llm_config() +st.markdown("#### Analysis") -st.markdown("#### Configure LLM") -c1, _ = st.columns((1, 2)) -with c1: - st.session_state[StateKeys.API_TYPE] = st.selectbox( - "Select LLM", - ["gpt4o", "llama3.1 70b"], - index=0 if st.session_state[StateKeys.API_TYPE] == "gpt4o" else 1, - ) +if StateKeys.LLM_INPUT not in st.session_state: + st.info("Create a Volcano plot first using the 'Analysis' page.") + st.stop() - if st.session_state[StateKeys.API_TYPE] == "gpt4o": - api_key = st.text_input("Enter OpenAI API Key", type="password") - set_api_key(api_key) +volcano_plot, parameter_dict = st.session_state[StateKeys.LLM_INPUT] -st.markdown("#### Analysis") c1, c2 = st.columns((1, 2)) with c1: - method = select_analysis() - chosen_parameter_dict = helper_compare_two_groups() - - method = st.selectbox( - "Differential Analysis using:", - options=["ttest", "anova", "wald", "sam", "paired-ttest", "welch-ttest"], - ) - chosen_parameter_dict.update({"method": method}) - - # TODO streamlit doesnt allow nested columns check for updates - - labels = st.checkbox("Add label", value=True) - - draw_line = st.checkbox("Draw line", value=True) - - alpha = st.number_input( - label="alpha", min_value=0.001, max_value=0.050, value=0.050 - ) - - organism = st.number_input( - label="UniProt organism ID, for example human is 9606, R. norvegicus is 10116", - value=9606, - ) - st.session_state[StateKeys.ORGANISM] = organism - - min_fc = st.select_slider("Foldchange cutoff", range(0, 3), value=1) - - plotting_parameter_dict = { - "labels": labels, - "draw_line": draw_line, - "alpha": alpha, - "min_fc": min_fc, - } - - if method == "sam": - perm = st.number_input( - label="Number of Permutations", min_value=1, max_value=1000, value=10 - ) - fdr = st.number_input( - label="FDR cut off", min_value=0.005, max_value=0.1, value=0.050 - ) - chosen_parameter_dict.update({"perm": perm, "fdr": fdr}) - - plot_submitted = st.button("Plot") - if plot_submitted: - st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] += 1 - - -if ( - st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] - < st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] -): - st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] = st.session_state[ - StateKeys.PLOT_SUBMITTED_CLICKED - ] - volcano_plot = gui_volcano_plot_differential_expression_analysis( - chosen_parameter_dict - ) - volcano_plot._update(plotting_parameter_dict) - volcano_plot._annotate_result_df() - volcano_plot._plot() - genes_of_interest_colored = volcano_plot.get_colored_labels() + # TODO move this to volcano anyway ? genes_of_interest_colored_df = volcano_plot.get_colored_labels_df() - print(genes_of_interest_colored_df) gene_names_colname = st.session_state[StateKeys.LOADER].gene_names prot_ids_colname = st.session_state[StateKeys.LOADER].index_column - st.session_state[StateKeys.PROT_ID_TO_GENE] = dict( - zip( - genes_of_interest_colored_df[prot_ids_colname].tolist(), - genes_of_interest_colored_df[gene_names_colname].tolist(), - ) - ) + # st.session_state[StateKeys.PROT_ID_TO_GENE] = dict( + # zip( + # genes_of_interest_colored_df[prot_ids_colname].tolist(), + # genes_of_interest_colored_df[gene_names_colname].tolist(), + # ) + # ) # TODO unused? st.session_state[StateKeys.GENE_TO_PROT_ID] = dict( zip( genes_of_interest_colored_df[gene_names_colname].tolist(), @@ -192,164 +82,119 @@ def select_analysis(): with c2: display_figure(volcano_plot.plot) + genes_of_interest_colored = volcano_plot.get_colored_labels() if not genes_of_interest_colored: st.text("No proteins of interest found.") st.stop() - print("genes_of_interest", genes_of_interest_colored) - save_plot_to_session_state(method, volcano_plot) - st.session_state[StateKeys.GENES_OF_INTEREST_COLORED] = genes_of_interest_colored # st.session_state["gene_functions"] = get_info(genes_of_interest_colored, organism) - st.session_state[StateKeys.UPREGULATED] = [ + upregulated_genes = [ key for key in genes_of_interest_colored if genes_of_interest_colored[key] == "up" ] - st.session_state[StateKeys.DOWNREGULATED] = [ + downregulated_genes = [ key for key in genes_of_interest_colored if genes_of_interest_colored[key] == "down" ] - st.subheader("Genes of interest") - c1, c2 = st.columns((1, 2), gap="medium") - with c1: - st.write("Upregulated genes") - display_proteins(st.session_state[StateKeys.UPREGULATED], []) - with c2: - st.write("Downregulated genes") - display_proteins([], st.session_state[StateKeys.DOWNREGULATED]) - -elif ( - st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] > 0 - and st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] - == st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] - and len(st.session_state[StateKeys.PLOT_LIST]) > 0 -): - with c2: - display_figure(st.session_state[StateKeys.PLOT_LIST][-1][1].plot) st.subheader("Genes of interest") c1, c2 = st.columns((1, 2), gap="medium") with c1: st.write("Upregulated genes") - display_proteins(st.session_state[StateKeys.UPREGULATED], []) + display_proteins(upregulated_genes, []) with c2: st.write("Downregulated genes") - display_proteins([], st.session_state[StateKeys.DOWNREGULATED]) + display_proteins([], downregulated_genes) + -st.session_state[StateKeys.INSTRUCTIONS] = ( +st.subheader("Prompts generated based on gene functions") + +subgroups = get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata) +system_message = ( f"You are an expert biologist and have extensive experience in molecular biology, medicine and biochemistry.{os.linesep}" "A user will present you with data regarding proteins upregulated in certain cells " "sourced from UniProt and abstracts from scientific publications. They seek your " "expertise in understanding the connections between these proteins and their potential role " f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. Use formatting to make your response more human readable." - f"The data you have has following groups and respective subgroups: {str(get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata))}." + f"The data you have has following groups and respective subgroups: {str(subgroups)}." "Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that. If the data coming to" " you from a function has references to the literature (for example, PubMed), always quote the references in your response." ) -if "column" in chosen_parameter_dict and StateKeys.UPREGULATED in st.session_state: - st.session_state[StateKeys.USER_PROMPT] = ( - f"We've recently identified several proteins that appear to be differently regulated in cells " - f"when comparing {chosen_parameter_dict['group1']} and {chosen_parameter_dict['group2']} in the {chosen_parameter_dict['column']} group. " - f"From our proteomics experiments, we know that the following ones are upregulated: {', '.join(st.session_state[StateKeys.UPREGULATED])}.{os.linesep}{os.linesep}" - f"Here is the list of proteins that are downregulated: {', '.join(st.session_state[StateKeys.DOWNREGULATED])}.{os.linesep}{os.linesep}" - f"Help us understand the potential connections between these proteins and how they might be contributing " - f"to the differences. After that provide a high level summary" - ) - -if StateKeys.USER_PROMPT in st.session_state: - st.subheader("Automatically generated prompt based on gene functions:") - with st.expander("Adjust system prompt (see example below)", expanded=False): - st.session_state[StateKeys.INSTRUCTIONS] = st.text_area( - "", value=st.session_state[StateKeys.INSTRUCTIONS], height=150 - ) +user_prompt = ( + f"We've recently identified several proteins that appear to be differently regulated in cells " + f"when comparing {parameter_dict['group1']} and {parameter_dict['group2']} in the {parameter_dict['column']} group. " + f"From our proteomics experiments, we know that the following ones are upregulated: {', '.join(upregulated_genes)}.{os.linesep}{os.linesep}" + f"Here is the list of proteins that are downregulated: {', '.join(downregulated_genes)}.{os.linesep}{os.linesep}" + f"Help us understand the potential connections between these proteins and how they might be contributing " + f"to the differences. After that provide a high level summary" +) - with st.expander("Adjust user prompt", expanded=True): - st.session_state[StateKeys.USER_PROMPT] = st.text_area( - "", value=st.session_state[StateKeys.USER_PROMPT], height=200 - ) +with st.expander("System message", expanded=False): + system_message = st.text_area("", value=system_message, height=150) -gpt_submitted = st.button("Run LLM analysis") +with st.expander("User prompt", expanded=True): + user_prompt = st.text_area("", value=user_prompt, height=200) -if gpt_submitted and StateKeys.USER_PROMPT not in st.session_state: - st.warning("Please enter a user prompt first") - st.stop() +llm_submitted = st.button("Run LLM analysis") -if gpt_submitted: - st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] += 1 # creating new assistant only once TODO: add a button to create new assistant -if ( - st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] - > st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] -): - if st.session_state[StateKeys.API_TYPE] == "gpt4o": - set_api_key() +if StateKeys.LLM_INTEGRATION not in st.session_state: + if not llm_submitted: + st.stop() try: if st.session_state[StateKeys.API_TYPE] == "gpt4o": - st.session_state[StateKeys.LLM_INTEGRATION] = LLMIntegration( + llm = LLMIntegration( api_type="gpt", api_key=st.session_state[StateKeys.OPENAI_API_KEY], dataset=st.session_state[StateKeys.DATASET], metadata=st.session_state[StateKeys.DATASET].metadata, ) else: - st.session_state[StateKeys.LLM_INTEGRATION] = LLMIntegration( + llm = LLMIntegration( api_type="ollama", base_url=os.getenv("OLLAMA_BASE_URL", None), dataset=st.session_state[StateKeys.DATASET], metadata=st.session_state[StateKeys.DATASET].metadata, ) + + # Set instructions and update tools + llm.tools = [ + *get_general_assistant_functions(), + *get_assistant_functions( + gene_to_prot_id_dict=st.session_state[StateKeys.GENE_TO_PROT_ID], + metadata=st.session_state[StateKeys.DATASET].metadata, + subgroups_for_each_group=get_subgroups_for_each_group( + st.session_state[StateKeys.DATASET].metadata + ), + ), + ] + + st.session_state[StateKeys.ARTIFACTS] = {} + llm.messages = [{"role": "system", "content": system_message}] + + st.session_state[StateKeys.LLM_INTEGRATION] = llm st.success( f"{st.session_state[StateKeys.API_TYPE].upper()} integration initialized successfully!" ) + + response = llm.chat_completion(user_prompt) + except AuthenticationError: st.warning( "Incorrect API key provided. Please enter a valid API key, it should look like this: sk-XXXXX" ) st.stop() -if ( - StateKeys.LLM_INTEGRATION not in st.session_state - or not st.session_state[StateKeys.LLM_INTEGRATION] -): - st.warning("Please initialize the model first") - st.stop() -llm = st.session_state[StateKeys.LLM_INTEGRATION] - -# Set instructions and update tools -llm.tools = [ - *get_general_assistant_functions(), - *get_assistant_functions( - gene_to_prot_id_dict=st.session_state[StateKeys.GENE_TO_PROT_ID], - metadata=st.session_state[StateKeys.DATASET].metadata, - subgroups_for_each_group=get_subgroups_for_each_group( - st.session_state[StateKeys.DATASET].metadata - ), - ), -] - -if StateKeys.ARTIFACTS not in st.session_state: - st.session_state[StateKeys.ARTIFACTS] = {} - -if ( - st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] - < st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] -): - st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] = st.session_state[ - StateKeys.GPT_SUBMITTED_CLICKED - ] - st.session_state[StateKeys.ARTIFACTS] = {} - llm.messages = [ - {"role": "system", "content": st.session_state[StateKeys.INSTRUCTIONS]} - ] - response = llm.chat_completion(st.session_state[StateKeys.USER_PROMPT]) +@st.fragment +def llm_chat(): + """The chat interface for the LLM analysis.""" + llm = st.session_state[StateKeys.LLM_INTEGRATION] -if st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] > 0: - if prompt := st.chat_input("Say something"): - response = llm.chat_completion(prompt) for num, role_content_dict in enumerate(st.session_state[StateKeys.MESSAGES]): if role_content_dict["role"] == "tool" or role_content_dict["role"] == "system": continue @@ -363,3 +208,10 @@ def select_analysis(): st.dataframe(artefact) elif "plotly" in str(type(artefact)): st.plotly_chart(artefact) + + if prompt := st.chat_input("Say something"): + llm.chat_completion(prompt) + st.rerun(scope="fragment") + + +llm_chat() diff --git a/alphastats/gui/utils/analysis_helper.py b/alphastats/gui/utils/analysis_helper.py index e54a10da..39cc147c 100644 --- a/alphastats/gui/utils/analysis_helper.py +++ b/alphastats/gui/utils/analysis_helper.py @@ -1,5 +1,5 @@ import io -from typing import Any, Dict +from typing import Any, Dict, Optional, Tuple import pandas as pd import streamlit as st @@ -143,9 +143,11 @@ def gui_volcano_plot_differential_expression_analysis( return volcano_plot -def gui_volcano_plot(): - """ - Draw Volcano Plot using the VolcanoPlot class +def gui_volcano_plot() -> Tuple[Optional[Any], Optional[Any], Optional[Dict]]: + """Draw Volcano Plot using the VolcanoPlot class. + + Returns a tuple(figure, analysis_object, parameters) where figure is the plot, + analysis_object is the underlying object, parameters is a dictionary of the parameters used. """ chosen_parameter_dict = helper_compare_two_groups() method = st.selectbox( @@ -192,11 +194,22 @@ def gui_volcano_plot(): volcano_plot._update(plotting_parameter_dict) volcano_plot._annotate_result_df() volcano_plot._plot() - return volcano_plot.plot + return volcano_plot.plot, volcano_plot, chosen_parameter_dict + + return None, None, None + +def do_analysis( + method: str, options_dict: Dict[str, Any] +) -> Tuple[Optional[Any], Optional[Any], Dict[str, Any]]: + """Extract plotting options and display. -def do_analysis(method: str, options_dict: Dict[str, Any]) -> Any: - """Extract plotting options and display.""" + Returns a tuple(figure, analysis_object, parameters) where figure is the plot, + analysis_object is the underlying object, parameters is a dictionary of the parameters used. + + Currently, analysis_object is only not-None for Volcano Plot. + # TODO unify the API of all analysis methods + """ method_dict = options_dict.get(method) @@ -225,7 +238,9 @@ def do_analysis(method: str, options_dict: Dict[str, Any]) -> Any: if submitted: with st.spinner("Calculating..."): - return method_dict["function"](**parameters) + return method_dict["function"](**parameters), None, parameters + + return None, None, {} # TODO try to cover all those by st_general() diff --git a/alphastats/gui/utils/openai_utils.py b/alphastats/gui/utils/openai_utils.py index 0f457b4e..aeda658c 100644 --- a/alphastats/gui/utils/openai_utils.py +++ b/alphastats/gui/utils/openai_utils.py @@ -1,11 +1,12 @@ import json import time +from pathlib import Path from typing import List, Optional import openai import streamlit as st -from gui.utils.options import get_plotting_options +from alphastats.gui.utils.options import get_plotting_options from alphastats.gui.utils.ui_helper import StateKeys try: @@ -177,27 +178,24 @@ def set_api_key(api_key: str = None) -> None: """ if api_key: st.session_state[StateKeys.OPENAI_API_KEY] = api_key - # TODO we should not write secrets to disk without user consent - # secret_path = Path("./.streamlit/secrets.toml") - # secret_path.parent.mkdir(parents=True, exist_ok=True) - # with open(secret_path, "w") as f: - # f.write(f'openai_api_key = "{api_key}"') - # openai.OpenAI.api_key = st.session_state[StateKeys.OPENAI_API_KEY"] - # return - elif StateKeys.OPENAI_API_KEY in st.session_state: + + if StateKeys.OPENAI_API_KEY in st.session_state: api_key = st.session_state[StateKeys.OPENAI_API_KEY] + st.info(f"OpenAI API key set: {api_key[:3]}***{api_key[-3:]}") else: try: - api_key = st.secrets["openai_api_key"] - except FileNotFoundError: - st.info( - "Please enter an OpenAI key or provide it in a secrets.toml file in the " - "alphastats/gui/.streamlit directory like " - "`openai_api_key = `" - ) + if Path("./.streamlit/secrets.toml").exists(): + api_key = st.secrets["openai_api_key"] + st.info("OpenAI API key loaded from secrets.toml.") + else: + st.info( + "Please enter an OpenAI key or provide it in a secrets.toml file in the " + "alphastats/gui/.streamlit directory like " + "`openai_api_key = `" + ) except KeyError: - st.write("OpenAI API key not found in secrets.") + st.error("OpenAI API key not found in secrets.toml .") except Exception as e: - st.write(f"Error loading OpenAI API key: {e}.") + st.error(f"Error loading OpenAI API key: {e}.") openai.OpenAI.api_key = api_key diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index 30bbbdfd..86059963 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -110,22 +110,9 @@ class StateKeys: # LLM OPENAI_API_KEY = "openai_api_key" # pragma: allowlist secret API_TYPE = "api_type" - LLM_INTEGRATION = "llm_integration" - - PLOT_SUBMITTED_CLICKED = "plot_submitted_clicked" - PLOT_SUBMITTED_COUNTER = "plot_submitted_counter" - LOOKUP_SUBMITTED_CLICKED = "lookup_submitted_clicked" - LOOKUP_SUBMITTED_COUNTER = "lookup_submitted_counter" + LLM_INPUT = "llm_input" - GPT_SUBMITTED_CLICKED = "gpt_submitted_clicked" - GPT_SUBMITTED_COUNTER = "gpt_submitted_counter" - - INSTRUCTIONS = "instructions" - USER_PROMPT = "user_prompt" + LLM_INTEGRATION = "llm_integration" MESSAGES = "messages" ARTIFACTS = "artifacts" - PROT_ID_TO_GENE = "prot_id_to_gene" - GENES_OF_INTEREST_COLORED = "genes_of_interest_colored" - UPREGULATED = "upregulated" - DOWNREGULATED = "downregulated"