Skip to content

Commit

Permalink
save according to task name
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Sep 21, 2024
1 parent f8d8e33 commit d3dd085
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
8 changes: 5 additions & 3 deletions arc_finetuning_st/finetuning/finetuning_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Annotated, Any, Callable, List, Optional

from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.bridge.pydantic import BaseModel, WrapSerializer
from llama_index.core.bridge.pydantic import BaseModel, Field, WrapSerializer

from arc_finetuning_st.finetuning.templates import (
ASSISTANT_TEMPLATE,
Expand All @@ -26,10 +26,12 @@ class FineTuningExample(BaseModel):
messages: List[
Annotated[ChatMessage, WrapSerializer(remove_additional_kwargs)]
]
task_name: str = Field(exclude=True)

@classmethod
def from_attempts(
cls,
task_name: str,
examples: str,
test_input: str,
attempts: List[Attempt],
Expand Down Expand Up @@ -65,7 +67,7 @@ def from_attempts(
),
]
)
return cls(messages=messages)
return cls(messages=messages, task_name=task_name)

def to_json(self) -> str:
data = self.model_dump()
Expand All @@ -79,5 +81,5 @@ def write_json(
data = self.model_dump()
dir = Path(dirpath or Path(__file__).parents[2].absolute(), dirname)
dir.mkdir(exist_ok=True, parents=True)
with open(Path(dir, "test.json"), "w") as f:
with open(Path(dir, self.task_name), "w") as f:
json.dump(data, f)
15 changes: 4 additions & 11 deletions arc_finetuning_st/streamlit/controller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import logging
import uuid
from os import listdir
from pathlib import Path
from typing import Any, List, Literal, Optional, cast
Expand All @@ -27,9 +26,9 @@ def __init__(self) -> None:
self._handler: Optional[WorkflowHandler] = None
self._attempts: List[Attempt] = []
self._passing_results: List[bool] = []
parent_path = Path(__file__).parents[2].absolute()
self._data_path = Path(parent_path, "data", "training")
self._attempts_history_df_key = str(uuid.uuid4())
self._data_path = Path(
Path(__file__).parents[2].absolute(), "data", "training"
)

def reset(self) -> None:
# clear prediction
Expand Down Expand Up @@ -159,10 +158,6 @@ def passing(self) -> Optional[bool]:
return self._passing_results[-1]
return None

@property
def attempts_history_df_key(self) -> str:
return self._attempts_history_df_key

@property
def attempts_history_df(
self,
Expand Down Expand Up @@ -246,6 +241,7 @@ def _display_finetuning_example() -> None:
nonlocal prompt_vars

finetuning_example = FineTuningExample.from_attempts(
task_name=st.session_state.selected_task,
attempts=self._attempts,
examples=prompt_vars["examples"],
test_input=prompt_vars["test_input"],
Expand All @@ -256,9 +252,6 @@ def _display_finetuning_example() -> None:
with save_col:
if st.button("Save", use_container_width=True):
finetuning_example.write_json()
st.success(
"Successfully saved finetuning example."
)
st.session_state.show_finetuning_preview_dialog = (
False
)
Expand Down

0 comments on commit d3dd085

Please sign in to comment.