diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/04_Analysis.py index a6f70d0c..8188d28f 100644 --- a/alphastats/gui/pages/04_Analysis.py +++ b/alphastats/gui/pages/04_Analysis.py @@ -89,7 +89,7 @@ def show_start_llm_button(method: str) -> None: msg = ( "(this will overwrite the existing LLM analysis!)" - if StateKeys.LLM_INTEGRATION in st.session_state + if st.session_state.get(StateKeys.LLM_INTEGRATION, {}) != {} else "" ) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 27d9b1b3..6d869a25 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -64,11 +64,10 @@ def llm_config(): api_key=st.session_state[StateKeys.OPENAI_API_KEY], base_url=base_url, ) - st.success( - f"Connection to {api_type} successful!" - ) if error is None else st.error( - f"❌ Connection to {api_type} failed: {error}" - ) + if error is None: + st.success(f"Connection to {api_type} successful!") + else: + st.error(f"❌ Connection to {api_type} failed: {str(error)}") if model_before != st.session_state[StateKeys.API_TYPE]: st.rerun(scope="app") diff --git a/alphastats/gui/utils/llm_helper.py b/alphastats/gui/utils/llm_helper.py index 801fa9ad..8895dd7c 100644 --- a/alphastats/gui/utils/llm_helper.py +++ b/alphastats/gui/utils/llm_helper.py @@ -70,7 +70,9 @@ def llm_connection_test( ) -> Optional[str]: """Test the connection to the LLM API, return None in case of success, error message otherwise.""" try: - llm = LLMIntegration(api_type, base_url=base_url, api_key=api_key) + llm = LLMIntegration( + api_type, base_url=base_url, api_key=api_key, load_tools=False + ) llm.chat_completion("Hello there!") return None diff --git a/alphastats/llm/llm_integration.py b/alphastats/llm/llm_integration.py index dfac63aa..4ebe3fdf 100644 --- a/alphastats/llm/llm_integration.py +++ b/alphastats/llm/llm_integration.py @@ -60,6 +60,7 @@ def __init__( base_url: Optional[str] = None, api_key: Optional[str] = None, system_message: str = None, + load_tools: bool = True, dataset: Optional[DataSet] = None, gene_to_prot_id_map: Optional[Dict[str, str]] = None, ): @@ -77,12 +78,11 @@ def __init__( self._metadata = None if dataset is None else dataset.metadata self._gene_to_prot_id_map = gene_to_prot_id_map - self._tools = self._get_tools() + self._tools = self._get_tools() if load_tools else None + self._artifacts = {} self._messages = [] # the conversation history used for the LLM, could be truncated at some point. self._all_messages = [] # full conversation history for display - self._artifacts = {} - if system_message is not None: self._append_message("system", system_message)