Skip to content

Commit

Permalink
Merge pull request #358 from MannLabs/refactor_llm_VII
Browse files Browse the repository at this point in the history
Refactor llm vii
  • Loading branch information
mschwoer authored Nov 8, 2024
2 parents ff26e4f + 4d1670c commit 85a9409
Show file tree
Hide file tree
Showing 12 changed files with 970 additions and 146 deletions.
21 changes: 21 additions & 0 deletions alphastats/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,27 @@ def plot_umap(self, group: Optional[str] = None, circle: bool = False):
)
return dimensionality_reduction.plot

def perform_dimensionality_reduction(
self, method: str, group: Optional[str] = None, circle: bool = False
):
"""Generic wrapper for dimensionality reduction methods to be used by LLM.
Args:
method (str): "pca", "tsne", "umap"
group (str, optional): column in metadata that should be used for coloring. Defaults to None.
circle (bool, optional): draw circle around each group. Defaults to False.
"""

result = {
"pca": self.plot_pca,
"tsne": self.plot_tsne,
"umap": self.plot_umap,
}.get(method)
if result is None:
raise ValueError(f"Invalid method: {method}")

return result(group=group, circle=circle)

@ignore_warning(RuntimeWarning)
def plot_volcano(
self,
Expand Down
2 changes: 1 addition & 1 deletion alphastats/gui/pages/04_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def show_start_llm_button(method: str) -> None:

msg = (
"(this will overwrite the existing LLM analysis!)"
if StateKeys.LLM_INTEGRATION in st.session_state
if st.session_state.get(StateKeys.LLM_INTEGRATION, {}) != {}
else ""
)

Expand Down
75 changes: 50 additions & 25 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
display_figure,
)
from alphastats.gui.utils.llm_helper import (
display_proteins,
get_display_proteins_html,
llm_connection_test,
set_api_key,
test_llm_connection,
)
from alphastats.gui.utils.ui_helper import StateKeys, init_session_state, sidebar_info
from alphastats.llm.llm_integration import LLMIntegration, Models
Expand All @@ -36,20 +36,24 @@ def llm_config():
"""Show the configuration options for the LLM analysis."""
c1, _ = st.columns((1, 2))
with c1:
model_before = st.session_state.get(StateKeys.API_TYPE, None)
current_model = st.session_state.get(StateKeys.MODEL_NAME, None)

st.session_state[StateKeys.API_TYPE] = st.selectbox(
models = [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]
st.session_state[StateKeys.MODEL_NAME] = st.selectbox(
"Select LLM",
[Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B],
models,
index=models.index(st.session_state.get(StateKeys.MODEL_NAME))
if current_model is not None
else 0,
)

base_url = None
if st.session_state[StateKeys.API_TYPE] in [Models.GPT4O]:
if st.session_state[StateKeys.MODEL_NAME] in [Models.GPT4O]:
api_key = st.text_input(
"Enter OpenAI API Key and press Enter", type="password"
)
set_api_key(api_key)
elif st.session_state[StateKeys.API_TYPE] in [
elif st.session_state[StateKeys.MODEL_NAME] in [
Models.OLLAMA_31_70B,
Models.OLLAMA_31_8B,
]:
Expand All @@ -58,13 +62,18 @@ def llm_config():

test_connection = st.button("Test connection")
if test_connection:
test_llm_connection(
api_type=st.session_state[StateKeys.API_TYPE],
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=base_url,
)

if model_before != st.session_state[StateKeys.API_TYPE]:
with st.spinner(f"Testing connection to {model_name}.."):
error = llm_connection_test(
model_name=st.session_state[StateKeys.MODEL_NAME],
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=base_url,
)
if error is None:
st.success(f"Connection to {model_name} successful!")
else:
st.error(f"Connection to {model_name} failed: {str(error)}")

if current_model != st.session_state[StateKeys.MODEL_NAME]:
st.rerun(scope="app")


Expand Down Expand Up @@ -120,17 +129,24 @@ def llm_config():
c1, c2 = st.columns((1, 2), gap="medium")
with c1:
st.write("Upregulated genes")
display_proteins(upregulated_genes, [])
st.markdown(
get_display_proteins_html(upregulated_genes, True), unsafe_allow_html=True
)

with c2:
st.write("Downregulated genes")
display_proteins([], downregulated_genes)
st.markdown(
get_display_proteins_html(downregulated_genes, False),
unsafe_allow_html=True,
)


st.markdown("##### Prompts generated based on analysis input")

api_type = st.session_state[StateKeys.API_TYPE]
model_name = st.session_state[StateKeys.MODEL_NAME]
llm_integration_set_for_model = (
st.session_state.get(StateKeys.LLM_INTEGRATION, {}).get(api_type, None) is not None
st.session_state.get(StateKeys.LLM_INTEGRATION, {}).get(model_name, None)
is not None
)
with st.expander("System message", expanded=False):
system_message = st.text_area(
Expand All @@ -151,30 +167,39 @@ def llm_config():
)


st.markdown(f"##### LLM Analysis with {api_type}")
st.markdown(f"##### LLM Analysis with {model_name}")

llm_submitted = st.button(
c1, c2, _ = st.columns((0.2, 0.2, 0.6))
llm_submitted = c1.button(
"Run LLM analysis ...", disabled=llm_integration_set_for_model
)

if st.session_state[StateKeys.LLM_INTEGRATION].get(api_type) is None:
llm_reset = c2.button(
"Reset LLM analysis ...", disabled=not llm_integration_set_for_model
)
if llm_reset:
del st.session_state[StateKeys.LLM_INTEGRATION]
st.rerun()


if st.session_state[StateKeys.LLM_INTEGRATION].get(model_name) is None:
if not llm_submitted:
st.stop()

try:
llm_integration = LLMIntegration(
api_type=api_type,
model_name=model_name,
system_message=system_message,
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=OLLAMA_BASE_URL,
dataset=st.session_state[StateKeys.DATASET],
gene_to_prot_id_map=gene_to_prot_id_map,
)

st.session_state[StateKeys.LLM_INTEGRATION][api_type] = llm_integration
st.session_state[StateKeys.LLM_INTEGRATION][model_name] = llm_integration

st.success(
f"{st.session_state[StateKeys.API_TYPE]} integration initialized successfully!"
f"{st.session_state[StateKeys.MODEL_NAME]} integration initialized successfully!"
)

with st.spinner("Processing initial prompt..."):
Expand Down Expand Up @@ -221,4 +246,4 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
help="Show all messages in the chat interface.",
)

llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][api_type], show_all)
llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][model_name], show_all)
55 changes: 23 additions & 32 deletions alphastats/gui/utils/llm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,24 @@
from alphastats.llm.llm_integration import LLMIntegration


def display_proteins(overexpressed: List[str], underexpressed: List[str]) -> None:
def get_display_proteins_html(protein_ids: List[str], is_upregulated: True) -> str:
"""
Display a list of overexpressed and underexpressed proteins in a Streamlit app.
Get HTML code for displaying a list of proteins, color according to expression.
Args:
overexpressed (list[str]): A list of overexpressed proteins.
underexpressed (list[str]): A list of underexpressed proteins.
protein_ids (list[str]): a list of proteins.
is_upregulated (bool): whether the proteins are up- or down-regulated.
"""

# Start with the overexpressed proteins
link = "https://www.uniprot.org/uniprotkb?query="
overexpressed_html = "".join(
f'<a href = {link + protein}><li style="color: green;">{protein}</li></a>'
for protein in overexpressed
)
# Continue with the underexpressed proteins
underexpressed_html = "".join(
f'<a href = {link + protein}><li style="color: red;">{protein}</li></a>'
for protein in underexpressed
)
uniprot_url = "https://www.uniprot.org/uniprotkb?query="

# Combine both lists into one HTML string
full_html = f"<ul>{overexpressed_html}{underexpressed_html}</ul>"
color = "green" if is_upregulated else "red"
protein_ids_html = "".join(
f'<a href = {uniprot_url + protein}><li style="color: {color};">{protein}</li></a>'
for protein in protein_ids
)

# Display in Streamlit
st.markdown(full_html, unsafe_allow_html=True)
return f"<ul>{protein_ids_html}</ul>"


def set_api_key(api_key: str = None) -> None:
Expand Down Expand Up @@ -71,21 +63,20 @@ def set_api_key(api_key: str = None) -> None:
st.session_state[StateKeys.OPENAI_API_KEY] = api_key


