diff --git a/alphastats/DataSet.py b/alphastats/DataSet.py
index b845a72e..32a0d6ae 100644
--- a/alphastats/DataSet.py
+++ b/alphastats/DataSet.py
@@ -311,6 +311,27 @@ def plot_umap(self, group: Optional[str] = None, circle: bool = False):
)
return dimensionality_reduction.plot
+ def perform_dimensionality_reduction(
+ self, method: str, group: Optional[str] = None, circle: bool = False
+ ):
+ """Generic wrapper for dimensionality reduction methods to be used by LLM.
+
+ Args:
+ method (str): "pca", "tsne", "umap"
+ group (str, optional): column in metadata that should be used for coloring. Defaults to None.
+ circle (bool, optional): draw circle around each group. Defaults to False.
+ """
+
+ result = {
+ "pca": self.plot_pca,
+ "tsne": self.plot_tsne,
+ "umap": self.plot_umap,
+ }.get(method)
+ if result is None:
+ raise ValueError(f"Invalid method: {method}")
+
+ return result(group=group, circle=circle)
+
@ignore_warning(RuntimeWarning)
def plot_volcano(
self,
diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/04_Analysis.py
index a6f70d0c..8188d28f 100644
--- a/alphastats/gui/pages/04_Analysis.py
+++ b/alphastats/gui/pages/04_Analysis.py
@@ -89,7 +89,7 @@ def show_start_llm_button(method: str) -> None:
msg = (
"(this will overwrite the existing LLM analysis!)"
- if StateKeys.LLM_INTEGRATION in st.session_state
+ if st.session_state.get(StateKeys.LLM_INTEGRATION, {}) != {}
else ""
)
diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py
index 6dcf2919..0e337a2e 100644
--- a/alphastats/gui/pages/05_LLM.py
+++ b/alphastats/gui/pages/05_LLM.py
@@ -8,9 +8,9 @@
display_figure,
)
from alphastats.gui.utils.llm_helper import (
- display_proteins,
+ get_display_proteins_html,
+ llm_connection_test,
set_api_key,
- test_llm_connection,
)
from alphastats.gui.utils.ui_helper import StateKeys, init_session_state, sidebar_info
from alphastats.llm.llm_integration import LLMIntegration, Models
@@ -36,20 +36,24 @@ def llm_config():
"""Show the configuration options for the LLM analysis."""
c1, _ = st.columns((1, 2))
with c1:
- model_before = st.session_state.get(StateKeys.API_TYPE, None)
+ current_model = st.session_state.get(StateKeys.MODEL_NAME, None)
- st.session_state[StateKeys.API_TYPE] = st.selectbox(
+ models = [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]
+ st.session_state[StateKeys.MODEL_NAME] = st.selectbox(
"Select LLM",
- [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B],
+ models,
+ index=models.index(st.session_state.get(StateKeys.MODEL_NAME))
+ if current_model is not None
+ else 0,
)
base_url = None
- if st.session_state[StateKeys.API_TYPE] in [Models.GPT4O]:
+ if st.session_state[StateKeys.MODEL_NAME] in [Models.GPT4O]:
api_key = st.text_input(
"Enter OpenAI API Key and press Enter", type="password"
)
set_api_key(api_key)
- elif st.session_state[StateKeys.API_TYPE] in [
+ elif st.session_state[StateKeys.MODEL_NAME] in [
Models.OLLAMA_31_70B,
Models.OLLAMA_31_8B,
]:
@@ -58,13 +62,18 @@ def llm_config():
test_connection = st.button("Test connection")
if test_connection:
- test_llm_connection(
- api_type=st.session_state[StateKeys.API_TYPE],
- api_key=st.session_state[StateKeys.OPENAI_API_KEY],
- base_url=base_url,
- )
-
- if model_before != st.session_state[StateKeys.API_TYPE]:
+ with st.spinner(f"Testing connection to {model_name}.."):
+ error = llm_connection_test(
+ model_name=st.session_state[StateKeys.MODEL_NAME],
+ api_key=st.session_state[StateKeys.OPENAI_API_KEY],
+ base_url=base_url,
+ )
+ if error is None:
+ st.success(f"Connection to {model_name} successful!")
+ else:
+ st.error(f"Connection to {model_name} failed: {str(error)}")
+
+ if current_model != st.session_state[StateKeys.MODEL_NAME]:
st.rerun(scope="app")
@@ -120,17 +129,24 @@ def llm_config():
c1, c2 = st.columns((1, 2), gap="medium")
with c1:
st.write("Upregulated genes")
- display_proteins(upregulated_genes, [])
+ st.markdown(
+ get_display_proteins_html(upregulated_genes, True), unsafe_allow_html=True
+ )
+
with c2:
st.write("Downregulated genes")
- display_proteins([], downregulated_genes)
+ st.markdown(
+ get_display_proteins_html(downregulated_genes, False),
+ unsafe_allow_html=True,
+ )
st.markdown("##### Prompts generated based on analysis input")
-api_type = st.session_state[StateKeys.API_TYPE]
+model_name = st.session_state[StateKeys.MODEL_NAME]
llm_integration_set_for_model = (
- st.session_state.get(StateKeys.LLM_INTEGRATION, {}).get(api_type, None) is not None
+ st.session_state.get(StateKeys.LLM_INTEGRATION, {}).get(model_name, None)
+ is not None
)
with st.expander("System message", expanded=False):
system_message = st.text_area(
@@ -151,19 +167,28 @@ def llm_config():
)
-st.markdown(f"##### LLM Analysis with {api_type}")
+st.markdown(f"##### LLM Analysis with {model_name}")
-llm_submitted = st.button(
+c1, c2, _ = st.columns((0.2, 0.2, 0.6))
+llm_submitted = c1.button(
"Run LLM analysis ...", disabled=llm_integration_set_for_model
)
-if st.session_state[StateKeys.LLM_INTEGRATION].get(api_type) is None:
+llm_reset = c2.button(
+ "Reset LLM analysis ...", disabled=not llm_integration_set_for_model
+)
+if llm_reset:
+ del st.session_state[StateKeys.LLM_INTEGRATION]
+ st.rerun()
+
+
+if st.session_state[StateKeys.LLM_INTEGRATION].get(model_name) is None:
if not llm_submitted:
st.stop()
try:
llm_integration = LLMIntegration(
- api_type=api_type,
+ model_name=model_name,
system_message=system_message,
api_key=st.session_state[StateKeys.OPENAI_API_KEY],
base_url=OLLAMA_BASE_URL,
@@ -171,10 +196,10 @@ def llm_config():
gene_to_prot_id_map=gene_to_prot_id_map,
)
- st.session_state[StateKeys.LLM_INTEGRATION][api_type] = llm_integration
+ st.session_state[StateKeys.LLM_INTEGRATION][model_name] = llm_integration
st.success(
- f"{st.session_state[StateKeys.API_TYPE]} integration initialized successfully!"
+ f"{st.session_state[StateKeys.MODEL_NAME]} integration initialized successfully!"
)
with st.spinner("Processing initial prompt..."):
@@ -221,4 +246,4 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
help="Show all messages in the chat interface.",
)
-llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][api_type], show_all)
+llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][model_name], show_all)
diff --git a/alphastats/gui/utils/llm_helper.py b/alphastats/gui/utils/llm_helper.py
index 33474a16..028abe20 100644
--- a/alphastats/gui/utils/llm_helper.py
+++ b/alphastats/gui/utils/llm_helper.py
@@ -7,32 +7,24 @@
from alphastats.llm.llm_integration import LLMIntegration
-def display_proteins(overexpressed: List[str], underexpressed: List[str]) -> None:
+def get_display_proteins_html(protein_ids: List[str], is_upregulated: True) -> str:
"""
- Display a list of overexpressed and underexpressed proteins in a Streamlit app.
+ Get HTML code for displaying a list of proteins, color according to expression.
Args:
- overexpressed (list[str]): A list of overexpressed proteins.
- underexpressed (list[str]): A list of underexpressed proteins.
+ protein_ids (list[str]): a list of proteins.
+ is_upregulated (bool): whether the proteins are up- or down-regulated.
"""
- # Start with the overexpressed proteins
- link = "https://www.uniprot.org/uniprotkb?query="
- overexpressed_html = "".join(
- f'{protein}'
- for protein in overexpressed
- )
- # Continue with the underexpressed proteins
- underexpressed_html = "".join(
- f'{protein}'
- for protein in underexpressed
- )
+ uniprot_url = "https://www.uniprot.org/uniprotkb?query="
- # Combine both lists into one HTML string
- full_html = f"
{overexpressed_html}{underexpressed_html}
"
+ color = "green" if is_upregulated else "red"
+ protein_ids_html = "".join(
+ f'{protein}'
+ for protein in protein_ids
+ )
- # Display in Streamlit
- st.markdown(full_html, unsafe_allow_html=True)
+ return f""
def set_api_key(api_key: str = None) -> None:
@@ -71,21 +63,20 @@ def set_api_key(api_key: str = None) -> None:
st.session_state[StateKeys.OPENAI_API_KEY] = api_key
-def test_llm_connection(
- api_type: str,
+def llm_connection_test(
+ model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
-):
- """Test the connection to the LLM API."""
+) -> Optional[str]:
+ """Test the connection to the LLM API, return None in case of success, error message otherwise."""
try:
- llm = LLMIntegration(api_type, base_url=base_url, api_key=api_key)
-
- with st.spinner(f"Testing connection to {api_type} ..."):
- llm.chat_completion("Hello, this is a test!")
-
- st.success(f"Connection to {api_type} successful!")
- return True
+ llm = LLMIntegration(
+ model_name, base_url=base_url, api_key=api_key, load_tools=False
+ )
+ llm.chat_completion(
+ "This is a test. Simply respond 'yes' if you got this message."
+ )
+ return None
except Exception as e:
- st.error(f"❌ Connection to {api_type} failed: {e}")
- return False
+ return str(e)
diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py
index 38e0027b..0b3170a5 100644
--- a/alphastats/gui/utils/ui_helper.py
+++ b/alphastats/gui/utils/ui_helper.py
@@ -108,7 +108,7 @@ class StateKeys:
# LLM
OPENAI_API_KEY = "openai_api_key" # pragma: allowlist secret
- API_TYPE = "api_type"
+ MODEL_NAME = "model_name"
LLM_INPUT = "llm_input"
diff --git a/alphastats/llm/llm_functions.py b/alphastats/llm/llm_functions.py
index f6738438..90cebb6b 100644
--- a/alphastats/llm/llm_functions.py
+++ b/alphastats/llm/llm_functions.py
@@ -4,7 +4,14 @@
import pandas as pd
-from alphastats.plots.DimensionalityReduction import DimensionalityReduction
+from alphastats.DataSet import DataSet
+from alphastats.llm.enrichment_analysis import get_enrichment_data
+from alphastats.llm.uniprot_utils import get_gene_function
+
+GENERAL_FUNCTION_MAPPING = {
+ "get_gene_function": get_gene_function,
+ "get_enrichment_data": get_enrichment_data,
+}
def get_general_assistant_functions() -> List[Dict]:
@@ -17,7 +24,7 @@ def get_general_assistant_functions() -> List[Dict]:
{
"type": "function",
"function": {
- "name": "get_gene_function",
+ "name": get_gene_function.__name__,
"description": "Get the gene function and description by UniProt lookup of gene identifier/name",
"parameters": {
"type": "object",
@@ -34,7 +41,7 @@ def get_general_assistant_functions() -> List[Dict]:
{
"type": "function",
"function": {
- "name": "get_enrichment_data",
+ "name": get_enrichment_data.__name__,
"description": "Get enrichment data for a list of differentially expressed genes",
"parameters": {
"type": "object",
@@ -83,12 +90,12 @@ def get_assistant_functions(
{
"type": "function",
"function": {
- "name": "plot_intensity",
+ "name": DataSet.plot_intensity.__name__,
"description": "Create an intensity plot based on protein data and analytical methods.",
"parameters": {
"type": "object",
"properties": {
- "gene_name": { # this will be mapped to "protein_id" when calling the function
+ "protein_id": { # LLM will provide gene_name, mapping to protein_id is done when calling the function
"type": "string",
"enum": gene_names,
"description": "Identifier for the gene of interest",
@@ -125,7 +132,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
- "name": "perform_dimensionality_reduction",
+ "name": DataSet.perform_dimensionality_reduction.__name__,
"description": "Perform dimensionality reduction on a given dataset and generate a plot.",
"parameters": {
"type": "object",
@@ -152,7 +159,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
- "name": "plot_sampledistribution",
+ "name": DataSet.plot_sampledistribution.__name__,
"description": "Generates a histogram plot for each sample in the dataset matrix.",
"parameters": {
"type": "object",
@@ -175,7 +182,7 @@ def get_assistant_functions(
{
"type": "function",
"function": {
- "name": "plot_volcano",
+ "name": DataSet.plot_volcano.__name__,
"description": "Generates a volcano plot based on two subgroups of the same group",
"parameters": {
"type": "object",
@@ -235,17 +242,3 @@ def get_assistant_functions(
},
# {"type": "code_interpreter"},
]
-
-
-def perform_dimensionality_reduction(dataset, group, method, circle, **kwargs):
- dr = DimensionalityReduction(
- mat=dataset.mat,
- metadate=dataset.metadata,
- sample=dataset.sample,
- preprocessing_info=dataset.preprocessing_info,
- group=group,
- circle=circle,
- method=method,
- **kwargs,
- )
- return dr.plot
diff --git a/alphastats/llm/llm_integration.py b/alphastats/llm/llm_integration.py
index 60392cb2..c866e7f3 100644
--- a/alphastats/llm/llm_integration.py
+++ b/alphastats/llm/llm_integration.py
@@ -11,23 +11,26 @@
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from alphastats.DataSet import DataSet
-from alphastats.llm.enrichment_analysis import get_enrichment_data
from alphastats.llm.llm_functions import (
+ GENERAL_FUNCTION_MAPPING,
get_assistant_functions,
get_general_assistant_functions,
- perform_dimensionality_reduction,
)
from alphastats.llm.llm_utils import (
get_protein_id_for_gene_name,
get_subgroups_for_each_group,
)
from alphastats.llm.prompts import get_tool_call_message
-from alphastats.llm.uniprot_utils import get_gene_function
logger = logging.getLogger(__name__)
class Models:
+ """Names of the available models.
+
+ Note that this will be directly passed to the OpenAI client.
+ """
+
GPT4O = "gpt-4o"
OLLAMA_31_8B = "llama3.1:8b"
OLLAMA_31_70B = "llama3.1:70b" # for testing only
@@ -41,7 +44,7 @@ class LLMIntegration:
Parameters
----------
- api_type : str
+ model_name : str
The type of API to use, will be forwarded to the client.
system_message : str
The system message that should be given to the model.
@@ -57,34 +60,34 @@ class LLMIntegration:
def __init__(
self,
- api_type: str,
+ model_name: str,
*,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
system_message: str = None,
+ load_tools: bool = True,
dataset: Optional[DataSet] = None,
gene_to_prot_id_map: Optional[Dict[str, str]] = None,
):
- self._model = api_type
+ self._model = model_name
- if api_type in [Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]:
+ if model_name in [Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]:
url = f"{base_url}/v1"
self._client = OpenAI(base_url=url, api_key="ollama")
- elif api_type in [Models.GPT4O]:
+ elif model_name in [Models.GPT4O]:
self._client = OpenAI(api_key=api_key)
else:
- raise ValueError(f"Invalid API type: {api_type}")
+ raise ValueError(f"Invalid model name: {model_name}")
self._dataset = dataset
self._metadata = None if dataset is None else dataset.metadata
self._gene_to_prot_id_map = gene_to_prot_id_map
- self._tools = self._get_tools()
+ self._tools = self._get_tools() if load_tools else None
+ self._artifacts = {}
self._messages = [] # the conversation history used for the LLM, could be truncated at some point.
self._all_messages = [] # full conversation history for display
- self._artifacts = {}
-
if system_message is not None:
self._append_message("system", system_message)
@@ -190,47 +193,36 @@ def _execute_function(
Any
The result of the function execution
"""
- try:
- # first try to find the function in the non-Dataset functions
- if (
- function := {
- "get_gene_function": get_gene_function,
- "get_enrichment_data": get_enrichment_data,
- }.get(function_name)
- ) is not None:
- return function(**function_args)
-
- # special treatment for these functions
- elif function_name == "perform_dimensionality_reduction":
- # TODO add API in dataset
- perform_dimensionality_reduction(self._dataset, **function_args)
-
- elif function_name == "plot_intensity":
- # TODO move this logic to dataset
- gene_name = function_args.pop("gene_name")
- protein_id = get_protein_id_for_gene_name(
- gene_name, self._gene_to_prot_id_map
- )
- function_args["protein_id"] = protein_id
+ # first try to find the function in the non-Dataset functions
+ if (function := GENERAL_FUNCTION_MAPPING.get(function_name)) is not None:
+ return function(**function_args)
+
+ # special treatment for this function
+ elif function_name == "plot_intensity":
+ # TODO move this logic to dataset
+ gene_name = function_args.pop(
+ "protein_id"
+ ) # no typo, the LLM sets "protein_id" to gene_name
+ protein_id = get_protein_id_for_gene_name(
+ gene_name, self._gene_to_prot_id_map
+ )
+ function_args["protein_id"] = protein_id
- return self._dataset.plot_intensity(**function_args)
+ return self._dataset.plot_intensity(**function_args)
- # fallback: try to find the function in the Dataset functions
- else:
- function = getattr(
- self._dataset,
- function_name.split(".")[-1],
- None, # TODO why split?
- )
- if function:
- return function(**function_args)
-
- raise ValueError(
- f"Function {function_name} not implemented or dataset not available"
+ # look up the function in the DataSet class
+ else:
+ function = getattr(
+ self._dataset,
+ function_name.split(".")[-1],
+ None, # TODO why split?
)
+ if function:
+ return function(**function_args)
- except Exception as e:
- return f"Error executing {function_name}: {str(e)}"
+ raise ValueError(
+ f"Function {function_name} not implemented or dataset not available"
+ )
def _handle_function_calls(
self,
@@ -260,7 +252,11 @@ def _handle_function_calls(
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
- function_result = self._execute_function(function_name, function_args)
+ try:
+ function_result = self._execute_function(function_name, function_args)
+ except Exception as e:
+ function_result = f"Error executing {function_name}: {str(e)}"
+
artifact_id = f"{function_name}_{tool_call.id}"
new_artifacts[artifact_id] = function_result
@@ -272,15 +268,20 @@ def _handle_function_calls(
post_artifact_message_idx = len(self._all_messages)
self._artifacts[post_artifact_message_idx] = new_artifacts.values()
- logger.info(f"Calling 'chat.completions.create' {self._messages[-1]=} ..")
- response = self._client.chat.completions.create(
+ response = self._chat_completion_create()
+
+ return self._parse_model_response(response)
+
+ def _chat_completion_create(self) -> ChatCompletion:
+ """Create a chat completion based on the current conversation history."""
+ logger.info(f"Calling 'chat.completions.create' {self._messages[-1]} ..")
+ result = self._client.chat.completions.create(
model=self._model,
messages=self._messages,
tools=self._tools,
)
logger.info(".. done")
-
- return self._parse_model_response(response)
+ return result
def get_print_view(self, show_all=False) -> List[Dict[str, Any]]:
"""Get a structured view of the conversation history for display purposes."""
@@ -301,9 +302,7 @@ def get_print_view(self, show_all=False) -> List[Dict[str, Any]]:
)
return print_view
- def chat_completion(
- self, prompt: str, role: str = "user"
- ) -> Tuple[str, Dict[str, Any]]:
+ def chat_completion(self, prompt: str, role: str = "user") -> None:
"""
Generate a chat completion based on the given prompt and manage any resulting artifacts.
@@ -322,13 +321,7 @@ def chat_completion(
self._append_message(role, prompt)
try:
- logger.info(f"Calling 'chat.completions.create' {self._messages[-1]} ..")
- response = self._client.chat.completions.create(
- model=self._model,
- messages=self._messages,
- tools=self._tools,
- )
- logger.info(".. done")
+ response = self._chat_completion_create()
content, tool_calls = self._parse_model_response(response)
@@ -345,7 +338,6 @@ def chat_completion(
except ArithmeticError as e:
error_message = f"Error in chat completion: {str(e)}"
self._append_message("system", error_message)
- return error_message, {}
# TODO this seems to be for notebooks?
# we need some "export mode" where everything is shown
diff --git a/tests/llm/__init__.py b/tests/llm/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/llm/test_llm_functions.py b/tests/llm/test_llm_functions.py
new file mode 100644
index 00000000..c0424fa8
--- /dev/null
+++ b/tests/llm/test_llm_functions.py
@@ -0,0 +1,95 @@
+"""Test that the function definitions in the LLM match the actual functions:
+
+- all parameters defined for the LLM are present in the function definitions
+- all non-default parameters in the function definitions are required for the LLM
+"""
+
+import inspect
+from typing import Callable, Dict
+
+import pandas as pd
+
+from alphastats.DataSet import DataSet
+from alphastats.llm.llm_functions import (
+ GENERAL_FUNCTION_MAPPING,
+ get_assistant_functions,
+ get_general_assistant_functions,
+)
+
+
+def _get_method_parameters(method: Callable) -> Dict:
+ """Get the parameters of a method as a dictionary, excluding "self"."""
+ signature = inspect.signature(method)
+
+ params = dict(signature.parameters)
+ if "self" in params:
+ del params["self"]
+
+ return params
+
+
+def _get_class_methods(cls):
+ """Get all methods of a class as a dictionary, excluding "self"."""
+ return {
+ member[0]: member[1]
+ for member in inspect.getmembers(cls, predicate=inspect.isfunction)
+ if member[0] != "self"
+ }
+
+
+def assert_parameters(method_definition: Callable, llm_function_dict_):
+ """Assert that the parameters of a method match the parameters defined in the dict used by the LLM."""
+
+ # suffix '_' denotes LLM-related variables
+ parameters = _get_method_parameters(method_definition)
+ parameters_without_default = [
+ param
+ for param in parameters
+ if parameters[param].default == inspect.Parameter.empty
+ ]
+
+ parameters_dict_ = llm_function_dict_["function"]["parameters"]
+ parameters_ = parameters_dict_["properties"].keys()
+
+ # are all in parameters_ available in the function?
+ assert set(parameters_).issubset(set(parameters))
+
+ # are all the parameters w/o default in the function filled in parameters_?
+ assert set(parameters_without_default).issubset(set(parameters_))
+
+ # are all required parameters marked as 'required'?
+ assert set(parameters_without_default).issubset(set(parameters_dict_["required"]))
+
+
+def test_general_assistant_functions():
+ """Test that the general assistant functions in the LLM match the actual functions."""
+ # suffix '_' denotes LLM-related variables
+ assistant_functions_dict = get_general_assistant_functions()
+
+ for llm_function_dict_ in assistant_functions_dict:
+ name_ = llm_function_dict_["function"]["name"]
+
+ method_definition = GENERAL_FUNCTION_MAPPING.get(name_, None)
+
+ if method_definition is None:
+ raise ValueError(f"Function not found in test: {name_}")
+
+ assert_parameters(method_definition, llm_function_dict_)
+
+
+def test_assistant_functions():
+ """Test that the assistant functions in the LLM match the actual functions."""
+ # suffix '_' denotes LLM-related variables
+ assistant_functions_dict = get_assistant_functions({}, pd.DataFrame(), {})
+
+ all_dataset_methods = _get_class_methods(DataSet)
+
+ for llm_function_dict_ in assistant_functions_dict:
+ name_ = llm_function_dict_["function"]["name"]
+
+ method_definition = all_dataset_methods.get(name_, None)
+
+ if method_definition is None:
+ raise ValueError(f"Function not found in test: {name_}")
+
+ assert_parameters(method_definition, llm_function_dict_)
diff --git a/tests/llm/test_llm_helper.py b/tests/llm/test_llm_helper.py
new file mode 100644
index 00000000..5f09534e
--- /dev/null
+++ b/tests/llm/test_llm_helper.py
@@ -0,0 +1,136 @@
+from unittest.mock import patch
+
+import pytest
+
+from alphastats.gui.utils.llm_helper import (
+ get_display_proteins_html,
+ llm_connection_test,
+ set_api_key,
+)
+from alphastats.gui.utils.ui_helper import StateKeys
+
+
+@pytest.fixture
+def mock_streamlit():
+ """Fixture to mock streamlit module."""
+ with patch("streamlit.info") as mock_info, patch(
+ "streamlit.error"
+ ) as mock_error, patch("streamlit.success") as mock_success, patch(
+ "streamlit.session_state", {}
+ ) as mock_session_state:
+ yield {
+ "info": mock_info,
+ "error": mock_error,
+ "success": mock_success,
+ "session_state": mock_session_state,
+ }
+
+
+def test_display_proteins_upregulated(mock_streamlit):
+ """Test displaying upregulated proteins."""
+ protein_ids = ["P12345", "Q67890"]
+ result = get_display_proteins_html(protein_ids, is_upregulated=True)
+
+ expected_html = (
+ "'
+ )
+
+ assert result == expected_html
+
+
+def test_display_proteins_downregulated(mock_streamlit):
+ """Test displaying downregulated proteins."""
+ protein_ids = ["P12345"]
+ result = get_display_proteins_html(protein_ids, is_upregulated=False)
+
+ expected_html = (
+ "'
+ )
+
+ assert result == expected_html
+
+
+def test_display_proteins_empty_list(mock_streamlit):
+ """Test displaying empty protein list."""
+ assert get_display_proteins_html([], is_upregulated=True) == ""
+
+
+@pytest.mark.parametrize(
+ "api_key,expected_message",
+ [
+ ("abc123xyz", "OpenAI API key set: abc***xyz"),
+ (
+ None,
+ "Please enter an OpenAI key or provide it in a secrets.toml file in the alphastats/gui/.streamlit directory like `openai_api_key = `",
+ ),
+ ],
+)
+def test_set_api_key_direct(mock_streamlit, api_key, expected_message):
+ """Test setting API key directly."""
+ set_api_key(api_key)
+
+ if api_key:
+ mock_streamlit["info"].assert_called_once_with(expected_message)
+ assert mock_streamlit["session_state"][StateKeys.OPENAI_API_KEY] == api_key
+ else:
+ mock_streamlit["info"].assert_called_with(expected_message)
+
+
+@patch("streamlit.secrets")
+@patch("pathlib.Path.exists")
+def test_set_api_key_from_secrets(mock_exists, mock_st_secrets, mock_streamlit):
+ """Test loading API key from secrets.toml."""
+ mock_exists.return_value = True
+
+ mock_st_secrets.__getitem__.return_value = (
+ "test_secret_key" # pragma: allowlist secret
+ )
+
+ set_api_key()
+
+ mock_streamlit["info"].assert_called_with(
+ "OpenAI API key loaded from secrets.toml."
+ )
+ assert (
+ mock_streamlit["session_state"][StateKeys.OPENAI_API_KEY]
+ == "test_secret_key" # pragma: allowlist secret
+ )
+
+
+@patch("pathlib.Path.exists")
+def test_set_api_key_missing_secrets(mock_exists, mock_streamlit):
+ """Test handling missing secrets.toml."""
+ mock_exists.return_value = False
+
+ set_api_key()
+
+ mock_streamlit["info"].assert_called_with(
+ "Please enter an OpenAI key or provide it in a secrets.toml file in the "
+ "alphastats/gui/.streamlit directory like `openai_api_key = `"
+ )
+
+
+@patch("alphastats.gui.utils.llm_helper.LLMIntegration")
+def test_llm_connection_test_success(mock_llm):
+ """Test successful LLM connection."""
+ assert llm_connection_test("some_model") is None
+
+ mock_llm.assert_called_once_with(
+ "some_model", base_url=None, api_key=None, load_tools=False
+ )
+
+
+@patch("alphastats.gui.utils.llm_helper.LLMIntegration")
+def test_llm_connection_test_failure(mock_llm, mock_streamlit):
+ """Test failed LLM connection."""
+ mock_llm.return_value.chat_completion.side_effect = ValueError("API Error")
+
+ assert llm_connection_test("some_model") == "API Error"
+
+ mock_llm.assert_called_once_with(
+ "some_model", base_url=None, api_key=None, load_tools=False
+ )
diff --git a/tests/llm/test_llm_integration.py b/tests/llm/test_llm_integration.py
new file mode 100644
index 00000000..9074bf3c
--- /dev/null
+++ b/tests/llm/test_llm_integration.py
@@ -0,0 +1,471 @@
+from unittest import skip
+from unittest.mock import Mock, patch
+
+import pandas as pd
+import pytest
+from openai.types.chat import (
+ ChatCompletion,
+ ChatCompletionMessage,
+ ChatCompletionMessageToolCall,
+)
+from openai.types.chat.chat_completion_message_tool_call import Function
+
+from alphastats.llm.llm_integration import LLMIntegration, Models
+
+
+@pytest.fixture
+def mock_openai_client():
+ with patch("alphastats.llm.llm_integration.OpenAI") as mock_client:
+ yield mock_client
+
+
+@pytest.fixture
+def llm_integration(mock_openai_client):
+ """Fixture providing a basic LLM instance with test configuration"""
+ dataset = Mock()
+ dataset.plot_intensity = Mock(return_value="Plot created")
+ dataset.custom_function = Mock(return_value="Dataset function called")
+ dataset.metadata = pd.DataFrame({"group1": ["A", "B"], "group2": ["C", "D"]})
+ return LLMIntegration(
+ model_name=Models.GPT4O,
+ api_key="test-key", # pragma: allowlist secret
+ system_message="Test system message",
+ dataset=dataset,
+ gene_to_prot_id_map={"GENE1": "PROT1", "GENE2": "PROT2"},
+ )
+
+
+@pytest.fixture
+def llm_with_conversation(llm_integration):
+ """Setup LLM with a sample conversation history"""
+ # Add various message types to conversation history
+ llm_integration._all_messages = [
+ {"role": "system", "content": "System message"},
+ {"role": "user", "content": "User message 1"},
+ {"role": "assistant", "content": "Assistant message 1"},
+ {
+ "role": "assistant",
+ "content": "Assistant with tool calls",
+ "tool_calls": [
+ {"id": "123", "type": "function", "function": {"name": "test"}}
+ ],
+ },
+ {"role": "tool", "content": "Tool response"},
+ {"role": "user", "content": "User message 2"},
+ {"role": "assistant", "content": "Assistant message 2"},
+ ]
+
+ # Add some artifacts
+ llm_integration._artifacts = {
+ 2: ["Artifact for message 2"],
+ 4: ["Tool artifact 1", "Tool artifact 2"],
+ 6: ["Artifact for message 6"],
+ }
+
+ return llm_integration
+
+
+@pytest.fixture
+def mock_chat_completion():
+ """Fixture providing a mock successful chat completion"""
+ mock_response = Mock(spec=ChatCompletion)
+ mock_response.choices = [
+ Mock(
+ message=ChatCompletionMessage(
+ role="assistant", content="Test response", tool_calls=None
+ )
+ )
+ ]
+ return mock_response
+
+
+@pytest.fixture
+def mock_tool_call_completion():
+ """Fixture providing a mock completion with tool calls"""
+ mock_response = Mock(spec=ChatCompletion)
+ mock_response.choices = [
+ Mock(
+ message=ChatCompletionMessage(
+ role="assistant",
+ content="",
+ tool_calls=[
+ ChatCompletionMessageToolCall(
+ id="test-id",
+ type="function",
+ function={"name": "test_function", "arguments": "{}"},
+ )
+ ],
+ )
+ )
+ ]
+ return mock_response
+
+
+@pytest.fixture
+def mock_general_function_mapping():
+ def mock_general_function(param1, param2):
+ """An example for a function."""
+ return f"General function called with {param1} and {param2}"
+
+ return {"test_general_function": mock_general_function}
+
+
+def test_initialization_gpt4(mock_openai_client):
+ """Test initialization with GPT-4 configuration"""
+ LLMIntegration(
+ model_name=Models.GPT4O,
+ api_key="test-key", # pragma: allowlist secret
+ )
+
+ mock_openai_client.assert_called_once_with(
+ api_key="test-key" # pragma: allowlist secret
+ )
+
+
+def test_initialization_ollama(mock_openai_client):
+ """Test initialization with Ollama configuration"""
+ LLMIntegration(
+ model_name=Models.OLLAMA_31_8B,
+ base_url="http://localhost:11434",
+ )
+
+ mock_openai_client.assert_called_once_with(
+ base_url="http://localhost:11434/v1",
+ api_key="ollama", # pragma: allowlist secret
+ )
+
+
+def test_initialization_invalid_model():
+ """Test initialization with invalid model type"""
+ with pytest.raises(ValueError, match="Invalid model name"):
+ LLMIntegration(model_name="invalid-model")
+
+
+def test_append_message(llm_integration):
+ """Test message appending functionality"""
+ llm_integration._append_message("user", "Test message")
+
+ assert len(llm_integration._messages) == 2 # Including system message
+ assert len(llm_integration._all_messages) == 2
+ assert llm_integration._messages[-1] == {"role": "user", "content": "Test message"}
+
+
+def test_append_message_with_tool_calls(llm_integration):
+ """Test message appending with tool calls"""
+ tool_calls = [
+ ChatCompletionMessageToolCall(
+ id="test-id",
+ type="function",
+ function={"name": "test_function", "arguments": "{}"},
+ )
+ ]
+
+ llm_integration._append_message("assistant", "Test message", tool_calls=tool_calls)
+
+ assert llm_integration._messages[-1]["tool_calls"] == tool_calls
+
+
+@pytest.mark.parametrize(
+ "num_messages,message_length,max_tokens,expected_messages",
+ [
+ (5, 100, 200, 2), # Should truncate to 2 messages
+ (3, 50, 1000, 3), # Should keep all messages
+ (10, 20, 100, 5), # Should truncate to 5 messages
+ ],
+)
+def test_truncate_conversation_history(
+ llm_integration, num_messages, message_length, max_tokens, expected_messages
+):
+ """Test conversation history truncation with different scenarios"""
+ # Add multiple messages
+ message_content = "Test " * message_length
+ for _ in range(num_messages):
+ llm_integration._append_message("user", message_content)
+
+ llm_integration._truncate_conversation_history(max_tokens=max_tokens)
+
+ # Adding 1 to account for the initial system message
+ assert len(llm_integration._messages) <= expected_messages + 1
+
+
+def test_chat_completion_success(llm_integration, mock_chat_completion):
+ """Test successful chat completion"""
+ llm_integration._client.chat.completions.create.return_value = mock_chat_completion
+
+ llm_integration.chat_completion("Test prompt")
+
+ assert llm_integration._messages == [
+ {
+ "content": "Test system message",
+ "role": "system",
+ },
+ {
+ "content": "Test prompt",
+ "role": "user",
+ },
+ {
+ "content": "Test response",
+ "role": "assistant",
+ },
+ ]
+
+
+def test_chat_completion_with_error(llm_integration):
+ """Test chat completion with error handling"""
+ llm_integration._client.chat.completions.create.side_effect = ArithmeticError(
+ "Test error"
+ )
+
+ llm_integration.chat_completion("Test prompt")
+
+ assert (
+ "Error in chat completion: Test error"
+ in llm_integration._messages[-1]["content"]
+ )
+
+
+def test_parse_model_response(llm_integration, mock_tool_call_completion):
+ """Test parsing model response with tool calls"""
+ content, tool_calls = llm_integration._parse_model_response(
+ mock_tool_call_completion
+ )
+
+ assert content == ""
+ assert len(tool_calls) == 1
+ assert tool_calls[0].id == "test-id"
+ assert tool_calls[0].type == "function"
+
+
+def test_chat_completion_with_content_and_tool_calls(llm_integration):
+ """Test that chat completion raises error when receiving both content and tool calls"""
+ mock_response = Mock(spec=ChatCompletion)
+ mock_response.choices = [
+ Mock(
+ message=ChatCompletionMessage(
+ role="assistant",
+ content="Some content",
+ tool_calls=[
+ ChatCompletionMessageToolCall(
+ id="test-id",
+ type="function",
+ function={"name": "test_function", "arguments": "{}"},
+ )
+ ],
+ )
+ )
+ ]
+ llm_integration._client.chat.completions.create.return_value = mock_response
+
+ with pytest.raises(ValueError, match="Unexpected content.*with tool calls"):
+ llm_integration.chat_completion("Test prompt")
+
+
+@pytest.mark.parametrize(
+ "function_name,function_args,expected_result",
+ [
+ (
+ "test_general_function",
+ {"param1": "value1", "param2": "value2"},
+ "General function called with value1 and value2",
+ ),
+ ],
+)
+def test_execute_general_function(
+ llm_integration,
+ mock_general_function_mapping,
+ function_name,
+ function_args,
+ expected_result,
+):
+ """Test execution of functions from GENERAL_FUNCTION_MAPPING"""
+ with patch(
+ "alphastats.llm.llm_integration.GENERAL_FUNCTION_MAPPING",
+ mock_general_function_mapping,
+ ):
+ result = llm_integration._execute_function(function_name, function_args)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "gene_name,plot_args,expected_protein_id",
+ [
+ ("GENE1", {"param1": "value1"}, "PROT1"),
+ ("GENE2", {"param2": "value2"}, "PROT2"),
+ ],
+)
+def test_execute_plot_intensity(
+ llm_integration, gene_name, plot_args, expected_protein_id
+):
+ """Test execution of plot_intensity with gene name translation"""
+ function_args = {"protein_id": gene_name, **plot_args}
+
+ result = llm_integration._execute_function("plot_intensity", function_args)
+
+ # Verify the dataset's plot_intensity was called with correct protein ID
+ llm_integration._dataset.plot_intensity.assert_called_once()
+ call_args = llm_integration._dataset.plot_intensity.call_args[1]
+ assert call_args["protein_id"] == expected_protein_id
+ assert result == "Plot created"
+
+
+def test_execute_dataset_function(llm_integration):
+ """Test execution of a function from the dataset"""
+ result = llm_integration._execute_function("custom_function", {"param1": "value1"})
+
+ assert result == "Dataset function called"
+ llm_integration._dataset.custom_function.assert_called_once_with(param1="value1")
+
+
+def test_execute_dataset_function_with_dots(llm_integration):
+ """Test execution of a dataset function when name contains dots"""
+ result = llm_integration._execute_function(
+ "dataset.custom_function", {"param1": "value1"}
+ )
+
+ assert result == "Dataset function called"
+ llm_integration._dataset.custom_function.assert_called_once_with(param1="value1")
+
+
+@skip # TODO fix this test
+def test_execute_nonexistent_function(llm_integration):
+ """Test execution of a non-existent function"""
+
+ result = llm_integration._execute_function(
+ "nonexistent_function", {"param1": "value1"}
+ )
+
+ assert "Error executing nonexistent_function" in result
+ assert "not implemented or dataset not available" in result
+
+
+def test_execute_function_with_error(llm_integration, mock_general_function_mapping):
+ """Test handling of function execution errors"""
+
+ def failing_function(**kwargs):
+ raise ValueError("Test error")
+
+ with patch(
+ "alphastats.llm.llm_integration.GENERAL_FUNCTION_MAPPING",
+ {"failing_function": failing_function},
+ ), pytest.raises(ValueError, match="Test error"):
+ llm_integration._execute_function("failing_function", {"param1": "value1"})
+
+
+def test_execute_function_without_dataset(mock_openai_client):
+ """Test function execution when dataset is not available"""
+ llm = LLMIntegration(model_name=Models.GPT4O, api_key="test-key")
+
+ with pytest.raises(
+ ValueError,
+ match="Function dataset_function not implemented or dataset not available",
+ ):
+ llm._execute_function("dataset_function", {"param1": "value1"})
+
+
+@patch("alphastats.llm.llm_integration.LLMIntegration._execute_function")
+def test_handle_function_calls(
+ mock_execute_function, mock_openai_client, mock_chat_completion
+):
+ """Test handling of function calls in the chat completion response."""
+ mock_execute_function.return_value = "some_function_result"
+
+ llm_integration = LLMIntegration(
+ model_name=Models.GPT4O,
+ api_key="test-key", # pragma: allowlist secret
+ system_message="Test system message",
+ )
+
+ tool_calls = [
+ ChatCompletionMessageToolCall(
+ id="test-id",
+ type="function",
+ function={"name": "test_function", "arguments": '{"arg1": "value1"}'},
+ )
+ ]
+
+ mock_openai_client.return_value.chat.completions.create.return_value = (
+ mock_chat_completion
+ )
+ result = llm_integration._handle_function_calls(tool_calls)
+
+ assert result == ("Test response", None)
+
+ mock_execute_function.assert_called_once_with("test_function", {"arg1": "value1"})
+
+ expected_messages = [
+ {"role": "system", "content": "Test system message"},
+ {
+ "role": "assistant",
+ "content": 'Calling function: test_function with arguments: {"arg1": "value1"}',
+ "tool_calls": [
+ ChatCompletionMessageToolCall(
+ id="test-id",
+ function=Function(
+ arguments='{"arg1": "value1"}', name="test_function"
+ ),
+ type="function",
+ )
+ ],
+ },
+ {
+ "role": "tool",
+ "content": '{"result": "some_function_result", "artifact_id": "test_function_test-id"}',
+ "tool_call_id": "test-id",
+ },
+ ]
+ mock_openai_client.return_value.chat.completions.create.assert_called_once_with(
+ model="gpt-4o", messages=expected_messages, tools=llm_integration._tools
+ )
+
+ assert list(llm_integration._artifacts[3]) == ["some_function_result"]
+
+ assert llm_integration._messages == expected_messages
+
+
+def test_get_print_view_default(llm_with_conversation):
+ """Test get_print_view with default settings (show_all=False)"""
+ print_view = llm_with_conversation.get_print_view()
+
+ # Should only include user and assistant messages without tool_calls
+ assert print_view == [
+ {"artifacts": [], "content": "User message 1", "role": "user"},
+ {
+ "artifacts": ["Artifact for message 2"],
+ "content": "Assistant message 1",
+ "role": "assistant",
+ },
+ {"artifacts": [], "content": "User message 2", "role": "user"},
+ {
+ "artifacts": ["Artifact for message 6"],
+ "content": "Assistant message 2",
+ "role": "assistant",
+ },
+ ]
+
+
+def test_get_print_view_show_all(llm_with_conversation):
+ """Test get_print_view with default settings (show_all=True)"""
+ print_view = llm_with_conversation.get_print_view(show_all=True)
+
+ # Should only include user and assistant messages without tool_calls
+ assert print_view == [
+ {"artifacts": [], "content": "System message", "role": "system"},
+ {"artifacts": [], "content": "User message 1", "role": "user"},
+ {
+ "artifacts": ["Artifact for message 2"],
+ "content": "Assistant message 1",
+ "role": "assistant",
+ },
+ {"artifacts": [], "content": "Assistant with tool calls", "role": "assistant"},
+ {
+ "artifacts": ["Tool artifact 1", "Tool artifact 2"],
+ "content": "Tool response",
+ "role": "tool",
+ },
+ {"artifacts": [], "content": "User message 2", "role": "user"},
+ {
+ "artifacts": ["Artifact for message 6"],
+ "content": "Assistant message 2",
+ "role": "assistant",
+ },
+ ]
diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py
new file mode 100644
index 00000000..1dc3ed43
--- /dev/null
+++ b/tests/llm/test_llm_utils.py
@@ -0,0 +1,100 @@
+import pandas as pd
+import pytest
+
+from alphastats.llm.llm_utils import (
+ get_protein_id_for_gene_name,
+ get_subgroups_for_each_group,
+)
+
+
+def test_get_subgroups_for_each_group_basic():
+ """Test basic functionality with simple metadata."""
+ # Create test metadata
+ data = {
+ "disease": ["cancer", "healthy", "cancer"],
+ "treatment": ["drug_a", "placebo", "drug_b"],
+ }
+ metadata = pd.DataFrame(data)
+
+ result = get_subgroups_for_each_group(metadata)
+
+ expected = {
+ "disease": ["cancer", "healthy"],
+ "treatment": ["drug_a", "placebo", "drug_b"],
+ }
+ assert result == expected
+
+
+def test_get_subgroups_for_each_group_empty():
+ """Test with empty DataFrame."""
+ metadata = pd.DataFrame()
+ result = get_subgroups_for_each_group(metadata)
+ assert result == {}
+
+
+def test_get_subgroups_for_each_group_single_column():
+ """Test with single column DataFrame."""
+ data = {"condition": ["A", "B", "A"]}
+ metadata = pd.DataFrame(data)
+
+ result = get_subgroups_for_each_group(metadata)
+
+ expected = {"condition": ["A", "B"]}
+ assert result == expected
+
+
+def test_get_subgroups_for_each_group_numeric_values():
+ """Test with numeric values in DataFrame."""
+ data = {"age_group": [20, 30, 20], "score": [1.5, 2.5, 1.5]}
+ metadata = pd.DataFrame(data)
+
+ result = get_subgroups_for_each_group(metadata)
+
+ expected = {"age_group": ["20", "30"], "score": ["1.5", "2.5"]}
+ assert result == expected
+
+
+@pytest.fixture
+def gene_to_prot_map():
+ """Fixture for protein mapping dictionary."""
+ return {
+ "VCL": "P18206",
+ "VCL;HEL114": "P18206;A0A024QZN4",
+ "MULTI;GENE": "PROT1;PROT2;PROT3",
+ }
+
+
+def test_get_protein_id_direct_match(gene_to_prot_map):
+ """Test when gene name directly matches a key."""
+ result = get_protein_id_for_gene_name("VCL", gene_to_prot_map)
+ assert result == "P18206"
+
+
+def test_get_protein_id_compound_key(gene_to_prot_map):
+ """Test when gene name is part of a compound key."""
+ result = get_protein_id_for_gene_name("HEL114", gene_to_prot_map)
+ assert result == "P18206;A0A024QZN4"
+
+
+def test_get_protein_id_not_found(gene_to_prot_map):
+ """Test when gene name is not found in mapping."""
+ result = get_protein_id_for_gene_name("UNKNOWN", gene_to_prot_map)
+ assert result == "UNKNOWN"
+
+
+def test_get_protein_id_empty_map():
+ """Test with empty mapping dictionary."""
+ result = get_protein_id_for_gene_name("VCL", {})
+ assert result == "VCL"
+
+
+def test_get_protein_id_multiple_matches(gene_to_prot_map):
+ """Test with a gene that appears in multiple compound keys."""
+ result = get_protein_id_for_gene_name("MULTI", gene_to_prot_map)
+ assert result == "PROT1;PROT2;PROT3"
+
+
+def test_get_protein_id_case_sensitivity(gene_to_prot_map):
+ """Test case sensitivity of gene name matching."""
+ result = get_protein_id_for_gene_name("vcl", gene_to_prot_map)
+ assert result == "vcl" # Should not match 'VCL' due to case sensitivity