Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor llm vii #358

Merged
merged 18 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions alphastats/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,27 @@ def plot_umap(self, group: Optional[str] = None, circle: bool = False):
)
return dimensionality_reduction.plot

def perform_dimensionality_reduction(
self, method: str, group: Optional[str] = None, circle: bool = False
):
"""Generic wrapper for dimensionality reduction methods to be used by LLM.

Args:
method (str): "pca", "tsne", "umap"
group (str, optional): column in metadata that should be used for coloring. Defaults to None.
circle (bool, optional): draw circle around each group. Defaults to False.
"""

result = {
"pca": self.plot_pca,
"tsne": self.plot_tsne,
"umap": self.plot_umap,
}.get(method)
if result is None:
raise ValueError(f"Invalid method: {method}")

return result(group=group, circle=circle)
Comment on lines +314 to +333
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to have something similarly simple for the differential analysis longterm.


@ignore_warning(RuntimeWarning)
def plot_volcano(
self,
Expand Down
2 changes: 1 addition & 1 deletion alphastats/gui/pages/04_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def show_start_llm_button(method: str) -> None:

msg = (
"(this will overwrite the existing LLM analysis!)"
if StateKeys.LLM_INTEGRATION in st.session_state
if st.session_state.get(StateKeys.LLM_INTEGRATION, {}) != {}
else ""
)

Expand Down
75 changes: 50 additions & 25 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
display_figure,
)
from alphastats.gui.utils.llm_helper import (
display_proteins,
get_display_proteins_html,
llm_connection_test,
set_api_key,
test_llm_connection,
)
from alphastats.gui.utils.ui_helper import StateKeys, init_session_state, sidebar_info
from alphastats.llm.llm_integration import LLMIntegration, Models
Expand All @@ -36,20 +36,24 @@ def llm_config():
"""Show the configuration options for the LLM analysis."""
c1, _ = st.columns((1, 2))
with c1:
model_before = st.session_state.get(StateKeys.API_TYPE, None)
current_model = st.session_state.get(StateKeys.MODEL_NAME, None)