def test_llm_connection(
api_type: str,
def llm_connection_test(
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
):
"""Test the connection to the LLM API."""
) -> Optional[str]:
"""Test the connection to the LLM API, return None in case of success, error message otherwise."""
try:
llm = LLMIntegration(api_type, base_url=base_url, api_key=api_key)

with st.spinner(f"Testing connection to {api_type} ..."):
llm.chat_completion("Hello, this is a test!")

st.success(f"Connection to {api_type} successful!")
return True
llm = LLMIntegration(
model_name, base_url=base_url, api_key=api_key, load_tools=False
)
llm.chat_completion(
"This is a test. Simply respond 'yes' if you got this message."
)
return None

except Exception as e:
st.error(f"❌ Connection to {api_type} failed: {e}")
return False
return str(e)
2 changes: 1 addition & 1 deletion alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class StateKeys:

# LLM
OPENAI_API_KEY = "openai_api_key" # pragma: allowlist secret
API_TYPE = "api_type"
MODEL_NAME = "model_name"

LLM_INPUT = "llm_input"

Expand Down
37 changes: 15 additions & 22 deletions alphastats/llm/llm_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

import pandas as pd

from alphastats.plots.DimensionalityReduction import DimensionalityReduction
from alphastats.DataSet import DataSet
from alphastats.llm.enrichment_analysis import get_enrichment_data
from alphastats.llm.uniprot_utils import get_gene_function

GENERAL_FUNCTION_MAPPING = {
"get_gene_function": get_gene_function,
"get_enrichment_data": get_enrichment_data,
}


def get_general_assistant_functions() -> List[Dict]:
Expand All @@ -17,7 +24,7 @@ def get_general_assistant_functions() -> List[Dict]:
{
"type": "function",
"function": {
"name": "get_gene_function",
"name": get_gene_function.__name__,
"description": "Get the gene function and description by UniProt lookup of gene identifier/name",
"parameters": {
"type": "object",
Expand All @@ -34,7 +41,7 @@ def get_general_assistant_functions() -> List[Dict]:
{
"type": "function",
"function": {
"name": "get_enrichment_data",
"name": get_enrichment_data.__name__,
"description": "Get enrichment data for a list of differentially expressed genes",
"parameters": {
"type": "object",
Expand Down Expand Up @@ -83,12 +90,12 @@ def get_assistant_functions(
{
"type": "function",
"function": {
"name": "plot_intensity",
"name": DataSet.plot_intensity.__name__,
"description": "Create an intensity plot based on protein data and analytical methods.",
"parameters": {
"type": "object",
"properties": {
"gene_name": { # this will be mapped to "protein_id" when calling the function
"protein_id": { # LLM will provide gene_name, mapping to protein_id is done when calling the function
"type": "string",
"enum": gene_names,
"description": "Identifier for the gene of interest",
Expand Down Expand Up @@ -125,7 +132,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
"name": "perform_dimensionality_reduction",
"name": DataSet.perform_dimensionality_reduction.__name__,
"description": "Perform dimensionality reduction on a given dataset and generate a plot.",
"parameters": {
"type": "object",
Expand All @@ -152,7 +159,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
"name": "plot_sampledistribution",
"name": DataSet.plot_sampledistribution.__name__,
"description": "Generates a histogram plot for each sample in the dataset matrix.",
"parameters": {
"type": "object",
Expand All @@ -175,7 +182,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
"name": "plot_volcano",
"name": DataSet.plot_volcano.__name__,
"description": "Generates a volcano plot based on two subgroups of the same group",
"parameters": {
"type": "object",
Expand Down Expand Up @@ -235,17 +242,3 @@ def get_assistant_functions(
},
# {"type": "code_interpreter"},
]


def perform_dimensionality_reduction(dataset, group, method, circle, **kwargs):
dr = DimensionalityReduction(
mat=dataset.mat,
metadate=dataset.metadata,
sample=dataset.sample,
preprocessing_info=dataset.preprocessing_info,
group=group,
circle=circle,
method=method,
**kwargs,
)
return dr.plot
Loading

0 comments on commit 85a9409

Please sign in to comment.