Skip to content

Commit

Permalink
make api key input from user
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Sep 24, 2024
1 parent 2da3628 commit 95a8c0e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
9 changes: 9 additions & 0 deletions arc_finetuning_st/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def startup() -> Tuple[Controller,]:
st.session_state["disable_preview_button"] = True
if "metric_value" not in st.session_state:
st.session_state["metric_value"] = "N/A"
if "is_valid_api_key" not in st.session_state:
st.session_state["is_valid_api_key"] = False

logo = '[<img src="https://d3ddy8balm3goa.cloudfront.net/llamaindex/LlamaLogoSmall.png" width="28" height="28" />](https://github.com/run-llama/llama-agents "Check out the llama-agents Github repo!")'
st.title("ARC Task Solver with Human Input")
Expand All @@ -40,6 +42,13 @@ def startup() -> Tuple[Controller,]:

# sidebar
with st.sidebar:
api_key = st.text_input(
"OpenAI API key:",
type="password",
key="openai_api_key",
on_change=controller.check_openai_api_key,
)

task_selection = st.radio(
label="Tasks",
options=controller.task_file_names,
Expand Down
21 changes: 18 additions & 3 deletions arc_finetuning_st/streamlit/controller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import logging
from os import listdir
import os
from pathlib import Path
from typing import Any, List, Literal, Optional, cast

Expand All @@ -10,6 +10,7 @@
import streamlit as st
from llama_index.core.workflow.handler import WorkflowHandler
from llama_index.llms.openai import OpenAI
from openai import AuthenticationError

from arc_finetuning_st.finetuning.finetuning_example import FineTuningExample
from arc_finetuning_st.workflows.arc_task_solver import (
Expand Down Expand Up @@ -54,6 +55,16 @@ def selectbox_selection_change_handler(self) -> None:
# streamlit element
self.reset()

def check_openai_api_key(self) -> None:
client = OpenAI(api_key=st.session_state.openai_api_key)._get_client()
try:
client.models.list()
except AuthenticationError:
st.session_state.is_valid_api_key = False
else:
st.session_state.is_valid_api_key = True
os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key

@staticmethod
def plot_grid(
grid: List[List[int]],
Expand Down Expand Up @@ -97,6 +108,10 @@ def handle_abort_click(self) -> None:

async def handle_prediction_click(self) -> None:
"""Run workflow to generate prediction."""
if not st.session_state.is_valid_api_key:
st.error("The OPENAI API KEY entered is invalid.")
return None

selected_task = st.session_state.selected_task
if selected_task:
task = self.load_task(selected_task)
Expand Down Expand Up @@ -147,11 +162,11 @@ async def handle_prediction_click(self) -> None:

@property
def saved_finetuning_examples(self) -> List[str]:
return listdir(self._finetuning_examples_path)
return os.listdir(self._finetuning_examples_path)

@property
def task_file_names(self) -> List[str]:
return listdir(self._data_path)
return os.listdir(self._data_path)

def radio_format_task_name(self, selected_task: str) -> str:
if selected_task in self.saved_finetuning_examples:
Expand Down

0 comments on commit 95a8c0e

Please sign in to comment.