st.session_state[StateKeys.API_TYPE] = st.selectbox(
models = [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]
st.session_state[StateKeys.MODEL_NAME] = st.selectbox(
"Select LLM",
[Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B],
models,
index=models.index(st.session_state.get(StateKeys.MODEL_NAME))
if current_model is not None
else 0,
Comment on lines +44 to +47
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key = StateKeys.MODEL_NAME?

)

base_url = None
if st.session_state[StateKeys.API_TYPE] in [Models.GPT4O]:
if st.session_state[StateKeys.MODEL_NAME] in [Models.GPT4O]:
api_key = st.text_input(
"Enter OpenAI API Key and press Enter", type="password"
)
set_api_key(api_key)
elif st.session_state[StateKeys.API_TYPE] in [
elif st.session_state[StateKeys.MODEL_NAME] in [
Models.OLLAMA_31_70B,
Models.OLLAMA_31_8B,
]:
Expand All @@ -58,13 +62,18 @@ def llm_config():

test_connection = st.button("Test connection")
if test_connection:
test_llm_connection(
api_type=st.session_state[StateKeys.API_TYPE],
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=base_url,
)

if model_before != st.session_state[StateKeys.API_TYPE]:
with st.spinner(f"Testing connection to {model_name}.."):
error = llm_connection_test(
model_name=st.session_state[StateKeys.MODEL_NAME],
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=base_url,
)
if error is None:
st.success(f"Connection to {model_name} successful!")
else:
st.error(f"Connection to {model_name} failed: {str(error)}")

if current_model != st.session_state[StateKeys.MODEL_NAME]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the intendend behaviour, then why is the config a st.fragment?

st.rerun(scope="app")


Expand Down Expand Up @@ -120,17 +129,24 @@ def llm_config():
c1, c2 = st.columns((1, 2), gap="medium")
with c1:
st.write("Upregulated genes")
display_proteins(upregulated_genes, [])
st.markdown(
get_display_proteins_html(upregulated_genes, True), unsafe_allow_html=True
)

with c2:
st.write("Downregulated genes")
display_proteins([], downregulated_genes)
st.markdown(
get_display_proteins_html(downregulated_genes, False),
unsafe_allow_html=True,
)


st.markdown("##### Prompts generated based on analysis input")

api_type = st.session_state[StateKeys.API_TYPE]
model_name = st.session_state[StateKeys.MODEL_NAME]
llm_integration_set_for_model = (
st.session_state.get(StateKeys.LLM_INTEGRATION, {}).get(api_type, None) is not None
st.session_state.get(StateKeys.LLM_INTEGRATION, {}).get(model_name, None)
is not None
)
with st.expander("System message", expanded=False):
system_message = st.text_area(
Expand All @@ -151,30 +167,39 @@ def llm_config():
)


st.markdown(f"##### LLM Analysis with {api_type}")
st.markdown(f"##### LLM Analysis with {model_name}")

llm_submitted = st.button(
c1, c2, _ = st.columns((0.2, 0.2, 0.6))
llm_submitted = c1.button(
"Run LLM analysis ...", disabled=llm_integration_set_for_model
)

if st.session_state[StateKeys.LLM_INTEGRATION].get(api_type) is None:
llm_reset = c2.button(
"Reset LLM analysis ...", disabled=not llm_integration_set_for_model
)
if llm_reset:
del st.session_state[StateKeys.LLM_INTEGRATION]
st.rerun()


if st.session_state[StateKeys.LLM_INTEGRATION].get(model_name) is None:
if not llm_submitted:
st.stop()

try:
llm_integration = LLMIntegration(
api_type=api_type,
model_name=model_name,
system_message=system_message,
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=OLLAMA_BASE_URL,
dataset=st.session_state[StateKeys.DATASET],
gene_to_prot_id_map=gene_to_prot_id_map,
)

st.session_state[StateKeys.LLM_INTEGRATION][api_type] = llm_integration
st.session_state[StateKeys.LLM_INTEGRATION][model_name] = llm_integration

st.success(
f"{st.session_state[StateKeys.API_TYPE]} integration initialized successfully!"
f"{st.session_state[StateKeys.MODEL_NAME]} integration initialized successfully!"
)

with st.spinner("Processing initial prompt..."):
Expand Down Expand Up @@ -221,4 +246,4 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
help="Show all messages in the chat interface.",
)

llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][api_type], show_all)
llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][model_name], show_all)
55 changes: 23 additions & 32 deletions alphastats/gui/utils/llm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,24 @@
from alphastats.llm.llm_integration import LLMIntegration


def display_proteins(overexpressed: List[str], underexpressed: List[str]) -> None:
def get_display_proteins_html(protein_ids: List[str], is_upregulated: True) -> str:
"""
Display a list of overexpressed and underexpressed proteins in a Streamlit app.
Get HTML code for displaying a list of proteins, color according to expression.

Args:
overexpressed (list[str]): A list of overexpressed proteins.
underexpressed (list[str]): A list of underexpressed proteins.
protein_ids (list[str]): a list of proteins.
is_upregulated (bool): whether the proteins are up- or down-regulated.
"""

# Start with the overexpressed proteins
link = "https://www.uniprot.org/uniprotkb?query="
overexpressed_html = "".join(
f'<a href = {link + protein}><li style="color: green;">{protein}</li></a>'
for protein in overexpressed
)
# Continue with the underexpressed proteins
underexpressed_html = "".join(
f'<a href = {link + protein}><li style="color: red;">{protein}</li></a>'
for protein in underexpressed
)
uniprot_url = "https://www.uniprot.org/uniprotkb?query="

# Combine both lists into one HTML string
full_html = f"<ul>{overexpressed_html}{underexpressed_html}</ul>"
color = "green" if is_upregulated else "red"
protein_ids_html = "".join(
f'<a href = {uniprot_url + protein}><li style="color: {color};">{protein}</li></a>'
for protein in protein_ids
)

# Display in Streamlit
st.markdown(full_html, unsafe_allow_html=True)
return f"<ul>{protein_ids_html}</ul>"


def set_api_key(api_key: str = None) -> None:
Expand Down Expand Up @@ -71,21 +63,20 @@ def set_api_key(api_key: str = None) -> None:
st.session_state[StateKeys.OPENAI_API_KEY] = api_key


