diff --git a/arc_finetuning_st/streamlit/app.py b/arc_finetuning_st/streamlit/app.py index 19cf270..a8b17f7 100644 --- a/arc_finetuning_st/streamlit/app.py +++ b/arc_finetuning_st/streamlit/app.py @@ -23,6 +23,8 @@ def startup() -> Tuple[Controller,]: st.session_state["disable_start_button"] = False if "disable_abort_button" not in st.session_state: st.session_state["disable_abort_button"] = True +if "disable_preview_button" not in st.session_state: + st.session_state["disable_preview_button"] = True if "metric_value" not in st.session_state: st.session_state["metric_value"] = "N/A" @@ -34,7 +36,7 @@ def startup() -> Tuple[Controller,]: task_selection = st.radio( label="Tasks", options=controller.task_file_names, - index=None, + index=0, on_change=controller.selectbox_selection_change_handler, key="selected_task", ) @@ -72,8 +74,8 @@ def startup() -> Tuple[Controller,]: with test_col: - header_col, start_col, abort_col = st.columns( - [4, 1, 1], vertical_alignment="bottom", gap="small" + header_col, start_col, preview_col = st.columns( + [4, 1, 2], vertical_alignment="bottom", gap="small" ) with header_col: st.subheader("Test") @@ -85,23 +87,15 @@ def startup() -> Tuple[Controller,]: type="primary", disabled=st.session_state.get("disable_start_button"), ) - with abort_col: - - @st.dialog("Are you sure you want to abort the session?") - def abort_solving() -> None: - st.write( - "Confirm that you want to abort the session by clicking 'confirm' button below." - ) - if st.button("Confirm"): - controller.reset() - st.rerun() - + with preview_col: st.button( - "abort", - on_click=abort_solving, + "fine-tuning example", + on_click=async_to_sync(controller.handle_prediction_click), use_container_width=True, - disabled=st.session_state.get("disable_abort_button"), + disabled=st.session_state.get("disable_preview_button"), + key="preview_button", ) + with st.container(): selected_task = st.session_state.selected_task if selected_task: @@ -134,7 +128,7 @@ def abort_solving() -> None: # metrics and past attempts with st.container(): - metric_col, attempts_history_col = st.columns( + metric_col, critique_col = st.columns( [1, 7], vertical_alignment="top" ) with metric_col: @@ -142,43 +136,60 @@ def abort_solving() -> None: st.markdown(body="Passing") st.markdown(body=f"# {metric_value}") - with attempts_history_col: - st.markdown(body="Past Attempts") - st.dataframe( - controller.attempts_history_df, - hide_index=True, - selection_mode="single-row", - on_select=controller.handle_workflow_run_selection, - column_order=( - "attempt #", - "passing", - "critique", - "rationale", + with critique_col: + st.markdown(body="Critique of Attempt") + st.text_area( + label="This critique is passed to the LLM to generate a new prediction.", + key="critique", + help=( + "An LLM was prompted to critique the prediction on why it might not fit the pattern. " + "This critique is passed in the PROMPT in the next prediction attempt. " + "Feel free to make edits to the critique or use your own." ), - key="attempts_history_df", - use_container_width=True, - height=100, ) - with st.container(): - # console - st.markdown(body="Critique of Attempt") - st.text_area( - label="This critique is passed to the LLM to generate a new prediction.", - key="critique", - help=( - "An LLM was prompted to critique the prediction on why it might not fit the pattern. " - "This critique is passed in the PROMPT in the next prediction attempt. " - "Feel free to make edits to the critique or use your own." + with st.expander("Past Attempts"): + st.dataframe( + controller.attempts_history_df, + hide_index=True, + selection_mode="single-row", + on_select=controller.handle_workflow_run_selection, + column_order=( + "attempt #", + "passing", + "critique", + "rationale", ), + key="attempts_history_df", + use_container_width=True, ) - # controls with st.container(): - st.button( - "continue", - on_click=async_to_sync(controller.handle_prediction_click), - use_container_width=True, - disabled=st.session_state.get("disable_continue_button"), - key="continue_button", - ) + continue_col, abort_col = st.columns([3, 1]) + with continue_col: + st.button( + "continue", + on_click=async_to_sync(controller.handle_prediction_click), + use_container_width=True, + disabled=st.session_state.get("disable_continue_button"), + key="continue_button", + type="primary", + ) + + with abort_col: + + @st.dialog("Are you sure you want to abort the session?") + def abort_solving() -> None: + st.write( + "Confirm that you want to abort the session by clicking 'confirm' button below." + ) + if st.button("Confirm"): + controller.reset() + st.rerun() + + st.button( + "abort", + on_click=abort_solving, + use_container_width=True, + disabled=st.session_state.get("disable_abort_button"), + ) diff --git a/arc_finetuning_st/streamlit/controller.py b/arc_finetuning_st/streamlit/controller.py index 0093f00..ce792dd 100644 --- a/arc_finetuning_st/streamlit/controller.py +++ b/arc_finetuning_st/streamlit/controller.py @@ -1,5 +1,6 @@ import asyncio import logging +import uuid from os import listdir from pathlib import Path from typing import Any, List, Literal, Optional, cast @@ -26,12 +27,14 @@ def __init__(self) -> None: self._passing_results: List[bool] = [] parent_path = Path(__file__).parents[2].absolute() self._data_path = Path(parent_path, "data", "training") + self._attempts_history_df_key = str(uuid.uuid4()) def reset(self) -> None: # clear prediction st.session_state.prediction = None st.session_state.disable_continue_button = True st.session_state.disable_abort_button = True + st.session_state.disable_preview_button = True st.session_state.disable_start_button = False st.session_state.critique = None st.session_state.metric_value = "N/A" @@ -48,7 +51,8 @@ def selectbox_selection_change_handler(self) -> None: @staticmethod def plot_grid( - grid: List[List[int]], kind: Literal["input", "output", "prediction"] + grid: List[List[int]], + kind: Literal["input", "output", "prediction", "latest prediction"], ) -> Any: m = len(grid) n = len(grid[0]) @@ -124,11 +128,14 @@ async def handle_prediction_click(self) -> None: grid = Prediction.prediction_str_to_int_array( prediction=str(res.attempts[-1].prediction) ) - prediction_fig = Controller.plot_grid(grid, kind="prediction") + prediction_fig = Controller.plot_grid( + grid, kind="latest prediction" + ) st.session_state.prediction = prediction_fig st.session_state.critique = str(res.attempts[-1].critique) st.session_state.disable_continue_button = False st.session_state.disable_abort_button = False + st.session_state.disable_preview_button = False st.session_state.disable_start_button = True metric_value = "✅" if res.passing else "❌" st.session_state.metric_value = metric_value @@ -152,6 +159,10 @@ def passing(self) -> Optional[bool]: return self._passing_results[-1] return None + @property + def attempts_history_df_key(self) -> str: + return self._attempts_history_df_key + @property def attempts_history_df( self, @@ -190,11 +201,23 @@ def attempts_history_df( ) def handle_workflow_run_selection(self) -> None: + @st.dialog("Past Attempt") + def _display_attempt( + fig: Any, rationale: str, critique: str, passing: bool + ) -> None: + st.plotly_chart( + prediction_fig, + use_container_width=True, + key="prediction", + ) + st.write(rationale) + selected_rows = ( st.session_state.get("attempts_history_df") .get("selection") .get("rows") ) + if selected_rows: row_ix = selected_rows[0] df_row = self.attempts_history_df.iloc[row_ix] @@ -203,7 +226,10 @@ def handle_workflow_run_selection(self) -> None: prediction=df_row["prediction"] ) prediction_fig = Controller.plot_grid(grid, kind="prediction") - st.session_state.prediction = prediction_fig - st.session_state.critique = df_row["critique"] - metric_value = df_row["passing"] - st.session_state.metric_value = metric_value + + _display_attempt( + fig=prediction_fig, + rationale=df_row["rationale"], + critique=df_row["critique"], + passing=df_row["passing"], + )