diff --git a/arc_finetuning_st/streamlit/app.py b/arc_finetuning_st/streamlit/app.py index 58bb776..0cbce99 100644 --- a/arc_finetuning_st/streamlit/app.py +++ b/arc_finetuning_st/streamlit/app.py @@ -30,6 +30,8 @@ def startup() -> Tuple[Controller,]: st.session_state["disable_preview_button"] = True if "metric_value" not in st.session_state: st.session_state["metric_value"] = "N/A" +if "is_valid_api_key" not in st.session_state: + st.session_state["is_valid_api_key"] = False logo = '[](https://github.com/run-llama/llama-agents "Check out the llama-agents Github repo!")' st.title("ARC Task Solver with Human Input") @@ -40,6 +42,13 @@ def startup() -> Tuple[Controller,]: # sidebar with st.sidebar: + api_key = st.text_input( + "OpenAI API key:", + type="password", + key="openai_api_key", + on_change=controller.check_openai_api_key, + ) + task_selection = st.radio( label="Tasks", options=controller.task_file_names, diff --git a/arc_finetuning_st/streamlit/controller.py b/arc_finetuning_st/streamlit/controller.py index f7eb07b..4e48ddb 100644 --- a/arc_finetuning_st/streamlit/controller.py +++ b/arc_finetuning_st/streamlit/controller.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from os import listdir +import os from pathlib import Path from typing import Any, List, Literal, Optional, cast @@ -10,6 +10,7 @@ import streamlit as st from llama_index.core.workflow.handler import WorkflowHandler from llama_index.llms.openai import OpenAI +from openai import AuthenticationError from arc_finetuning_st.finetuning.finetuning_example import FineTuningExample from arc_finetuning_st.workflows.arc_task_solver import ( @@ -54,6 +55,16 @@ def selectbox_selection_change_handler(self) -> None: # streamlit element self.reset() + def check_openai_api_key(self) -> None: + client = OpenAI(api_key=st.session_state.openai_api_key)._get_client() + try: + client.models.list() + except AuthenticationError: + st.session_state.is_valid_api_key = False + else: + st.session_state.is_valid_api_key = True + os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key + @staticmethod def plot_grid( grid: List[List[int]], @@ -97,6 +108,10 @@ def handle_abort_click(self) -> None: async def handle_prediction_click(self) -> None: """Run workflow to generate prediction.""" + if not st.session_state.is_valid_api_key: + st.error("The OPENAI API KEY entered is invalid.") + return None + selected_task = st.session_state.selected_task if selected_task: task = self.load_task(selected_task) @@ -147,11 +162,11 @@ async def handle_prediction_click(self) -> None: @property def saved_finetuning_examples(self) -> List[str]: - return listdir(self._finetuning_examples_path) + return os.listdir(self._finetuning_examples_path) @property def task_file_names(self) -> List[str]: - return listdir(self._data_path) + return os.listdir(self._data_path) def radio_format_task_name(self, selected_task: str) -> str: if selected_task in self.saved_finetuning_examples: