Skip to content

Commit

Permalink
UI improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Sep 20, 2024
1 parent 79e8181 commit 55973c0
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 57 deletions.
113 changes: 62 additions & 51 deletions arc_finetuning_st/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def startup() -> Tuple[Controller,]:
st.session_state["disable_start_button"] = False
if "disable_abort_button" not in st.session_state:
st.session_state["disable_abort_button"] = True
if "disable_preview_button" not in st.session_state:
st.session_state["disable_preview_button"] = True
if "metric_value" not in st.session_state:
st.session_state["metric_value"] = "N/A"

Expand All @@ -34,7 +36,7 @@ def startup() -> Tuple[Controller,]:
task_selection = st.radio(
label="Tasks",
options=controller.task_file_names,
index=None,
index=0,
on_change=controller.selectbox_selection_change_handler,
key="selected_task",
)
Expand Down Expand Up @@ -72,8 +74,8 @@ def startup() -> Tuple[Controller,]:


with test_col:
header_col, start_col, abort_col = st.columns(
[4, 1, 1], vertical_alignment="bottom", gap="small"
header_col, start_col, preview_col = st.columns(
[4, 1, 2], vertical_alignment="bottom", gap="small"
)
with header_col:
st.subheader("Test")
Expand All @@ -85,23 +87,15 @@ def startup() -> Tuple[Controller,]:
type="primary",
disabled=st.session_state.get("disable_start_button"),
)
with abort_col:

@st.dialog("Are you sure you want to abort the session?")
def abort_solving() -> None:
st.write(
"Confirm that you want to abort the session by clicking 'confirm' button below."
)
if st.button("Confirm"):
controller.reset()
st.rerun()

with preview_col:
st.button(
"abort",
on_click=abort_solving,
"fine-tuning example",
on_click=async_to_sync(controller.handle_prediction_click),
use_container_width=True,
disabled=st.session_state.get("disable_abort_button"),
disabled=st.session_state.get("disable_preview_button"),
key="preview_button",
)

with st.container():
selected_task = st.session_state.selected_task
if selected_task:
Expand Down Expand Up @@ -134,51 +128,68 @@ def abort_solving() -> None:

