Skip to content

Commit

Permalink
clean up session state
Browse files Browse the repository at this point in the history
  • Loading branch information
mschwoer committed Oct 21, 2024
1 parent ad9f175 commit 488c7ac
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 69 deletions.
3 changes: 2 additions & 1 deletion alphastats/gui/pages/04_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,6 @@
"Download as .csv", csv, method + ".csv", "text/csv", key="download-csv"
)

# TODO this is still quite rough, should be a list, mb add a button etc..
if method == "Volcano Plot" and analysis_result is not None:
st.session_state["LLM"] = (analysis_object, parameters)
st.session_state[StateKeys.LLM_INPUT] = (analysis_object, parameters)
66 changes: 30 additions & 36 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@
st.info("Import Data first")
st.stop()

if "LLM" not in st.session_state:
st.info("Create a Volcano plot first using the 'Analysis' page.")
st.stop()

volcano_plot, chosen_parameter_dict = st.session_state["LLM"]


st.markdown("### LLM Analysis")

Expand All @@ -43,7 +37,6 @@ def llm_config():
st.session_state[StateKeys.API_TYPE] = st.selectbox(
"Select LLM",
["gpt4o", "llama3.1 70b"],
# index=0 if st.session_state[StateKeys.API_TYPE] == "gpt4o" else 1,
)

if st.session_state[StateKeys.API_TYPE] == "gpt4o":
Expand All @@ -55,6 +48,13 @@ def llm_config():

st.markdown("#### Analysis")


if StateKeys.LLM_INPUT not in st.session_state:
st.info("Create a Volcano plot first using the 'Analysis' page.")
st.stop()

volcano_plot, parameter_dict = st.session_state[StateKeys.LLM_INPUT]

c1, c2 = st.columns((1, 2))

with c1:
Expand Down Expand Up @@ -86,12 +86,12 @@ def llm_config():
st.stop()

