diff --git a/arc_finetuning_st/finetuning/finetuning_example.py b/arc_finetuning_st/finetuning/finetuning_example.py index 72f5102..aba77c0 100644 --- a/arc_finetuning_st/finetuning/finetuning_example.py +++ b/arc_finetuning_st/finetuning/finetuning_example.py @@ -3,7 +3,7 @@ from typing import Annotated, Any, Callable, List, Optional from llama_index.core.base.llms.types import ChatMessage, MessageRole -from llama_index.core.bridge.pydantic import BaseModel, WrapSerializer +from llama_index.core.bridge.pydantic import BaseModel, Field, WrapSerializer from arc_finetuning_st.finetuning.templates import ( ASSISTANT_TEMPLATE, @@ -26,10 +26,12 @@ class FineTuningExample(BaseModel): messages: List[ Annotated[ChatMessage, WrapSerializer(remove_additional_kwargs)] ] + task_name: str = Field(exclude=True) @classmethod def from_attempts( cls, + task_name: str, examples: str, test_input: str, attempts: List[Attempt], @@ -65,7 +67,7 @@ def from_attempts( ), ] ) - return cls(messages=messages) + return cls(messages=messages, task_name=task_name) def to_json(self) -> str: data = self.model_dump() @@ -79,5 +81,5 @@ def write_json( data = self.model_dump() dir = Path(dirpath or Path(__file__).parents[2].absolute(), dirname) dir.mkdir(exist_ok=True, parents=True) - with open(Path(dir, "test.json"), "w") as f: + with open(Path(dir, self.task_name), "w") as f: json.dump(data, f) diff --git a/arc_finetuning_st/streamlit/controller.py b/arc_finetuning_st/streamlit/controller.py index 76b579a..87a08ac 100644 --- a/arc_finetuning_st/streamlit/controller.py +++ b/arc_finetuning_st/streamlit/controller.py @@ -1,7 +1,6 @@ import asyncio import json import logging -import uuid from os import listdir from pathlib import Path from typing import Any, List, Literal, Optional, cast @@ -27,9 +26,9 @@ def __init__(self) -> None: self._handler: Optional[WorkflowHandler] = None self._attempts: List[Attempt] = [] self._passing_results: List[bool] = [] - parent_path = Path(__file__).parents[2].absolute() - self._data_path = Path(parent_path, "data", "training") - self._attempts_history_df_key = str(uuid.uuid4()) + self._data_path = Path( + Path(__file__).parents[2].absolute(), "data", "training" + ) def reset(self) -> None: # clear prediction @@ -159,10 +158,6 @@ def passing(self) -> Optional[bool]: return self._passing_results[-1] return None - @property - def attempts_history_df_key(self) -> str: - return self._attempts_history_df_key - @property def attempts_history_df( self, @@ -246,6 +241,7 @@ def _display_finetuning_example() -> None: nonlocal prompt_vars finetuning_example = FineTuningExample.from_attempts( + task_name=st.session_state.selected_task, attempts=self._attempts, examples=prompt_vars["examples"], test_input=prompt_vars["test_input"], @@ -256,9 +252,6 @@ def _display_finetuning_example() -> None: with save_col: if st.button("Save", use_container_width=True): finetuning_example.write_json() - st.success( - "Successfully saved finetuning example." - ) st.session_state.show_finetuning_preview_dialog = ( False )