Skip to content

Commit

Permalink
Attempt data class re-factor
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Sep 20, 2024
1 parent 7dfd639 commit 7cf0702
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 91 deletions.
1 change: 1 addition & 0 deletions arc_finetuning_st/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def abort_solving() -> None:
"attempt #",
"passing",
"rationale",
"critique",
),
key="attempts_history_df",
)
21 changes: 11 additions & 10 deletions arc_finetuning_st/streamlit/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
ARCTaskSolverWorkflow,
WorkflowOutput,
)
from arc_finetuning_st.workflows.models import Prediction
from arc_finetuning_st.workflows.models import Attempt, Prediction

logger = logging.getLogger(__name__)


class Controller:
def __init__(self) -> None:
self._handler: Optional[WorkflowHandler] = None
self._attempts: List[Prediction] = []
self._attempts: List[Attempt] = []
self._passing_results: List[bool] = []
parent_path = Path(__file__).parents[2].absolute()
self._data_path = Path(parent_path, "data", "training")
Expand Down Expand Up @@ -137,7 +137,7 @@ async def handle_prediction_click(self) -> None:
# update streamlit states
prompt_vars = await self._handler.ctx.get("prompt_vars")
grid = Prediction.prediction_str_to_int_array(
prediction=res.attempts[-1].prediction
prediction=str(res.attempts[-1].prediction)
)
prediction_fig = Controller.plot_grid(grid, kind="prediction")
st.session_state.prediction = prediction_fig
Expand Down Expand Up @@ -175,19 +175,20 @@ def attempts_history_df(
attempt_number_list: List[int] = []
passings: List[str] = []
rationales: List[str] = []
critiques: List[str] = []
predictions: List[str] = []
for ix, (a, passing) in enumerate(
zip(self._attempts, self._passing_results)
):
passings = ["✅" if passing else "❌"] + passings
rationales = [a.rationale] + rationales
predictions = [a.prediction] + predictions
for ix, a in enumerate(self._attempts):
passings = ["✅" if a.passing else "❌"] + passings
rationales = [a.prediction.rationale] + rationales
predictions = [str(a.prediction)] + predictions
critiques = [str(a.critique)] + critiques
attempt_number_list = [ix + 1] + attempt_number_list
return pd.DataFrame(
{
"attempt #": attempt_number_list,
"passing": passings,
"rationale": rationales,
"critique": critiques,
# hidden from UI
"prediction": predictions,
}
Expand All @@ -209,6 +210,6 @@ def handle_workflow_run_selection(self) -> None:
)
prediction_fig = Controller.plot_grid(grid, kind="prediction")
st.session_state.prediction = prediction_fig
st.session_state.critique = df_row["rationale"]
st.session_state.critique = df_row["critique"]
metric_value = df_row["passing"]
st.session_state.metric_value = metric_value
53 changes: 29 additions & 24 deletions arc_finetuning_st/workflows/arc_task_solver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, cast

from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.llms import LLM
Expand All @@ -15,7 +15,7 @@
FormatTaskEvent,
PredictionEvent,
)
from arc_finetuning_st.workflows.models import Correction, Critique, Prediction
from arc_finetuning_st.workflows.models import Attempt, Critique, Prediction
from arc_finetuning_st.workflows.prompts import (
CORRECTION_PROMPT_TEMPLATE,
PREDICTION_PROMPT_TEMPLATE,
Expand All @@ -35,7 +35,7 @@

class WorkflowOutput(BaseModel):
passing: bool
attempts: List[Prediction]
attempts: List[Attempt]


class ARCTaskSolverWorkflow(Workflow):
Expand Down Expand Up @@ -90,21 +90,18 @@ async def prediction(
if "critique" in prompt_vars:
# generating a correction from last Workflow run
attempts = await ctx.get("attempts")
corr: Correction = await self.llm.astructured_predict(
Correction, CORRECTION_PROMPT_TEMPLATE, **prompt_vars
)
attempts.append(
Prediction(
rationale=prompt_vars["critique"],
prediction=corr.correction,
)
attempts = cast(List[Attempt], attempts)
correction: Prediction = await self.llm.astructured_predict(
Prediction, CORRECTION_PROMPT_TEMPLATE, **prompt_vars
)

attempts.append(Attempt(prediction=correction))
else:
# starting a new correction with no previous Workflow runs
pred: Prediction = await self.llm.astructured_predict(
Prediction, PREDICTION_PROMPT_TEMPLATE, **prompt_vars
)
attempts = [pred]
attempts = [Attempt(prediction=pred)]

await ctx.set("attempts", attempts)
return PredictionEvent()
Expand All @@ -115,35 +112,43 @@ async def evaluation(
) -> EvaluationEvent:
ctx.write_event_to_stream(ev)
task = await ctx.get("task")
attempts: List[Prediction] = await ctx.get("attempts")
final_attempt = attempts[-1]
prediction_str = final_attempt.prediction
prediction = Prediction.prediction_str_to_int_array(prediction_str)
attempts: List[Attempt] = await ctx.get("attempts")
latest_prediction = attempts[-1].prediction
latest_prediction_as_array = Prediction.prediction_str_to_int_array(
str(latest_prediction)
)
ground_truth = task["test"][0]["output"]

return EvaluationEvent(passing=(prediction == ground_truth))
return EvaluationEvent(
passing=(latest_prediction_as_array == ground_truth)
)

@step
async def reflection(self, ctx: Context, ev: EvaluationEvent) -> StopEvent:
ctx.write_event_to_stream(ev)
attempts: List[Prediction] = await ctx.get("attempts")
attempts = await ctx.get("attempts")
attempts = cast(List[Attempt], attempts)
latest_attempt = attempts[-1]

# check if passing
if not ev.passing:
prompt_vars = await ctx.get("prompt_vars")
prompt_vars.update(
predicted_output=attempts[-1].prediction
) # use last attempt
prompt_vars.update(predicted_output=str(latest_attempt.prediction))

# generate critique
critique_model: Critique = await self.llm.astructured_predict(
critique: Critique = await self.llm.astructured_predict(
Critique, REFLECTION_PROMPT_TEMPLATE, **prompt_vars
)

# generate correction
prompt_vars.update(critique=critique_model.critique)
# update states
latest_attempt.critique = critique
prompt_vars.update(critique=str(critique))
await ctx.set("prompt_vars", prompt_vars)

latest_attempt.passing = ev.passing
attempts[-1] = latest_attempt
await ctx.set("attempts", attempts)

result = WorkflowOutput(passing=ev.passing, attempts=attempts)
return StopEvent(result=result)

Expand Down
46 changes: 0 additions & 46 deletions arc_finetuning_st/workflows/human_input.py

This file was deleted.

28 changes: 21 additions & 7 deletions arc_finetuning_st/workflows/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from typing import List
import uuid
from typing import List, Optional

from llama_index.core.bridge.pydantic import BaseModel, Field


class Prediction(BaseModel):
"""Prediction data class for LLM structured predict."""

rationale: str = Field(
description="Brief description of pattern and why prediction was made. Limit to 150 words."
description="Brief description of pattern and why prediction was made. Limit to 250 words."
)
prediction: str = Field(
description="Predicted grid as a single string. e.g. '0,0,1\n1,1,1\n0,0,0'"
)

def __str__(self) -> str:
return self.prediction

@staticmethod
def prediction_str_to_int_array(prediction: str) -> List[List[int]]:
return [
Expand All @@ -19,12 +25,20 @@ def prediction_str_to_int_array(prediction: str) -> List[List[int]]:


class Critique(BaseModel):
"""Critique data class for LLM structured predict."""

critique: str = Field(
description="Brief critique of the previous prediction and rationale. Limit to 150 words."
description="Brief critique of the previous prediction and rationale. Limit to 250 words."
)

def __str__(self) -> str:
return self.critique

class Correction(BaseModel):
correction: str = Field(
description="Corrected prediction as a single string. e.g. '0,0,1\n1,1,1\n0,0,0'"
)

class Attempt(BaseModel):
"""Container class of a single solution attempt."""

id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
prediction: Prediction
critique: Optional[Critique] = Field(default=None)
passing: bool = Field(default=False)
8 changes: 4 additions & 4 deletions arc_finetuning_st/workflows/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
REFLECTION_PROMPT_TEMPLATE = PromptTemplate(
"""You are a bot that is very good at solving puzzles. Below is a list of input and output pairs that share a
common pattern. The TEST INPUT also shares this common pattern, and you've previously predicted the output for it.
Your task now is critique the prediction on why it might not fit the pattern inherent in the example input/output pairs.
Your task now is critique the latest prediction on why it might not fit the pattern inherent in the example input/output pairs.
EXAMPLES:
{examples}
TEST INPUT:
{test_input}
PREDICTED OUTPUT:
LATEST PREDICTED OUTPUT:
{predicted_output}
OUTPUT FORMAT:
Expand All @@ -56,15 +56,15 @@
TEST INPUT:
{test_input}
PREDICTED OUTPUT:
LATEST PREDICTED OUTPUT:
{predicted_output}
CRITIQUE:
{critique}
OUTPUT FORMAT:
{{
"correction": ...
"prediction": ...
}}
Return your response in JSON format given above. DO NOT RETURN markdown code."""
Expand Down

0 comments on commit 7cf0702

Please sign in to comment.