def test_llm_connection(
api_type: str,
def llm_connection_test(
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
):
"""Test the connection to the LLM API."""
) -> Optional[str]:
"""Test the connection to the LLM API, return None in case of success, error message otherwise."""
try:
llm = LLMIntegration(api_type, base_url=base_url, api_key=api_key)

with st.spinner(f"Testing connection to {api_type} ..."):
llm.chat_completion("Hello, this is a test!")

st.success(f"Connection to {api_type} successful!")
return True
llm = LLMIntegration(
model_name, base_url=base_url, api_key=api_key, load_tools=False
)
llm.chat_completion(
"This is a test. Simply respond 'yes' if you got this message."
)
return None

except Exception as e:
st.error(f"❌ Connection to {api_type} failed: {e}")
return False
return str(e)
2 changes: 1 addition & 1 deletion alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class StateKeys:

# LLM
OPENAI_API_KEY = "openai_api_key" # pragma: allowlist secret
API_TYPE = "api_type"
MODEL_NAME = "model_name"

LLM_INPUT = "llm_input"

Expand Down
37 changes: 15 additions & 22 deletions alphastats/llm/llm_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

import pandas as pd

from alphastats.plots.DimensionalityReduction import DimensionalityReduction
from alphastats.DataSet import DataSet
from alphastats.llm.enrichment_analysis import get_enrichment_data
from alphastats.llm.uniprot_utils import get_gene_function

GENERAL_FUNCTION_MAPPING = {
"get_gene_function": get_gene_function,
"get_enrichment_data": get_enrichment_data,
}


def get_general_assistant_functions() -> List[Dict]:
Expand All @@ -17,7 +24,7 @@ def get_general_assistant_functions() -> List[Dict]:
{
"type": "function",
"function": {
"name": "get_gene_function",
"name": get_gene_function.__name__,
"description": "Get the gene function and description by UniProt lookup of gene identifier/name",
"parameters": {
"type": "object",
Expand All @@ -34,7 +41,7 @@ def get_general_assistant_functions() -> List[Dict]:
{
"type": "function",
"function": {
"name": "get_enrichment_data",
"name": get_enrichment_data.__name__,
"description": "Get enrichment data for a list of differentially expressed genes",
"parameters": {
"type": "object",
Expand Down Expand Up @@ -83,12 +90,12 @@ def get_assistant_functions(
{
"type": "function",
"function": {
"name": "plot_intensity",
"name": DataSet.plot_intensity.__name__,
"description": "Create an intensity plot based on protein data and analytical methods.",
"parameters": {
"type": "object",
"properties": {
"gene_name": { # this will be mapped to "protein_id" when calling the function
"protein_id": { # LLM will provide gene_name, mapping to protein_id is done when calling the function
"type": "string",
"enum": gene_names,
"description": "Identifier for the gene of interest",
Expand Down Expand Up @@ -125,7 +132,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
"name": "perform_dimensionality_reduction",
"name": DataSet.perform_dimensionality_reduction.__name__,
"description": "Perform dimensionality reduction on a given dataset and generate a plot.",
"parameters": {
"type": "object",
Expand All @@ -152,7 +159,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
"name": "plot_sampledistribution",
"name": DataSet.plot_sampledistribution.__name__,
"description": "Generates a histogram plot for each sample in the dataset matrix.",
"parameters": {
"type": "object",
Expand All @@ -175,7 +182,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
"name": "plot_volcano",
"name": DataSet.plot_volcano.__name__,
"description": "Generates a volcano plot based on two subgroups of the same group",
"parameters": {
"type": "object",
Expand Down Expand Up @@ -235,17 +242,3 @@ def get_assistant_functions(
},
# {"type": "code_interpreter"},
]


def perform_dimensionality_reduction(dataset, group, method, circle, **kwargs):
dr = DimensionalityReduction(
mat=dataset.mat,
metadate=dataset.metadata,
sample=dataset.sample,
preprocessing_info=dataset.preprocessing_info,
group=group,
circle=circle,
method=method,
**kwargs,
)
return dr.plot
Loading
Loading