Skip to content

Commit

Permalink
refactor prompts for simplicity
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Sep 20, 2024
1 parent 7cf0702 commit 0f7e2c4
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 38 deletions.
21 changes: 3 additions & 18 deletions arc_finetuning_st/streamlit/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,8 @@ async def handle_prediction_click(self) -> None:

# use the critique and prediction str from streamlit
critique = st.session_state.get("critique")
prompt_vars = await self._handler.ctx.get("prompt_vars")
prompt_vars.update(critique=critique)

# check if selected rows
selected_rows = (
st.session_state.get("attempts_history_df")
.get("selection")
.get("rows")
)
if selected_rows:
row_ix = selected_rows[0]
df_row = self.attempts_history_df.iloc[row_ix]
prediction_str = df_row["prediction"]
prompt_vars.update(predicted_output=prediction_str)

await self._handler.ctx.set("prompt_vars", prompt_vars)
self._attempts[-1].critique = critique
await self._handler.ctx.set("attempts", self._attempts)

# run Workflow
handler = w.run(ctx=self._handler.ctx, task=task)
Expand All @@ -135,13 +121,12 @@ async def handle_prediction_click(self) -> None:
self._attempts = res.attempts

# update streamlit states
prompt_vars = await self._handler.ctx.get("prompt_vars")
grid = Prediction.prediction_str_to_int_array(
prediction=str(res.attempts[-1].prediction)
)
prediction_fig = Controller.plot_grid(grid, kind="prediction")
st.session_state.prediction = prediction_fig
st.session_state.critique = prompt_vars["critique"]
st.session_state.critique = str(res.attempts[-1].critique)
st.session_state.disable_continue_button = False
st.session_state.disable_abort_button = False
st.session_state.disable_start_button = True
Expand Down
52 changes: 40 additions & 12 deletions arc_finetuning_st/workflows/arc_task_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
{output}
"""

past_attempt_template = """◦◦◦
PAST ATTEMPT {past_attempt_number}
PREDICTED_OUTPUT:
{past_predicted_output}
CRITIQUE:
{past_critique}
"""


class WorkflowOutput(BaseModel):
passing: bool
Expand All @@ -44,6 +54,13 @@ def __init__(self, llm: LLM, max_attempts: int = 3, **kwargs: Any) -> None:
self.llm = llm
self._max_attempts = max_attempts

def _format_past_attempt(self, attempt: Attempt, attempt_num: int) -> str:
return past_attempt_template.format(
past_attempt_number=attempt_num,
past_predicted_output=str(attempt.prediction),
past_critique=str(attempt.critique) if attempt.critique else "",
)

@step
async def format_task(
self, ctx: Context, ev: StartEvent
Expand All @@ -66,11 +83,19 @@ def format_train_example(train_pair: Dict) -> str:
task = ev.get("task", {})
await ctx.set("task", task)

# check if ctx has data from previous run
# if there is, don't overwrite it
prompt_vars = await ctx.get("prompt_vars", {})

if not prompt_vars:
# prepare prompt_vars
attempts = await ctx.get("attempts", [])
if attempts:
# update past predictions
prompt_vars = await ctx.get("prompt_vars")
formatted_past_attempts = [
self._format_past_attempt(a, ix + 1)
for ix, a in enumerate(attempts)
]
prompt_vars.update(
past_attempts="\n".join(formatted_past_attempts)
)
else:
examples = [format_train_example(t) for t in task["train"]]
prompt_vars = {
"test_input": pretty_print_grid(task["test"][0]["input"]),
Expand All @@ -85,16 +110,15 @@ async def prediction(
self, ctx: Context, ev: FormatTaskEvent
) -> PredictionEvent | StopEvent:
ctx.write_event_to_stream(ev)
attempts = await ctx.get("attempts", [])
attempts = cast(List[Attempt], attempts)
prompt_vars = await ctx.get("prompt_vars")

if "critique" in prompt_vars:
if attempts:
# generating a correction from last Workflow run
attempts = await ctx.get("attempts")
attempts = cast(List[Attempt], attempts)
correction: Prediction = await self.llm.astructured_predict(
Prediction, CORRECTION_PROMPT_TEMPLATE, **prompt_vars
)

attempts.append(Attempt(prediction=correction))
else:
# starting a new correction with no previous Workflow runs
Expand Down Expand Up @@ -133,7 +157,13 @@ async def reflection(self, ctx: Context, ev: EvaluationEvent) -> StopEvent:
# check if passing
if not ev.passing:
prompt_vars = await ctx.get("prompt_vars")
prompt_vars.update(predicted_output=str(latest_attempt.prediction))
formatted_past_attempts = [
self._format_past_attempt(a, ix + 1)
for ix, a in enumerate(attempts)
]
prompt_vars.update(
past_attempts="\n".join(formatted_past_attempts)
)

# generate critique
critique: Critique = await self.llm.astructured_predict(
Expand All @@ -142,8 +172,6 @@ async def reflection(self, ctx: Context, ev: EvaluationEvent) -> StopEvent:

# update states
latest_attempt.critique = critique
prompt_vars.update(critique=str(critique))
await ctx.set("prompt_vars", prompt_vars)

latest_attempt.passing = ev.passing
attempts[-1] = latest_attempt
Expand Down
14 changes: 6 additions & 8 deletions arc_finetuning_st/workflows/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
TEST INPUT:
{test_input}
LATEST PREDICTED OUTPUT:
{predicted_output}
PAST ATTEMPTS:
{past_attempts}
OUTPUT FORMAT:
{{
Expand All @@ -48,19 +48,17 @@
The predicted output was found to be incorrect and a critique has been articulated offering a potential
reason as to why it may have been a flawed prediction.
Your task now to create a new prediction that corrects the previous one using the critique.
Your task now to create a new prediction that corrects from the previous attempts. Use
the last attempt and critique.
EXAMPLES:
{examples}
TEST INPUT:
{test_input}
LATEST PREDICTED OUTPUT:
{predicted_output}
CRITIQUE:
{critique}
PAST ATTEMPTS:
{past_attempts}
OUTPUT FORMAT:
{{
Expand Down

0 comments on commit 0f7e2c4

Please sign in to comment.