From 7cf0702a14650e291819a7b23eab9f9b9c4d0919 Mon Sep 17 00:00:00 2001 From: Andrei Fajardo Date: Fri, 20 Sep 2024 11:38:24 -0400 Subject: [PATCH] Attempt data class re-factor --- arc_finetuning_st/streamlit/app.py | 1 + arc_finetuning_st/streamlit/controller.py | 21 ++++---- .../workflows/arc_task_solver.py | 53 ++++++++++--------- arc_finetuning_st/workflows/human_input.py | 46 ---------------- arc_finetuning_st/workflows/models.py | 28 +++++++--- arc_finetuning_st/workflows/prompts.py | 8 +-- 6 files changed, 66 insertions(+), 91 deletions(-) delete mode 100644 arc_finetuning_st/workflows/human_input.py diff --git a/arc_finetuning_st/streamlit/app.py b/arc_finetuning_st/streamlit/app.py index ddcc1fd..d02a614 100644 --- a/arc_finetuning_st/streamlit/app.py +++ b/arc_finetuning_st/streamlit/app.py @@ -165,6 +165,7 @@ def abort_solving() -> None: "attempt #", "passing", "rationale", + "critique", ), key="attempts_history_df", ) diff --git a/arc_finetuning_st/streamlit/controller.py b/arc_finetuning_st/streamlit/controller.py index b54a380..2f25a1d 100644 --- a/arc_finetuning_st/streamlit/controller.py +++ b/arc_finetuning_st/streamlit/controller.py @@ -14,7 +14,7 @@ ARCTaskSolverWorkflow, WorkflowOutput, ) -from arc_finetuning_st.workflows.models import Prediction +from arc_finetuning_st.workflows.models import Attempt, Prediction logger = logging.getLogger(__name__) @@ -22,7 +22,7 @@ class Controller: def __init__(self) -> None: self._handler: Optional[WorkflowHandler] = None - self._attempts: List[Prediction] = [] + self._attempts: List[Attempt] = [] self._passing_results: List[bool] = [] parent_path = Path(__file__).parents[2].absolute() self._data_path = Path(parent_path, "data", "training") @@ -137,7 +137,7 @@ async def handle_prediction_click(self) -> None: # update streamlit states prompt_vars = await self._handler.ctx.get("prompt_vars") grid = Prediction.prediction_str_to_int_array( - prediction=res.attempts[-1].prediction + prediction=str(res.attempts[-1].prediction) ) prediction_fig = Controller.plot_grid(grid, kind="prediction") st.session_state.prediction = prediction_fig @@ -175,19 +175,20 @@ def attempts_history_df( attempt_number_list: List[int] = [] passings: List[str] = [] rationales: List[str] = [] + critiques: List[str] = [] predictions: List[str] = [] - for ix, (a, passing) in enumerate( - zip(self._attempts, self._passing_results) - ): - passings = ["✅" if passing else "❌"] + passings - rationales = [a.rationale] + rationales - predictions = [a.prediction] + predictions + for ix, a in enumerate(self._attempts): + passings = ["✅" if a.passing else "❌"] + passings + rationales = [a.prediction.rationale] + rationales + predictions = [str(a.prediction)] + predictions + critiques = [str(a.critique)] + critiques attempt_number_list = [ix + 1] + attempt_number_list return pd.DataFrame( { "attempt #": attempt_number_list, "passing": passings, "rationale": rationales, + "critique": critiques, # hidden from UI "prediction": predictions, } @@ -209,6 +210,6 @@ def handle_workflow_run_selection(self) -> None: ) prediction_fig = Controller.plot_grid(grid, kind="prediction") st.session_state.prediction = prediction_fig - st.session_state.critique = df_row["rationale"] + st.session_state.critique = df_row["critique"] metric_value = df_row["passing"] st.session_state.metric_value = metric_value diff --git a/arc_finetuning_st/workflows/arc_task_solver.py b/arc_finetuning_st/workflows/arc_task_solver.py index b539f78..b2414b8 100644 --- a/arc_finetuning_st/workflows/arc_task_solver.py +++ b/arc_finetuning_st/workflows/arc_task_solver.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, cast from llama_index.core.bridge.pydantic import BaseModel from llama_index.core.llms import LLM @@ -15,7 +15,7 @@ FormatTaskEvent, PredictionEvent, ) -from arc_finetuning_st.workflows.models import Correction, Critique, Prediction +from arc_finetuning_st.workflows.models import Attempt, Critique, Prediction from arc_finetuning_st.workflows.prompts import ( CORRECTION_PROMPT_TEMPLATE, PREDICTION_PROMPT_TEMPLATE, @@ -35,7 +35,7 @@ class WorkflowOutput(BaseModel): passing: bool - attempts: List[Prediction] + attempts: List[Attempt] class ARCTaskSolverWorkflow(Workflow): @@ -90,21 +90,18 @@ async def prediction( if "critique" in prompt_vars: # generating a correction from last Workflow run attempts = await ctx.get("attempts") - corr: Correction = await self.llm.astructured_predict( - Correction, CORRECTION_PROMPT_TEMPLATE, **prompt_vars - ) - attempts.append( - Prediction( - rationale=prompt_vars["critique"], - prediction=corr.correction, - ) + attempts = cast(List[Attempt], attempts) + correction: Prediction = await self.llm.astructured_predict( + Prediction, CORRECTION_PROMPT_TEMPLATE, **prompt_vars ) + + attempts.append(Attempt(prediction=correction)) else: # starting a new correction with no previous Workflow runs pred: Prediction = await self.llm.astructured_predict( Prediction, PREDICTION_PROMPT_TEMPLATE, **prompt_vars ) - attempts = [pred] + attempts = [Attempt(prediction=pred)] await ctx.set("attempts", attempts) return PredictionEvent() @@ -115,35 +112,43 @@ async def evaluation( ) -> EvaluationEvent: ctx.write_event_to_stream(ev) task = await ctx.get("task") - attempts: List[Prediction] = await ctx.get("attempts") - final_attempt = attempts[-1] - prediction_str = final_attempt.prediction - prediction = Prediction.prediction_str_to_int_array(prediction_str) + attempts: List[Attempt] = await ctx.get("attempts") + latest_prediction = attempts[-1].prediction + latest_prediction_as_array = Prediction.prediction_str_to_int_array( + str(latest_prediction) + ) ground_truth = task["test"][0]["output"] - return EvaluationEvent(passing=(prediction == ground_truth)) + return EvaluationEvent( + passing=(latest_prediction_as_array == ground_truth) + ) @step async def reflection(self, ctx: Context, ev: EvaluationEvent) -> StopEvent: ctx.write_event_to_stream(ev) - attempts: List[Prediction] = await ctx.get("attempts") + attempts = await ctx.get("attempts") + attempts = cast(List[Attempt], attempts) + latest_attempt = attempts[-1] # check if passing if not ev.passing: prompt_vars = await ctx.get("prompt_vars") - prompt_vars.update( - predicted_output=attempts[-1].prediction - ) # use last attempt + prompt_vars.update(predicted_output=str(latest_attempt.prediction)) # generate critique - critique_model: Critique = await self.llm.astructured_predict( + critique: Critique = await self.llm.astructured_predict( Critique, REFLECTION_PROMPT_TEMPLATE, **prompt_vars ) - # generate correction - prompt_vars.update(critique=critique_model.critique) + # update states + latest_attempt.critique = critique + prompt_vars.update(critique=str(critique)) await ctx.set("prompt_vars", prompt_vars) + latest_attempt.passing = ev.passing + attempts[-1] = latest_attempt + await ctx.set("attempts", attempts) + result = WorkflowOutput(passing=ev.passing, attempts=attempts) return StopEvent(result=result) diff --git a/arc_finetuning_st/workflows/human_input.py b/arc_finetuning_st/workflows/human_input.py deleted file mode 100644 index 402868a..0000000 --- a/arc_finetuning_st/workflows/human_input.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, Awaitable, Protocol, runtime_checkable - -from llama_index.core.workflow import StartEvent, StopEvent, Workflow, step - - -@runtime_checkable -class HumanInputFn(Protocol): - """Protocol for getting human input.""" - - def __call__(self, prompt: str, **kwargs: Any) -> Awaitable[str]: - ... - - -async def default_human_input_fn(prompt: str, **kwargs: Any) -> str: - return input(prompt) - - -class HumanInputWorkflow(Workflow): - def __init__( - self, input: HumanInputFn = default_human_input_fn, **kwargs: Any - ): - super().__init__(**kwargs) - self.input = input - - @step - async def human_input(self, ev: StartEvent) -> StopEvent: - prompt = str(ev.get("prompt", "")) - critique = str(ev.get("critique", "")) - prediction_str = str(ev.get("prediction_str", "")) - human_input = await self.input( - prompt, critique=critique, prediction_str=prediction_str - ) - return StopEvent(result=human_input) - - -# Local Testing -async def _test_workflow() -> None: - w = HumanInputWorkflow(timeout=None, verbose=False) - result = await w.run(prompt="How old are you?\n\n") - print(str(result)) - - -if __name__ == "__main__": - import asyncio - - asyncio.run(_test_workflow()) diff --git a/arc_finetuning_st/workflows/models.py b/arc_finetuning_st/workflows/models.py index f4c1ed0..f872a7f 100644 --- a/arc_finetuning_st/workflows/models.py +++ b/arc_finetuning_st/workflows/models.py @@ -1,16 +1,22 @@ -from typing import List +import uuid +from typing import List, Optional from llama_index.core.bridge.pydantic import BaseModel, Field class Prediction(BaseModel): + """Prediction data class for LLM structured predict.""" + rationale: str = Field( - description="Brief description of pattern and why prediction was made. Limit to 150 words." + description="Brief description of pattern and why prediction was made. Limit to 250 words." ) prediction: str = Field( description="Predicted grid as a single string. e.g. '0,0,1\n1,1,1\n0,0,0'" ) + def __str__(self) -> str: + return self.prediction + @staticmethod def prediction_str_to_int_array(prediction: str) -> List[List[int]]: return [ @@ -19,12 +25,20 @@ def prediction_str_to_int_array(prediction: str) -> List[List[int]]: class Critique(BaseModel): + """Critique data class for LLM structured predict.""" + critique: str = Field( - description="Brief critique of the previous prediction and rationale. Limit to 150 words." + description="Brief critique of the previous prediction and rationale. Limit to 250 words." ) + def __str__(self) -> str: + return self.critique -class Correction(BaseModel): - correction: str = Field( - description="Corrected prediction as a single string. e.g. '0,0,1\n1,1,1\n0,0,0'" - ) + +class Attempt(BaseModel): + """Container class of a single solution attempt.""" + + id_: str = Field(default_factory=lambda: str(uuid.uuid4())) + prediction: Prediction + critique: Optional[Critique] = Field(default=None) + passing: bool = Field(default=False) diff --git a/arc_finetuning_st/workflows/prompts.py b/arc_finetuning_st/workflows/prompts.py index b4003f0..c379e55 100644 --- a/arc_finetuning_st/workflows/prompts.py +++ b/arc_finetuning_st/workflows/prompts.py @@ -23,7 +23,7 @@ REFLECTION_PROMPT_TEMPLATE = PromptTemplate( """You are a bot that is very good at solving puzzles. Below is a list of input and output pairs that share a common pattern. The TEST INPUT also shares this common pattern, and you've previously predicted the output for it. -Your task now is critique the prediction on why it might not fit the pattern inherent in the example input/output pairs. +Your task now is critique the latest prediction on why it might not fit the pattern inherent in the example input/output pairs. EXAMPLES: {examples} @@ -31,7 +31,7 @@ TEST INPUT: {test_input} -PREDICTED OUTPUT: +LATEST PREDICTED OUTPUT: {predicted_output} OUTPUT FORMAT: @@ -56,7 +56,7 @@ TEST INPUT: {test_input} -PREDICTED OUTPUT: +LATEST PREDICTED OUTPUT: {predicted_output} CRITIQUE: @@ -64,7 +64,7 @@ OUTPUT FORMAT: {{ - "correction": ... + "prediction": ... }} Return your response in JSON format given above. DO NOT RETURN markdown code."""