Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Sep 20, 2024
1 parent 96a3f83 commit 663063c
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 235 deletions.
24 changes: 15 additions & 9 deletions arc_finetuning_st/streamlit/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pandas as pd
import random
import streamlit as st
from typing import Tuple

import streamlit as st
from llama_index.core.tools.function_tool import async_to_sync

from arc_finetuning_st.streamlit.controller import Controller
Expand Down Expand Up @@ -41,7 +39,9 @@ def startup() -> Tuple[Controller,]:
key="selected_task",
)

train_col, test_col = st.columns([1, 1], vertical_alignment="top", gap="medium")
train_col, test_col = st.columns(
[1, 1], vertical_alignment="top", gap="medium"
)

with train_col:
st.subheader("Train Examples")
Expand All @@ -50,7 +50,9 @@ def startup() -> Tuple[Controller,]:
if selected_task:
task = controller.load_task(selected_task)
num_examples = len(task["train"])
tabs = st.tabs([f"Example {ix}" for ix in range(1, num_examples + 1)])
tabs = st.tabs(
[f"Example {ix}" for ix in range(1, num_examples + 1)]
)
for ix, tab in enumerate(tabs):
with tab:
left, right = st.columns(
Expand Down Expand Up @@ -86,9 +88,9 @@ def startup() -> Tuple[Controller,]:
with abort_col:

@st.dialog("Are you sure you want to abort the session?")
def abort_solving():
def abort_solving() -> None:
st.write(
f"Confirm that you want to abort the session by clicking 'confirm' button below."
"Confirm that you want to abort the session by clicking 'confirm' button below."
)
if st.button("Confirm"):
controller.reset()
Expand All @@ -105,7 +107,9 @@ def abort_solving():
if selected_task:
task = controller.load_task(selected_task)
num_cases = len(task["test"])
tabs = st.tabs([f"Test Case {ix}" for ix in range(1, num_cases + 1)])
tabs = st.tabs(
[f"Test Case {ix}" for ix in range(1, num_cases + 1)]
)
for ix, tab in enumerate(tabs):
with tab:
left, right = st.columns(
Expand All @@ -118,7 +122,9 @@ def abort_solving():
st.plotly_chart(fig, use_container_width=True)

with right:
prediction_fig = st.session_state.get("prediction", None)
prediction_fig = st.session_state.get(
"prediction", None
)
if prediction_fig:
st.plotly_chart(
prediction_fig,
Expand Down
54 changes: 30 additions & 24 deletions arc_finetuning_st/streamlit/controller.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
import asyncio
import logging
import streamlit as st
import plotly.express as px
import pandas as pd
from typing import Any, Dict, Optional, List, Literal
from pathlib import Path
from os import listdir
from pathlib import Path
from typing import Any, List, Literal, Optional, cast

import pandas as pd
import plotly.express as px
import streamlit as st
from llama_index.core.workflow.handler import WorkflowHandler
from llama_index.llms.openai import OpenAI

from arc_finetuning_st.workflows.models import Prediction
from arc_finetuning_st.workflows.arc_task_solver import (
ARCTaskSolverWorkflow,
WorkflowOutput,
)
from arc_finetuning_st.workflows.models import Prediction

logger = logging.getLogger(__name__)


class Controller:
def __init__(self) -> None:
self._handler = None
self._attempts = []
self._passing_results = []
self._handler: Optional[WorkflowHandler] = None
self._attempts: List[Prediction] = []
self._passing_results: List[bool] = []
parent_path = Path(__file__).parents[2].absolute()
self._data_path = Path(parent_path, "data", "training")

def reset(self):
def reset(self) -> None:
# clear prediction
st.session_state.prediction = None
st.session_state.disable_continue_button = True
Expand All @@ -39,20 +40,22 @@ def reset(self):
self._attempts = []
self._passing_results = []

def selectbox_selection_change_handler(self):
def selectbox_selection_change_handler(self) -> None:
# only reset states
# loading of task is delegated to relevant calls made with each
# streamlit element
self.reset()

@staticmethod
def plot_grid(
grid: List[List[int]], kind=Literal["input", "output", "prediction"]
grid: List[List[int]], kind: Literal["input", "output", "prediction"]
) -> Any:
m = len(grid)
n = len(grid[0])
fig = px.imshow(
grid, text_auto=True, labels={"x": f"{kind.title()}<br><sup>{m}x{n}</sup>"}
grid,
text_auto=True,
labels={"x": f"{kind.title()}<br><sup>{m}x{n}</sup>"},
)
fig.update_coloraxes(showscale=False)
fig.update_layout(
Expand All @@ -67,7 +70,7 @@ def plot_grid(
)
return fig

async def show_progress_bar(self, handler) -> None:
async def show_progress_bar(self, handler: WorkflowHandler) -> None:
progress_text_template = "{event} completed. Next step in progress."
my_bar = st.progress(0, text="Workflow run in progress. Please wait.")
num_steps = 5.0
Expand All @@ -88,13 +91,14 @@ async def handle_prediction_click(self) -> None:
selected_task = st.session_state.selected_task
if selected_task:
task = self.load_task(selected_task)
w = ARCTaskSolverWorkflow(timeout=None, verbose=False, llm=OpenAI("gpt-4o"))
w = ARCTaskSolverWorkflow(
timeout=None, verbose=False, llm=OpenAI("gpt-4o")
)

if not self._handler: # start a new solver
handler = w.run(task=task)

else: # continuing from past Workflow execution

# need to reset this queue otherwise will use nested event loops
self._handler.ctx._streaming_queue = asyncio.Queue()

Expand Down Expand Up @@ -125,6 +129,7 @@ async def handle_prediction_click(self) -> None:

res: WorkflowOutput = await handler

handler = cast(WorkflowHandler, handler)
self._handler = handler
self._passing_results.append(res.passing)
self._attempts = res.attempts
Expand All @@ -147,7 +152,7 @@ async def handle_prediction_click(self) -> None:
def task_file_names(self) -> List[str]:
return listdir(self._data_path)

def load_task(self, selected_task: str) -> Dict:
def load_task(self, selected_task: str) -> Any:
import json

task_path = Path(self._data_path, selected_task)
Expand All @@ -160,18 +165,17 @@ def load_task(self, selected_task: str) -> Dict:
def passing(self) -> Optional[bool]:
if self._passing_results:
return self._passing_results[-1]
return
return None

@property
def attempts_history_df(
self,
) -> pd.DataFrame:

if self._attempts:
attempt_number_list = []
passings = []
rationales = []
predictions = []
attempt_number_list: List[int] = []
passings: List[str] = []
rationales: List[str] = []
predictions: List[str] = []
for ix, (a, passing) in enumerate(
zip(self._attempts, self._passing_results)
):
Expand All @@ -192,7 +196,9 @@ def attempts_history_df(

def handle_workflow_run_selection(self) -> None:
selected_rows = (
st.session_state.get("attempts_history_df").get("selection").get("rows")
st.session_state.get("attempts_history_df")
.get("selection")
.get("rows")
)
if selected_rows:
row_ix = selected_rows[0]
Expand Down
38 changes: 23 additions & 15 deletions arc_finetuning_st/workflows/arc_task_solver.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
from typing import Dict, List
from typing import Any, Dict, List

from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.llms import LLM
from llama_index.core.workflow import (
Workflow,
Context,
StartEvent,
StopEvent,
Workflow,
step,
)
from llama_index.core.llms import LLM

from arc_finetuning_st.workflows.events import (
EvaluationEvent,
FormatTaskEvent,
PredictionEvent,
EvaluationEvent,
)
from arc_finetuning_st.workflows.models import Correction, Critique, Prediction
from arc_finetuning_st.workflows.prompts import (
REFLECTION_PROMPT_TEMPLATE,
PREDICTION_PROMPT_TEMPLATE,
CORRECTION_PROMPT_TEMPLATE,
PREDICTION_PROMPT_TEMPLATE,
REFLECTION_PROMPT_TEMPLATE,
)
from arc_finetuning_st.workflows.models import Prediction, Correction, Critique

example_template = """===
EXAMPLE
Expand All @@ -37,14 +39,15 @@ class WorkflowOutput(BaseModel):


class ARCTaskSolverWorkflow(Workflow):

def __init__(self, llm: LLM, max_attempts: int = 3, **kwargs) -> None:
def __init__(self, llm: LLM, max_attempts: int = 3, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.llm = llm
self._max_attempts = max_attempts

@step
async def format_task(self, ctx: Context, ev: StartEvent) -> FormatTaskEvent:
async def format_task(
self, ctx: Context, ev: StartEvent
) -> FormatTaskEvent:
ctx.write_event_to_stream(ev)

def _format_row(row: List[int]) -> str:
Expand Down Expand Up @@ -92,7 +95,8 @@ async def prediction(
)
attempts.append(
Prediction(
rationale=prompt_vars["critique"], prediction=corr.correction
rationale=prompt_vars["critique"],
prediction=corr.correction,
)
)
else:
Expand All @@ -106,7 +110,9 @@ async def prediction(
return PredictionEvent()

@step
async def evaluation(self, ctx: Context, ev: PredictionEvent) -> EvaluationEvent:
async def evaluation(
self, ctx: Context, ev: PredictionEvent
) -> EvaluationEvent:
ctx.write_event_to_stream(ev)
task = await ctx.get("task")
attempts: List[Prediction] = await ctx.get("attempts")
Expand Down Expand Up @@ -142,9 +148,10 @@ async def reflection(self, ctx: Context, ev: EvaluationEvent) -> StopEvent:
return StopEvent(result=result)


async def _test_workflow():
async def _test_workflow() -> None:
import json
from pathlib import Path

from llama_index.llms.openai import OpenAI

task_path = Path(
Expand All @@ -153,8 +160,9 @@ async def _test_workflow():
with open(task_path) as f:
task = json.load(f)

w = ARCTaskSolverWorkflow(timeout=None, verbose=False, llm=OpenAI("gpt-4o"))
w.add_workflows(human_input_workflow=HumanInputWorkflow())
w = ARCTaskSolverWorkflow(
timeout=None, verbose=False, llm=OpenAI("gpt-4o")
)
attempts = await w.run(task=task)

print(attempts)
Expand Down
Loading

0 comments on commit 663063c

Please sign in to comment.