diff --git a/alphastats/DataSet.py b/alphastats/DataSet.py index b845a72e..32a0d6ae 100644 --- a/alphastats/DataSet.py +++ b/alphastats/DataSet.py @@ -311,6 +311,27 @@ def plot_umap(self, group: Optional[str] = None, circle: bool = False): ) return dimensionality_reduction.plot + def perform_dimensionality_reduction( + self, method: str, group: Optional[str] = None, circle: bool = False + ): + """Generic wrapper for dimensionality reduction methods to be used by LLM. + + Args: + method (str): "pca", "tsne", "umap" + group (str, optional): column in metadata that should be used for coloring. Defaults to None. + circle (bool, optional): draw circle around each group. Defaults to False. + """ + + result = { + "pca": self.plot_pca, + "tsne": self.plot_tsne, + "umap": self.plot_umap, + }.get(method) + if result is None: + raise ValueError(f"Invalid method: {method}") + + return result(group=group, circle=circle) + @ignore_warning(RuntimeWarning) def plot_volcano( self, diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/04_Analysis.py index a6f70d0c..8188d28f 100644 --- a/alphastats/gui/pages/04_Analysis.py +++ b/alphastats/gui/pages/04_Analysis.py @@ -89,7 +89,7 @@ def show_start_llm_button(method: str) -> None: msg = ( "(this will overwrite the existing LLM analysis!)" - if StateKeys.LLM_INTEGRATION in st.session_state + if st.session_state.get(StateKeys.LLM_INTEGRATION, {}) != {} else "" ) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 6dcf2919..0e337a2e 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -8,9 +8,9 @@ display_figure, ) from alphastats.gui.utils.llm_helper import ( - display_proteins, + get_display_proteins_html, + llm_connection_test, set_api_key, - test_llm_connection, ) from alphastats.gui.utils.ui_helper import StateKeys, init_session_state, sidebar_info from alphastats.llm.llm_integration import LLMIntegration, Models @@ -36,20 +36,24 @@ def llm_config(): """Show the configuration options for the LLM analysis.""" c1, _ = st.columns((1, 2)) with c1: - model_before = st.session_state.get(StateKeys.API_TYPE, None) + current_model = st.session_state.get(StateKeys.MODEL_NAME, None) - st.session_state[StateKeys.API_TYPE] = st.selectbox( + models = [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B] + st.session_state[StateKeys.MODEL_NAME] = st.selectbox( "Select LLM", - [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B], + models, + index=models.index(st.session_state.get(StateKeys.MODEL_NAME)) + if current_model is not None + else 0, ) base_url = None - if st.session_state[StateKeys.API_TYPE] in [Models.GPT4O]: + if st.session_state[StateKeys.MODEL_NAME] in [Models.GPT4O]: api_key = st.text_input( "Enter OpenAI API Key and press Enter", type="password" ) set_api_key(api_key) - elif st.session_state[StateKeys.API_TYPE] in [ + elif st.session_state[StateKeys.MODEL_NAME] in [ Models.OLLAMA_31_70B, Models.OLLAMA_31_8B, ]: @@ -58,13 +62,18 @@ def llm_config(): test_connection = st.button("Test connection") if test_connection: - test_llm_connection( - api_type=st.session_state[StateKeys.API_TYPE], - api_key=st.session_state[StateKeys.OPENAI_API_KEY], - base_url=base_url, - ) - - if model_before != st.session_state[StateKeys.API_TYPE]: + with st.spinner(f"Testing connection to {model_name}.."): + error = llm_connection_test( + model_name=st.session_state[StateKeys.MODEL_NAME], + api_key=st.session_state[StateKeys.OPENAI_API_KEY], + base_url=base_url, + ) + if error is None: + st.success(f"Connection to {model_name} successful!") + else: + st.error(f"Connection to {model_name} failed: {str(error)}") + + if current_model != st.session_state[StateKeys.MODEL_NAME]: st.rerun(scope="app") @@ -120,17 +129,24 @@ def llm_config(): c1, c2 = st.columns((1, 2), gap="medium") with c1: st.write("Upregulated genes") - display_proteins(upregulated_genes, []) + st.markdown( + get_display_proteins_html(upregulated_genes, True), unsafe_allow_html=True + ) + with c2: st.write("Downregulated genes") - display_proteins([], downregulated_genes) + st.markdown( + get_display_proteins_html(downregulated_genes, False), + unsafe_allow_html=True, + ) st.markdown("##### Prompts generated based on analysis input") -api_type = st.session_state[StateKeys.API_TYPE] +model_name = st.session_state[StateKeys.MODEL_NAME] llm_integration_set_for_model = ( - st.session_state.get(StateKeys.LLM_INTEGRATION, {}).get(api_type, None) is not None + st.session_state.get(StateKeys.LLM_INTEGRATION, {}).get(model_name, None) + is not None ) with st.expander("System message", expanded=False): system_message = st.text_area( @@ -151,19 +167,28 @@ def llm_config(): ) -st.markdown(f"##### LLM Analysis with {api_type}") +st.markdown(f"##### LLM Analysis with {model_name}") -llm_submitted = st.button( +c1, c2, _ = st.columns((0.2, 0.2, 0.6)) +llm_submitted = c1.button( "Run LLM analysis ...", disabled=llm_integration_set_for_model ) -if st.session_state[StateKeys.LLM_INTEGRATION].get(api_type) is None: +llm_reset = c2.button( + "Reset LLM analysis ...", disabled=not llm_integration_set_for_model +) +if llm_reset: + del st.session_state[StateKeys.LLM_INTEGRATION] + st.rerun() + + +if st.session_state[StateKeys.LLM_INTEGRATION].get(model_name) is None: if not llm_submitted: st.stop() try: llm_integration = LLMIntegration( - api_type=api_type, + model_name=model_name, system_message=system_message, api_key=st.session_state[StateKeys.OPENAI_API_KEY], base_url=OLLAMA_BASE_URL, @@ -171,10 +196,10 @@ def llm_config(): gene_to_prot_id_map=gene_to_prot_id_map, ) - st.session_state[StateKeys.LLM_INTEGRATION][api_type] = llm_integration + st.session_state[StateKeys.LLM_INTEGRATION][model_name] = llm_integration st.success( - f"{st.session_state[StateKeys.API_TYPE]} integration initialized successfully!" + f"{st.session_state[StateKeys.MODEL_NAME]} integration initialized successfully!" ) with st.spinner("Processing initial prompt..."): @@ -221,4 +246,4 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False): help="Show all messages in the chat interface.", ) -llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][api_type], show_all) +llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][model_name], show_all) diff --git a/alphastats/gui/utils/llm_helper.py b/alphastats/gui/utils/llm_helper.py index 33474a16..028abe20 100644 --- a/alphastats/gui/utils/llm_helper.py +++ b/alphastats/gui/utils/llm_helper.py @@ -7,32 +7,24 @@ from alphastats.llm.llm_integration import LLMIntegration -def display_proteins(overexpressed: List[str], underexpressed: List[str]) -> None: +def get_display_proteins_html(protein_ids: List[str], is_upregulated: True) -> str: """ - Display a list of overexpressed and underexpressed proteins in a Streamlit app. + Get HTML code for displaying a list of proteins, color according to expression. Args: - overexpressed (list[str]): A list of overexpressed proteins. - underexpressed (list[str]): A list of underexpressed proteins. + protein_ids (list[str]): a list of proteins. + is_upregulated (bool): whether the proteins are up- or down-regulated. """ - # Start with the overexpressed proteins - link = "https://www.uniprot.org/uniprotkb?query=" - overexpressed_html = "".join( - f'
  • {protein}
  • ' - for protein in overexpressed - ) - # Continue with the underexpressed proteins - underexpressed_html = "".join( - f'
  • {protein}
  • ' - for protein in underexpressed - ) + uniprot_url = "https://www.uniprot.org/uniprotkb?query=" - # Combine both lists into one HTML string - full_html = f"" + color = "green" if is_upregulated else "red" + protein_ids_html = "".join( + f'
  • {protein}
  • ' + for protein in protein_ids + ) - # Display in Streamlit - st.markdown(full_html, unsafe_allow_html=True) + return f"" def set_api_key(api_key: str = None) -> None: @@ -71,21 +63,20 @@ def set_api_key(api_key: str = None) -> None: st.session_state[StateKeys.OPENAI_API_KEY] = api_key -def test_llm_connection( - api_type: str, +def llm_connection_test( + model_name: str, base_url: Optional[str] = None, api_key: Optional[str] = None, -): - """Test the connection to the LLM API.""" +) -> Optional[str]: + """Test the connection to the LLM API, return None in case of success, error message otherwise.""" try: - llm = LLMIntegration(api_type, base_url=base_url, api_key=api_key) - - with st.spinner(f"Testing connection to {api_type} ..."): - llm.chat_completion("Hello, this is a test!") - - st.success(f"Connection to {api_type} successful!") - return True + llm = LLMIntegration( + model_name, base_url=base_url, api_key=api_key, load_tools=False + ) + llm.chat_completion( + "This is a test. Simply respond 'yes' if you got this message." + ) + return None except Exception as e: - st.error(f"❌ Connection to {api_type} failed: {e}") - return False + return str(e) diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index 38e0027b..0b3170a5 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -108,7 +108,7 @@ class StateKeys: # LLM OPENAI_API_KEY = "openai_api_key" # pragma: allowlist secret - API_TYPE = "api_type" + MODEL_NAME = "model_name" LLM_INPUT = "llm_input" diff --git a/alphastats/llm/llm_functions.py b/alphastats/llm/llm_functions.py index f6738438..90cebb6b 100644 --- a/alphastats/llm/llm_functions.py +++ b/alphastats/llm/llm_functions.py @@ -4,7 +4,14 @@ import pandas as pd -from alphastats.plots.DimensionalityReduction import DimensionalityReduction +from alphastats.DataSet import DataSet +from alphastats.llm.enrichment_analysis import get_enrichment_data +from alphastats.llm.uniprot_utils import get_gene_function + +GENERAL_FUNCTION_MAPPING = { + "get_gene_function": get_gene_function, + "get_enrichment_data": get_enrichment_data, +} def get_general_assistant_functions() -> List[Dict]: @@ -17,7 +24,7 @@ def get_general_assistant_functions() -> List[Dict]: { "type": "function", "function": { - "name": "get_gene_function", + "name": get_gene_function.__name__, "description": "Get the gene function and description by UniProt lookup of gene identifier/name", "parameters": { "type": "object", @@ -34,7 +41,7 @@ def get_general_assistant_functions() -> List[Dict]: { "type": "function", "function": { - "name": "get_enrichment_data", + "name": get_enrichment_data.__name__, "description": "Get enrichment data for a list of differentially expressed genes", "parameters": { "type": "object", @@ -83,12 +90,12 @@ def get_assistant_functions( { "type": "function", "function": { - "name": "plot_intensity", + "name": DataSet.plot_intensity.__name__, "description": "Create an intensity plot based on protein data and analytical methods.", "parameters": { "type": "object", "properties": { - "gene_name": { # this will be mapped to "protein_id" when calling the function + "protein_id": { # LLM will provide gene_name, mapping to protein_id is done when calling the function "type": "string", "enum": gene_names, "description": "Identifier for the gene of interest", @@ -125,7 +132,7 @@ def get_assistant_functions( { "type": "function", "function": { - "name": "perform_dimensionality_reduction", + "name": DataSet.perform_dimensionality_reduction.__name__, "description": "Perform dimensionality reduction on a given dataset and generate a plot.", "parameters": { "type": "object", @@ -152,7 +159,7 @@ def get_assistant_functions( { "type": "function", "function": { - "name": "plot_sampledistribution", + "name": DataSet.plot_sampledistribution.__name__, "description": "Generates a histogram plot for each sample in the dataset matrix.", "parameters": { "type": "object", @@ -175,7 +182,7 @@ def get_assistant_functions( { "type": "function", "function": { - "name": "plot_volcano", + "name": DataSet.plot_volcano.__name__, "description": "Generates a volcano plot based on two subgroups of the same group", "parameters": { "type": "object", @@ -235,17 +242,3 @@ def get_assistant_functions( }, # {"type": "code_interpreter"}, ] - - -def perform_dimensionality_reduction(dataset, group, method, circle, **kwargs): - dr = DimensionalityReduction( - mat=dataset.mat, - metadate=dataset.metadata, - sample=dataset.sample, - preprocessing_info=dataset.preprocessing_info, - group=group, - circle=circle, - method=method, - **kwargs, - ) - return dr.plot diff --git a/alphastats/llm/llm_integration.py b/alphastats/llm/llm_integration.py index 60392cb2..c866e7f3 100644 --- a/alphastats/llm/llm_integration.py +++ b/alphastats/llm/llm_integration.py @@ -11,23 +11,26 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from alphastats.DataSet import DataSet -from alphastats.llm.enrichment_analysis import get_enrichment_data from alphastats.llm.llm_functions import ( + GENERAL_FUNCTION_MAPPING, get_assistant_functions, get_general_assistant_functions, - perform_dimensionality_reduction, ) from alphastats.llm.llm_utils import ( get_protein_id_for_gene_name, get_subgroups_for_each_group, ) from alphastats.llm.prompts import get_tool_call_message -from alphastats.llm.uniprot_utils import get_gene_function logger = logging.getLogger(__name__) class Models: + """Names of the available models. + + Note that this will be directly passed to the OpenAI client. + """ + GPT4O = "gpt-4o" OLLAMA_31_8B = "llama3.1:8b" OLLAMA_31_70B = "llama3.1:70b" # for testing only @@ -41,7 +44,7 @@ class LLMIntegration: Parameters ---------- - api_type : str + model_name : str The type of API to use, will be forwarded to the client. system_message : str The system message that should be given to the model. @@ -57,34 +60,34 @@ class LLMIntegration: def __init__( self, - api_type: str, + model_name: str, *, base_url: Optional[str] = None, api_key: Optional[str] = None, system_message: str = None, + load_tools: bool = True, dataset: Optional[DataSet] = None, gene_to_prot_id_map: Optional[Dict[str, str]] = None, ): - self._model = api_type + self._model = model_name - if api_type in [Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]: + if model_name in [Models.OLLAMA_31_70B, Models.OLLAMA_31_8B]: url = f"{base_url}/v1" self._client = OpenAI(base_url=url, api_key="ollama") - elif api_type in [Models.GPT4O]: + elif model_name in [Models.GPT4O]: self._client = OpenAI(api_key=api_key) else: - raise ValueError(f"Invalid API type: {api_type}") + raise ValueError(f"Invalid model name: {model_name}") self._dataset = dataset self._metadata = None if dataset is None else dataset.metadata self._gene_to_prot_id_map = gene_to_prot_id_map - self._tools = self._get_tools() + self._tools = self._get_tools() if load_tools else None + self._artifacts = {} self._messages = [] # the conversation history used for the LLM, could be truncated at some point. self._all_messages = [] # full conversation history for display - self._artifacts = {} - if system_message is not None: self._append_message("system", system_message) @@ -190,47 +193,36 @@ def _execute_function( Any The result of the function execution """ - try: - # first try to find the function in the non-Dataset functions - if ( - function := { - "get_gene_function": get_gene_function, - "get_enrichment_data": get_enrichment_data, - }.get(function_name) - ) is not None: - return function(**function_args) - - # special treatment for these functions - elif function_name == "perform_dimensionality_reduction": - # TODO add API in dataset - perform_dimensionality_reduction(self._dataset, **function_args) - - elif function_name == "plot_intensity": - # TODO move this logic to dataset - gene_name = function_args.pop("gene_name") - protein_id = get_protein_id_for_gene_name( - gene_name, self._gene_to_prot_id_map - ) - function_args["protein_id"] = protein_id + # first try to find the function in the non-Dataset functions + if (function := GENERAL_FUNCTION_MAPPING.get(function_name)) is not None: + return function(**function_args) + + # special treatment for this function + elif function_name == "plot_intensity": + # TODO move this logic to dataset + gene_name = function_args.pop( + "protein_id" + ) # no typo, the LLM sets "protein_id" to gene_name + protein_id = get_protein_id_for_gene_name( + gene_name, self._gene_to_prot_id_map + ) + function_args["protein_id"] = protein_id - return self._dataset.plot_intensity(**function_args) + return self._dataset.plot_intensity(**function_args) - # fallback: try to find the function in the Dataset functions - else: - function = getattr( - self._dataset, - function_name.split(".")[-1], - None, # TODO why split? - ) - if function: - return function(**function_args) - - raise ValueError( - f"Function {function_name} not implemented or dataset not available" + # look up the function in the DataSet class + else: + function = getattr( + self._dataset, + function_name.split(".")[-1], + None, # TODO why split? ) + if function: + return function(**function_args) - except Exception as e: - return f"Error executing {function_name}: {str(e)}" + raise ValueError( + f"Function {function_name} not implemented or dataset not available" + ) def _handle_function_calls( self, @@ -260,7 +252,11 @@ def _handle_function_calls( function_name = tool_call.function.name function_args = json.loads(tool_call.function.arguments) - function_result = self._execute_function(function_name, function_args) + try: + function_result = self._execute_function(function_name, function_args) + except Exception as e: + function_result = f"Error executing {function_name}: {str(e)}" + artifact_id = f"{function_name}_{tool_call.id}" new_artifacts[artifact_id] = function_result @@ -272,15 +268,20 @@ def _handle_function_calls( post_artifact_message_idx = len(self._all_messages) self._artifacts[post_artifact_message_idx] = new_artifacts.values() - logger.info(f"Calling 'chat.completions.create' {self._messages[-1]=} ..") - response = self._client.chat.completions.create( + response = self._chat_completion_create() + + return self._parse_model_response(response) + + def _chat_completion_create(self) -> ChatCompletion: + """Create a chat completion based on the current conversation history.""" + logger.info(f"Calling 'chat.completions.create' {self._messages[-1]} ..") + result = self._client.chat.completions.create( model=self._model, messages=self._messages, tools=self._tools, ) logger.info(".. done") - - return self._parse_model_response(response) + return result def get_print_view(self, show_all=False) -> List[Dict[str, Any]]: """Get a structured view of the conversation history for display purposes.""" @@ -301,9 +302,7 @@ def get_print_view(self, show_all=False) -> List[Dict[str, Any]]: ) return print_view - def chat_completion( - self, prompt: str, role: str = "user" - ) -> Tuple[str, Dict[str, Any]]: + def chat_completion(self, prompt: str, role: str = "user") -> None: """ Generate a chat completion based on the given prompt and manage any resulting artifacts. @@ -322,13 +321,7 @@ def chat_completion( self._append_message(role, prompt) try: - logger.info(f"Calling 'chat.completions.create' {self._messages[-1]} ..") - response = self._client.chat.completions.create( - model=self._model, - messages=self._messages, - tools=self._tools, - ) - logger.info(".. done") + response = self._chat_completion_create() content, tool_calls = self._parse_model_response(response) @@ -345,7 +338,6 @@ def chat_completion( except ArithmeticError as e: error_message = f"Error in chat completion: {str(e)}" self._append_message("system", error_message) - return error_message, {} # TODO this seems to be for notebooks? # we need some "export mode" where everything is shown diff --git a/tests/llm/__init__.py b/tests/llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/llm/test_llm_functions.py b/tests/llm/test_llm_functions.py new file mode 100644 index 00000000..c0424fa8 --- /dev/null +++ b/tests/llm/test_llm_functions.py @@ -0,0 +1,95 @@ +"""Test that the function definitions in the LLM match the actual functions: + +- all parameters defined for the LLM are present in the function definitions +- all non-default parameters in the function definitions are required for the LLM +""" + +import inspect +from typing import Callable, Dict + +import pandas as pd + +from alphastats.DataSet import DataSet +from alphastats.llm.llm_functions import ( + GENERAL_FUNCTION_MAPPING, + get_assistant_functions, + get_general_assistant_functions, +) + + +def _get_method_parameters(method: Callable) -> Dict: + """Get the parameters of a method as a dictionary, excluding "self".""" + signature = inspect.signature(method) + + params = dict(signature.parameters) + if "self" in params: + del params["self"] + + return params + + +def _get_class_methods(cls): + """Get all methods of a class as a dictionary, excluding "self".""" + return { + member[0]: member[1] + for member in inspect.getmembers(cls, predicate=inspect.isfunction) + if member[0] != "self" + } + + +def assert_parameters(method_definition: Callable, llm_function_dict_): + """Assert that the parameters of a method match the parameters defined in the dict used by the LLM.""" + + # suffix '_' denotes LLM-related variables + parameters = _get_method_parameters(method_definition) + parameters_without_default = [ + param + for param in parameters + if parameters[param].default == inspect.Parameter.empty + ] + + parameters_dict_ = llm_function_dict_["function"]["parameters"] + parameters_ = parameters_dict_["properties"].keys() + + # are all in parameters_ available in the function? + assert set(parameters_).issubset(set(parameters)) + + # are all the parameters w/o default in the function filled in parameters_? + assert set(parameters_without_default).issubset(set(parameters_)) + + # are all required parameters marked as 'required'? + assert set(parameters_without_default).issubset(set(parameters_dict_["required"])) + + +def test_general_assistant_functions(): + """Test that the general assistant functions in the LLM match the actual functions.""" + # suffix '_' denotes LLM-related variables + assistant_functions_dict = get_general_assistant_functions() + + for llm_function_dict_ in assistant_functions_dict: + name_ = llm_function_dict_["function"]["name"] + + method_definition = GENERAL_FUNCTION_MAPPING.get(name_, None) + + if method_definition is None: + raise ValueError(f"Function not found in test: {name_}") + + assert_parameters(method_definition, llm_function_dict_) + + +def test_assistant_functions(): + """Test that the assistant functions in the LLM match the actual functions.""" + # suffix '_' denotes LLM-related variables + assistant_functions_dict = get_assistant_functions({}, pd.DataFrame(), {}) + + all_dataset_methods = _get_class_methods(DataSet) + + for llm_function_dict_ in assistant_functions_dict: + name_ = llm_function_dict_["function"]["name"] + + method_definition = all_dataset_methods.get(name_, None) + + if method_definition is None: + raise ValueError(f"Function not found in test: {name_}") + + assert_parameters(method_definition, llm_function_dict_) diff --git a/tests/llm/test_llm_helper.py b/tests/llm/test_llm_helper.py new file mode 100644 index 00000000..5f09534e --- /dev/null +++ b/tests/llm/test_llm_helper.py @@ -0,0 +1,136 @@ +from unittest.mock import patch + +import pytest + +from alphastats.gui.utils.llm_helper import ( + get_display_proteins_html, + llm_connection_test, + set_api_key, +) +from alphastats.gui.utils.ui_helper import StateKeys + + +@pytest.fixture +def mock_streamlit(): + """Fixture to mock streamlit module.""" + with patch("streamlit.info") as mock_info, patch( + "streamlit.error" + ) as mock_error, patch("streamlit.success") as mock_success, patch( + "streamlit.session_state", {} + ) as mock_session_state: + yield { + "info": mock_info, + "error": mock_error, + "success": mock_success, + "session_state": mock_session_state, + } + + +def test_display_proteins_upregulated(mock_streamlit): + """Test displaying upregulated proteins.""" + protein_ids = ["P12345", "Q67890"] + result = get_display_proteins_html(protein_ids, is_upregulated=True) + + expected_html = ( + "' + ) + + assert result == expected_html + + +def test_display_proteins_downregulated(mock_streamlit): + """Test displaying downregulated proteins.""" + protein_ids = ["P12345"] + result = get_display_proteins_html(protein_ids, is_upregulated=False) + + expected_html = ( + "' + ) + + assert result == expected_html + + +def test_display_proteins_empty_list(mock_streamlit): + """Test displaying empty protein list.""" + assert get_display_proteins_html([], is_upregulated=True) == "" + + +@pytest.mark.parametrize( + "api_key,expected_message", + [ + ("abc123xyz", "OpenAI API key set: abc***xyz"), + ( + None, + "Please enter an OpenAI key or provide it in a secrets.toml file in the alphastats/gui/.streamlit directory like `openai_api_key = `", + ), + ], +) +def test_set_api_key_direct(mock_streamlit, api_key, expected_message): + """Test setting API key directly.""" + set_api_key(api_key) + + if api_key: + mock_streamlit["info"].assert_called_once_with(expected_message) + assert mock_streamlit["session_state"][StateKeys.OPENAI_API_KEY] == api_key + else: + mock_streamlit["info"].assert_called_with(expected_message) + + +@patch("streamlit.secrets") +@patch("pathlib.Path.exists") +def test_set_api_key_from_secrets(mock_exists, mock_st_secrets, mock_streamlit): + """Test loading API key from secrets.toml.""" + mock_exists.return_value = True + + mock_st_secrets.__getitem__.return_value = ( + "test_secret_key" # pragma: allowlist secret + ) + + set_api_key() + + mock_streamlit["info"].assert_called_with( + "OpenAI API key loaded from secrets.toml." + ) + assert ( + mock_streamlit["session_state"][StateKeys.OPENAI_API_KEY] + == "test_secret_key" # pragma: allowlist secret + ) + + +@patch("pathlib.Path.exists") +def test_set_api_key_missing_secrets(mock_exists, mock_streamlit): + """Test handling missing secrets.toml.""" + mock_exists.return_value = False + + set_api_key() + + mock_streamlit["info"].assert_called_with( + "Please enter an OpenAI key or provide it in a secrets.toml file in the " + "alphastats/gui/.streamlit directory like `openai_api_key = `" + ) + + +@patch("alphastats.gui.utils.llm_helper.LLMIntegration") +def test_llm_connection_test_success(mock_llm): + """Test successful LLM connection.""" + assert llm_connection_test("some_model") is None + + mock_llm.assert_called_once_with( + "some_model", base_url=None, api_key=None, load_tools=False + ) + + +@patch("alphastats.gui.utils.llm_helper.LLMIntegration") +def test_llm_connection_test_failure(mock_llm, mock_streamlit): + """Test failed LLM connection.""" + mock_llm.return_value.chat_completion.side_effect = ValueError("API Error") + + assert llm_connection_test("some_model") == "API Error" + + mock_llm.assert_called_once_with( + "some_model", base_url=None, api_key=None, load_tools=False + ) diff --git a/tests/llm/test_llm_integration.py b/tests/llm/test_llm_integration.py new file mode 100644 index 00000000..9074bf3c --- /dev/null +++ b/tests/llm/test_llm_integration.py @@ -0,0 +1,471 @@ +from unittest import skip +from unittest.mock import Mock, patch + +import pandas as pd +import pytest +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageToolCall, +) +from openai.types.chat.chat_completion_message_tool_call import Function + +from alphastats.llm.llm_integration import LLMIntegration, Models + + +@pytest.fixture +def mock_openai_client(): + with patch("alphastats.llm.llm_integration.OpenAI") as mock_client: + yield mock_client + + +@pytest.fixture +def llm_integration(mock_openai_client): + """Fixture providing a basic LLM instance with test configuration""" + dataset = Mock() + dataset.plot_intensity = Mock(return_value="Plot created") + dataset.custom_function = Mock(return_value="Dataset function called") + dataset.metadata = pd.DataFrame({"group1": ["A", "B"], "group2": ["C", "D"]}) + return LLMIntegration( + model_name=Models.GPT4O, + api_key="test-key", # pragma: allowlist secret + system_message="Test system message", + dataset=dataset, + gene_to_prot_id_map={"GENE1": "PROT1", "GENE2": "PROT2"}, + ) + + +@pytest.fixture +def llm_with_conversation(llm_integration): + """Setup LLM with a sample conversation history""" + # Add various message types to conversation history + llm_integration._all_messages = [ + {"role": "system", "content": "System message"}, + {"role": "user", "content": "User message 1"}, + {"role": "assistant", "content": "Assistant message 1"}, + { + "role": "assistant", + "content": "Assistant with tool calls", + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "test"}} + ], + }, + {"role": "tool", "content": "Tool response"}, + {"role": "user", "content": "User message 2"}, + {"role": "assistant", "content": "Assistant message 2"}, + ] + + # Add some artifacts + llm_integration._artifacts = { + 2: ["Artifact for message 2"], + 4: ["Tool artifact 1", "Tool artifact 2"], + 6: ["Artifact for message 6"], + } + + return llm_integration + + +@pytest.fixture +def mock_chat_completion(): + """Fixture providing a mock successful chat completion""" + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock( + message=ChatCompletionMessage( + role="assistant", content="Test response", tool_calls=None + ) + ) + ] + return mock_response + + +@pytest.fixture +def mock_tool_call_completion(): + """Fixture providing a mock completion with tool calls""" + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock( + message=ChatCompletionMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionMessageToolCall( + id="test-id", + type="function", + function={"name": "test_function", "arguments": "{}"}, + ) + ], + ) + ) + ] + return mock_response + + +@pytest.fixture +def mock_general_function_mapping(): + def mock_general_function(param1, param2): + """An example for a function.""" + return f"General function called with {param1} and {param2}" + + return {"test_general_function": mock_general_function} + + +def test_initialization_gpt4(mock_openai_client): + """Test initialization with GPT-4 configuration""" + LLMIntegration( + model_name=Models.GPT4O, + api_key="test-key", # pragma: allowlist secret + ) + + mock_openai_client.assert_called_once_with( + api_key="test-key" # pragma: allowlist secret + ) + + +def test_initialization_ollama(mock_openai_client): + """Test initialization with Ollama configuration""" + LLMIntegration( + model_name=Models.OLLAMA_31_8B, + base_url="http://localhost:11434", + ) + + mock_openai_client.assert_called_once_with( + base_url="http://localhost:11434/v1", + api_key="ollama", # pragma: allowlist secret + ) + + +def test_initialization_invalid_model(): + """Test initialization with invalid model type""" + with pytest.raises(ValueError, match="Invalid model name"): + LLMIntegration(model_name="invalid-model") + + +def test_append_message(llm_integration): + """Test message appending functionality""" + llm_integration._append_message("user", "Test message") + + assert len(llm_integration._messages) == 2 # Including system message + assert len(llm_integration._all_messages) == 2 + assert llm_integration._messages[-1] == {"role": "user", "content": "Test message"} + + +def test_append_message_with_tool_calls(llm_integration): + """Test message appending with tool calls""" + tool_calls = [ + ChatCompletionMessageToolCall( + id="test-id", + type="function", + function={"name": "test_function", "arguments": "{}"}, + ) + ] + + llm_integration._append_message("assistant", "Test message", tool_calls=tool_calls) + + assert llm_integration._messages[-1]["tool_calls"] == tool_calls + + +@pytest.mark.parametrize( + "num_messages,message_length,max_tokens,expected_messages", + [ + (5, 100, 200, 2), # Should truncate to 2 messages + (3, 50, 1000, 3), # Should keep all messages + (10, 20, 100, 5), # Should truncate to 5 messages + ], +) +def test_truncate_conversation_history( + llm_integration, num_messages, message_length, max_tokens, expected_messages +): + """Test conversation history truncation with different scenarios""" + # Add multiple messages + message_content = "Test " * message_length + for _ in range(num_messages): + llm_integration._append_message("user", message_content) + + llm_integration._truncate_conversation_history(max_tokens=max_tokens) + + # Adding 1 to account for the initial system message + assert len(llm_integration._messages) <= expected_messages + 1 + + +def test_chat_completion_success(llm_integration, mock_chat_completion): + """Test successful chat completion""" + llm_integration._client.chat.completions.create.return_value = mock_chat_completion + + llm_integration.chat_completion("Test prompt") + + assert llm_integration._messages == [ + { + "content": "Test system message", + "role": "system", + }, + { + "content": "Test prompt", + "role": "user", + }, + { + "content": "Test response", + "role": "assistant", + }, + ] + + +def test_chat_completion_with_error(llm_integration): + """Test chat completion with error handling""" + llm_integration._client.chat.completions.create.side_effect = ArithmeticError( + "Test error" + ) + + llm_integration.chat_completion("Test prompt") + + assert ( + "Error in chat completion: Test error" + in llm_integration._messages[-1]["content"] + ) + + +def test_parse_model_response(llm_integration, mock_tool_call_completion): + """Test parsing model response with tool calls""" + content, tool_calls = llm_integration._parse_model_response( + mock_tool_call_completion + ) + + assert content == "" + assert len(tool_calls) == 1 + assert tool_calls[0].id == "test-id" + assert tool_calls[0].type == "function" + + +def test_chat_completion_with_content_and_tool_calls(llm_integration): + """Test that chat completion raises error when receiving both content and tool calls""" + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock( + message=ChatCompletionMessage( + role="assistant", + content="Some content", + tool_calls=[ + ChatCompletionMessageToolCall( + id="test-id", + type="function", + function={"name": "test_function", "arguments": "{}"}, + ) + ], + ) + ) + ] + llm_integration._client.chat.completions.create.return_value = mock_response + + with pytest.raises(ValueError, match="Unexpected content.*with tool calls"): + llm_integration.chat_completion("Test prompt") + + +@pytest.mark.parametrize( + "function_name,function_args,expected_result", + [ + ( + "test_general_function", + {"param1": "value1", "param2": "value2"}, + "General function called with value1 and value2", + ), + ], +) +def test_execute_general_function( + llm_integration, + mock_general_function_mapping, + function_name, + function_args, + expected_result, +): + """Test execution of functions from GENERAL_FUNCTION_MAPPING""" + with patch( + "alphastats.llm.llm_integration.GENERAL_FUNCTION_MAPPING", + mock_general_function_mapping, + ): + result = llm_integration._execute_function(function_name, function_args) + assert result == expected_result + + +@pytest.mark.parametrize( + "gene_name,plot_args,expected_protein_id", + [ + ("GENE1", {"param1": "value1"}, "PROT1"), + ("GENE2", {"param2": "value2"}, "PROT2"), + ], +) +def test_execute_plot_intensity( + llm_integration, gene_name, plot_args, expected_protein_id +): + """Test execution of plot_intensity with gene name translation""" + function_args = {"protein_id": gene_name, **plot_args} + + result = llm_integration._execute_function("plot_intensity", function_args) + + # Verify the dataset's plot_intensity was called with correct protein ID + llm_integration._dataset.plot_intensity.assert_called_once() + call_args = llm_integration._dataset.plot_intensity.call_args[1] + assert call_args["protein_id"] == expected_protein_id + assert result == "Plot created" + + +def test_execute_dataset_function(llm_integration): + """Test execution of a function from the dataset""" + result = llm_integration._execute_function("custom_function", {"param1": "value1"}) + + assert result == "Dataset function called" + llm_integration._dataset.custom_function.assert_called_once_with(param1="value1") + + +def test_execute_dataset_function_with_dots(llm_integration): + """Test execution of a dataset function when name contains dots""" + result = llm_integration._execute_function( + "dataset.custom_function", {"param1": "value1"} + ) + + assert result == "Dataset function called" + llm_integration._dataset.custom_function.assert_called_once_with(param1="value1") + + +@skip # TODO fix this test +def test_execute_nonexistent_function(llm_integration): + """Test execution of a non-existent function""" + + result = llm_integration._execute_function( + "nonexistent_function", {"param1": "value1"} + ) + + assert "Error executing nonexistent_function" in result + assert "not implemented or dataset not available" in result + + +def test_execute_function_with_error(llm_integration, mock_general_function_mapping): + """Test handling of function execution errors""" + + def failing_function(**kwargs): + raise ValueError("Test error") + + with patch( + "alphastats.llm.llm_integration.GENERAL_FUNCTION_MAPPING", + {"failing_function": failing_function}, + ), pytest.raises(ValueError, match="Test error"): + llm_integration._execute_function("failing_function", {"param1": "value1"}) + + +def test_execute_function_without_dataset(mock_openai_client): + """Test function execution when dataset is not available""" + llm = LLMIntegration(model_name=Models.GPT4O, api_key="test-key") + + with pytest.raises( + ValueError, + match="Function dataset_function not implemented or dataset not available", + ): + llm._execute_function("dataset_function", {"param1": "value1"}) + + +@patch("alphastats.llm.llm_integration.LLMIntegration._execute_function") +def test_handle_function_calls( + mock_execute_function, mock_openai_client, mock_chat_completion +): + """Test handling of function calls in the chat completion response.""" + mock_execute_function.return_value = "some_function_result" + + llm_integration = LLMIntegration( + model_name=Models.GPT4O, + api_key="test-key", # pragma: allowlist secret + system_message="Test system message", + ) + + tool_calls = [ + ChatCompletionMessageToolCall( + id="test-id", + type="function", + function={"name": "test_function", "arguments": '{"arg1": "value1"}'}, + ) + ] + + mock_openai_client.return_value.chat.completions.create.return_value = ( + mock_chat_completion + ) + result = llm_integration._handle_function_calls(tool_calls) + + assert result == ("Test response", None) + + mock_execute_function.assert_called_once_with("test_function", {"arg1": "value1"}) + + expected_messages = [ + {"role": "system", "content": "Test system message"}, + { + "role": "assistant", + "content": 'Calling function: test_function with arguments: {"arg1": "value1"}', + "tool_calls": [ + ChatCompletionMessageToolCall( + id="test-id", + function=Function( + arguments='{"arg1": "value1"}', name="test_function" + ), + type="function", + ) + ], + }, + { + "role": "tool", + "content": '{"result": "some_function_result", "artifact_id": "test_function_test-id"}', + "tool_call_id": "test-id", + }, + ] + mock_openai_client.return_value.chat.completions.create.assert_called_once_with( + model="gpt-4o", messages=expected_messages, tools=llm_integration._tools + ) + + assert list(llm_integration._artifacts[3]) == ["some_function_result"] + + assert llm_integration._messages == expected_messages + + +def test_get_print_view_default(llm_with_conversation): + """Test get_print_view with default settings (show_all=False)""" + print_view = llm_with_conversation.get_print_view() + + # Should only include user and assistant messages without tool_calls + assert print_view == [ + {"artifacts": [], "content": "User message 1", "role": "user"}, + { + "artifacts": ["Artifact for message 2"], + "content": "Assistant message 1", + "role": "assistant", + }, + {"artifacts": [], "content": "User message 2", "role": "user"}, + { + "artifacts": ["Artifact for message 6"], + "content": "Assistant message 2", + "role": "assistant", + }, + ] + + +def test_get_print_view_show_all(llm_with_conversation): + """Test get_print_view with default settings (show_all=True)""" + print_view = llm_with_conversation.get_print_view(show_all=True) + + # Should only include user and assistant messages without tool_calls + assert print_view == [ + {"artifacts": [], "content": "System message", "role": "system"}, + {"artifacts": [], "content": "User message 1", "role": "user"}, + { + "artifacts": ["Artifact for message 2"], + "content": "Assistant message 1", + "role": "assistant", + }, + {"artifacts": [], "content": "Assistant with tool calls", "role": "assistant"}, + { + "artifacts": ["Tool artifact 1", "Tool artifact 2"], + "content": "Tool response", + "role": "tool", + }, + {"artifacts": [], "content": "User message 2", "role": "user"}, + { + "artifacts": ["Artifact for message 6"], + "content": "Assistant message 2", + "role": "assistant", + }, + ] diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py new file mode 100644 index 00000000..1dc3ed43 --- /dev/null +++ b/tests/llm/test_llm_utils.py @@ -0,0 +1,100 @@ +import pandas as pd +import pytest + +from alphastats.llm.llm_utils import ( + get_protein_id_for_gene_name, + get_subgroups_for_each_group, +) + + +def test_get_subgroups_for_each_group_basic(): + """Test basic functionality with simple metadata.""" + # Create test metadata + data = { + "disease": ["cancer", "healthy", "cancer"], + "treatment": ["drug_a", "placebo", "drug_b"], + } + metadata = pd.DataFrame(data) + + result = get_subgroups_for_each_group(metadata) + + expected = { + "disease": ["cancer", "healthy"], + "treatment": ["drug_a", "placebo", "drug_b"], + } + assert result == expected + + +def test_get_subgroups_for_each_group_empty(): + """Test with empty DataFrame.""" + metadata = pd.DataFrame() + result = get_subgroups_for_each_group(metadata) + assert result == {} + + +def test_get_subgroups_for_each_group_single_column(): + """Test with single column DataFrame.""" + data = {"condition": ["A", "B", "A"]} + metadata = pd.DataFrame(data) + + result = get_subgroups_for_each_group(metadata) + + expected = {"condition": ["A", "B"]} + assert result == expected + + +def test_get_subgroups_for_each_group_numeric_values(): + """Test with numeric values in DataFrame.""" + data = {"age_group": [20, 30, 20], "score": [1.5, 2.5, 1.5]} + metadata = pd.DataFrame(data) + + result = get_subgroups_for_each_group(metadata) + + expected = {"age_group": ["20", "30"], "score": ["1.5", "2.5"]} + assert result == expected + + +@pytest.fixture +def gene_to_prot_map(): + """Fixture for protein mapping dictionary.""" + return { + "VCL": "P18206", + "VCL;HEL114": "P18206;A0A024QZN4", + "MULTI;GENE": "PROT1;PROT2;PROT3", + } + + +def test_get_protein_id_direct_match(gene_to_prot_map): + """Test when gene name directly matches a key.""" + result = get_protein_id_for_gene_name("VCL", gene_to_prot_map) + assert result == "P18206" + + +def test_get_protein_id_compound_key(gene_to_prot_map): + """Test when gene name is part of a compound key.""" + result = get_protein_id_for_gene_name("HEL114", gene_to_prot_map) + assert result == "P18206;A0A024QZN4" + + +def test_get_protein_id_not_found(gene_to_prot_map): + """Test when gene name is not found in mapping.""" + result = get_protein_id_for_gene_name("UNKNOWN", gene_to_prot_map) + assert result == "UNKNOWN" + + +def test_get_protein_id_empty_map(): + """Test with empty mapping dictionary.""" + result = get_protein_id_for_gene_name("VCL", {}) + assert result == "VCL" + + +def test_get_protein_id_multiple_matches(gene_to_prot_map): + """Test with a gene that appears in multiple compound keys.""" + result = get_protein_id_for_gene_name("MULTI", gene_to_prot_map) + assert result == "PROT1;PROT2;PROT3" + + +def test_get_protein_id_case_sensitivity(gene_to_prot_map): + """Test case sensitivity of gene name matching.""" + result = get_protein_id_for_gene_name("vcl", gene_to_prot_map) + assert result == "vcl" # Should not match 'VCL' due to case sensitivity