# metrics and past attempts
with st.container():
metric_col, attempts_history_col = st.columns(
metric_col, critique_col = st.columns(
[1, 7], vertical_alignment="top"
)
with metric_col:
metric_value = st.session_state.get("metric_value")
st.markdown(body="Passing")
st.markdown(body=f"# {metric_value}")

with attempts_history_col:
st.markdown(body="Past Attempts")
st.dataframe(
controller.attempts_history_df,
hide_index=True,
selection_mode="single-row",
on_select=controller.handle_workflow_run_selection,
column_order=(
"attempt #",
"passing",
"critique",
"rationale",
with critique_col:
st.markdown(body="Critique of Attempt")
st.text_area(
label="This critique is passed to the LLM to generate a new prediction.",
key="critique",
help=(
"An LLM was prompted to critique the prediction on why it might not fit the pattern. "
"This critique is passed in the PROMPT in the next prediction attempt. "
"Feel free to make edits to the critique or use your own."
),
key="attempts_history_df",
use_container_width=True,
height=100,
)

with st.container():
# console
st.markdown(body="Critique of Attempt")
st.text_area(
label="This critique is passed to the LLM to generate a new prediction.",
key="critique",
help=(
"An LLM was prompted to critique the prediction on why it might not fit the pattern. "
"This critique is passed in the PROMPT in the next prediction attempt. "
"Feel free to make edits to the critique or use your own."
with st.expander("Past Attempts"):
st.dataframe(
controller.attempts_history_df,
hide_index=True,
selection_mode="single-row",
on_select=controller.handle_workflow_run_selection,
column_order=(
"attempt #",
"passing",
"critique",
"rationale",
),
key="attempts_history_df",
use_container_width=True,
)

# controls
with st.container():
st.button(
"continue",
on_click=async_to_sync(controller.handle_prediction_click),
use_container_width=True,
disabled=st.session_state.get("disable_continue_button"),
key="continue_button",
)
continue_col, abort_col = st.columns([3, 1])
with continue_col:
st.button(
"continue",
on_click=async_to_sync(controller.handle_prediction_click),
use_container_width=True,
disabled=st.session_state.get("disable_continue_button"),
key="continue_button",
type="primary",
)

with abort_col:

@st.dialog("Are you sure you want to abort the session?")
def abort_solving() -> None:
st.write(
"Confirm that you want to abort the session by clicking 'confirm' button below."
)
if st.button("Confirm"):
controller.reset()
st.rerun()

st.button(
"abort",
on_click=abort_solving,
use_container_width=True,
disabled=st.session_state.get("disable_abort_button"),
)
38 changes: 32 additions & 6 deletions arc_finetuning_st/streamlit/controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import uuid
from os import listdir
from pathlib import Path
from typing import Any, List, Literal, Optional, cast
Expand All @@ -26,12 +27,14 @@ def __init__(self) -> None:
self._passing_results: List[bool] = []
parent_path = Path(__file__).parents[2].absolute()
self._data_path = Path(parent_path, "data", "training")
self._attempts_history_df_key = str(uuid.uuid4())

def reset(self) -> None:
# clear prediction
st.session_state.prediction = None
st.session_state.disable_continue_button = True
st.session_state.disable_abort_button = True
st.session_state.disable_preview_button = True
st.session_state.disable_start_button = False
st.session_state.critique = None
st.session_state.metric_value = "N/A"
Expand All @@ -48,7 +51,8 @@ def selectbox_selection_change_handler(self) -> None:

@staticmethod
def plot_grid(
grid: List[List[int]], kind: Literal["input", "output", "prediction"]
grid: List[List[int]],
kind: Literal["input", "output", "prediction", "latest prediction"],
) -> Any:
m = len(grid)
n = len(grid[0])
Expand Down Expand Up @@ -124,11 +128,14 @@ async def handle_prediction_click(self) -> None:
grid = Prediction.prediction_str_to_int_array(
prediction=str(res.attempts[-1].prediction)
)
prediction_fig = Controller.plot_grid(grid, kind="prediction")
prediction_fig = Controller.plot_grid(
grid, kind="latest prediction"
)
st.session_state.prediction = prediction_fig
st.session_state.critique = str(res.attempts[-1].critique)
st.session_state.disable_continue_button = False
st.session_state.disable_abort_button = False
st.session_state.disable_preview_button = False
st.session_state.disable_start_button = True
metric_value = "✅" if res.passing else "❌"
st.session_state.metric_value = metric_value
Expand All @@ -152,6 +159,10 @@ def passing(self) -> Optional[bool]:
return self._passing_results[-1]
return None

@property
def attempts_history_df_key(self) -> str:
return self._attempts_history_df_key

@property
def attempts_history_df(
self,
Expand Down Expand Up @@ -190,11 +201,23 @@ def attempts_history_df(
)

def handle_workflow_run_selection(self) -> None:
@st.dialog("Past Attempt")
def _display_attempt(
fig: Any, rationale: str, critique: str, passing: bool
) -> None:
st.plotly_chart(
prediction_fig,
use_container_width=True,
key="prediction",
)
st.write(rationale)

selected_rows = (
st.session_state.get("attempts_history_df")
.get("selection")
.get("rows")
)

if selected_rows:
row_ix = selected_rows[0]
df_row = self.attempts_history_df.iloc[row_ix]
Expand All @@ -203,7 +226,10 @@ def handle_workflow_run_selection(self) -> None:
prediction=df_row["prediction"]
)
prediction_fig = Controller.plot_grid(grid, kind="prediction")
st.session_state.prediction = prediction_fig
st.session_state.critique = df_row["critique"]
metric_value = df_row["passing"]
st.session_state.metric_value = metric_value

_display_attempt(
fig=prediction_fig,
rationale=df_row["rationale"],
critique=df_row["critique"],
passing=df_row["passing"],
)

0 comments on commit 55973c0

Please sign in to comment.