# st.session_state["gene_functions"] = get_info(genes_of_interest_colored, organism)
st.session_state[StateKeys.UPREGULATED] = [
upregulated_genes = [
key
for key in genes_of_interest_colored
if genes_of_interest_colored[key] == "up"
]
st.session_state[StateKeys.DOWNREGULATED] = [
downregulated_genes = [
key
for key in genes_of_interest_colored
if genes_of_interest_colored[key] == "down"
Expand All @@ -101,43 +101,39 @@ def llm_config():
c1, c2 = st.columns((1, 2), gap="medium")
with c1:
st.write("Upregulated genes")
display_proteins(st.session_state[StateKeys.UPREGULATED], [])
display_proteins(upregulated_genes, [])
with c2:
st.write("Downregulated genes")
display_proteins([], st.session_state[StateKeys.DOWNREGULATED])
display_proteins([], downregulated_genes)


st.subheader("Prompts generated based on gene functions")

st.session_state[StateKeys.INSTRUCTIONS] = (
subgroups = get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata)
system_message = (
f"You are an expert biologist and have extensive experience in molecular biology, medicine and biochemistry.{os.linesep}"
"A user will present you with data regarding proteins upregulated in certain cells "
"sourced from UniProt and abstracts from scientific publications. They seek your "
"expertise in understanding the connections between these proteins and their potential role "
f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. Use formatting to make your response more human readable."
f"The data you have has following groups and respective subgroups: {str(get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata))}."
f"The data you have has following groups and respective subgroups: {str(subgroups)}."
"Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that. If the data coming to"
" you from a function has references to the literature (for example, PubMed), always quote the references in your response."
)
if "column" in chosen_parameter_dict and StateKeys.UPREGULATED in st.session_state:
st.session_state[StateKeys.USER_PROMPT] = (
f"We've recently identified several proteins that appear to be differently regulated in cells "
f"when comparing {chosen_parameter_dict['group1']} and {chosen_parameter_dict['group2']} in the {chosen_parameter_dict['column']} group. "
f"From our proteomics experiments, we know that the following ones are upregulated: {', '.join(st.session_state[StateKeys.UPREGULATED])}.{os.linesep}{os.linesep}"
f"Here is the list of proteins that are downregulated: {', '.join(st.session_state[StateKeys.DOWNREGULATED])}.{os.linesep}{os.linesep}"
f"Help us understand the potential connections between these proteins and how they might be contributing "
f"to the differences. After that provide a high level summary"
)
user_prompt = (
f"We've recently identified several proteins that appear to be differently regulated in cells "
f"when comparing {parameter_dict['group1']} and {parameter_dict['group2']} in the {parameter_dict['column']} group. "
f"From our proteomics experiments, we know that the following ones are upregulated: {', '.join(upregulated_genes)}.{os.linesep}{os.linesep}"
f"Here is the list of proteins that are downregulated: {', '.join(downregulated_genes)}.{os.linesep}{os.linesep}"
f"Help us understand the potential connections between these proteins and how they might be contributing "
f"to the differences. After that provide a high level summary"
)

if StateKeys.USER_PROMPT in st.session_state:
st.subheader("Automatically generated prompt based on gene functions:")
with st.expander("System prompt", expanded=True):
st.session_state[StateKeys.INSTRUCTIONS] = st.text_area(
"", value=st.session_state[StateKeys.INSTRUCTIONS], height=150
)
with st.expander("System message", expanded=False):
system_message = st.text_area("", value=system_message, height=150)

with st.expander("User prompt", expanded=True):
st.session_state[StateKeys.USER_PROMPT] = st.text_area(
"", value=st.session_state[StateKeys.USER_PROMPT], height=200
)
with st.expander("User prompt", expanded=True):
user_prompt = st.text_area("", value=user_prompt, height=200)

llm_submitted = st.button("Run LLM analysis")

Expand Down Expand Up @@ -176,16 +172,14 @@ def llm_config():
]

st.session_state[StateKeys.ARTIFACTS] = {}
llm.messages = [
{"role": "system", "content": st.session_state[StateKeys.INSTRUCTIONS]}
]
llm.messages = [{"role": "system", "content": system_message}]

st.session_state[StateKeys.LLM_INTEGRATION] = llm
st.success(
f"{st.session_state[StateKeys.API_TYPE].upper()} integration initialized successfully!"
)

response = llm.chat_completion(st.session_state[StateKeys.USER_PROMPT])
response = llm.chat_completion(user_prompt)

except AuthenticationError:
st.warning(
Expand Down
29 changes: 12 additions & 17 deletions alphastats/gui/utils/openai_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import time
from pathlib import Path
from typing import List, Optional

import openai
import streamlit as st
from gui.utils.options import get_plotting_options

from alphastats.gui.utils.options import get_plotting_options
from alphastats.gui.utils.ui_helper import StateKeys

try:
Expand Down Expand Up @@ -177,27 +178,21 @@ def set_api_key(api_key: str = None) -> None:
"""
if api_key:
st.session_state[StateKeys.OPENAI_API_KEY] = api_key
# TODO we should not write secrets to disk without user consent
# secret_path = Path("./.streamlit/secrets.toml")
# secret_path.parent.mkdir(parents=True, exist_ok=True)
# with open(secret_path, "w") as f:
# f.write(f'openai_api_key = "{api_key}"')
# openai.OpenAI.api_key = st.session_state[StateKeys.OPENAI_API_KEY"]
# return
elif StateKeys.OPENAI_API_KEY in st.session_state:
api_key = st.session_state[StateKeys.OPENAI_API_KEY]
else:
try:
api_key = st.secrets["openai_api_key"]
except FileNotFoundError:
st.info(
"Please enter an OpenAI key or provide it in a secrets.toml file in the "
"alphastats/gui/.streamlit directory like "
"`openai_api_key = <key>`"
)
if Path("./.streamlit/secrets.toml").exists():
api_key = st.secrets["openai_api_key"]
else:
st.info(
"Please enter an OpenAI key or provide it in a secrets.toml file in the "
"alphastats/gui/.streamlit directory like "
"`openai_api_key = <key>`"
)
except KeyError:
st.write("OpenAI API key not found in secrets.")
st.error("OpenAI API key not found in secrets.toml .")
except Exception as e:
st.write(f"Error loading OpenAI API key: {e}.")
st.error(f"Error loading OpenAI API key: {e}.")

openai.OpenAI.api_key = api_key
17 changes: 2 additions & 15 deletions alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,9 @@ class StateKeys:
# LLM
OPENAI_API_KEY = "openai_api_key" # pragma: allowlist secret
API_TYPE = "api_type"
LLM_INTEGRATION = "llm_integration"

PLOT_SUBMITTED_CLICKED = "plot_submitted_clicked"
PLOT_SUBMITTED_COUNTER = "plot_submitted_counter"

LOOKUP_SUBMITTED_CLICKED = "lookup_submitted_clicked"
LOOKUP_SUBMITTED_COUNTER = "lookup_submitted_counter"
LLM_INPUT = "llm_input"

GPT_SUBMITTED_CLICKED = "gpt_submitted_clicked"
GPT_SUBMITTED_COUNTER = "gpt_submitted_counter"

INSTRUCTIONS = "instructions"
USER_PROMPT = "user_prompt"
LLM_INTEGRATION = "llm_integration"
MESSAGES = "messages"
ARTIFACTS = "artifacts"
PROT_ID_TO_GENE = "prot_id_to_gene"
GENES_OF_INTEREST_COLORED = "genes_of_interest_colored"
UPREGULATED = "upregulated"
DOWNREGULATED = "downregulated"

0 comments on commit 488c7ac

Please sign in to comment.