From bc5643002e1d64ac7b6a69b53305a6d455bfc8fe Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:02:25 +0200 Subject: [PATCH] preserve selected model in selectbox --- alphastats/gui/pages/05_LLM.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 97b662b2..42bbf2ee 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -36,11 +36,15 @@ def llm_config(): """Show the configuration options for the LLM analysis.""" c1, _ = st.columns((1, 2)) with c1: - model_before = st.session_state.get(StateKeys.MODEL_NAME, None) + current_model = st.session_state.get(StateKeys.MODEL_NAME, None) + models = [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B] st.session_state[StateKeys.MODEL_NAME] = st.selectbox( "Select LLM", - [Models.GPT4O, Models.OLLAMA_31_70B, Models.OLLAMA_31_8B], + models, + index=models.index(st.session_state.get(StateKeys.MODEL_NAME)) + if current_model is not None + else 0, ) base_url = None @@ -67,9 +71,9 @@ def llm_config(): if error is None: st.success(f"Connection to {model_name} successful!") else: - st.error(f"❌ Connection to {model_name} failed: {str(error)}") + st.error(f"Connection to {model_name} failed: {str(error)}") - if model_before != st.session_state[StateKeys.MODEL_NAME]: + if current_model != st.session_state[StateKeys.MODEL_NAME]: st.rerun(scope="app")