Skip to content

Commit

Permalink
[ENH] enable base_dir to be a list (#392)
Browse files Browse the repository at this point in the history
* [ENH] enable base_dir to be a list

* remove breakpoint

* by default, don't parse.
  • Loading branch information
YannDubs authored Aug 17, 2024
1 parent c6a4164 commit c4def44
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions src/alpaca_eval/annotators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import pandas as pd

from .. import completion_parsers, constants, processors, utils
from .. import completion_parsers, constants, processors, types, utils
from ..decoders import get_fn_completions

CURRENT_DIR = Path(__file__).parent
Expand All @@ -23,11 +23,11 @@ class BaseAnnotator(abc.ABC):
Parameters
----------
annotators_config : Path or list of dict, optional
A dictionary or path to a yaml file containing the configuration for the pool of annotators. If a directory,
we search for 'configs.yaml' in it. The keys in the first dictionary should be the annotator's name, and
the value should be a dictionary of the annotator's configuration which should have the following keys:
The path is relative to `base_dir` directory.
annotators_config : Path, optional
A path to a yaml file containing the configuration for the pool of annotators. The path can be absolute or
relative to `base_dir` directory. If a directory, we search for 'configs.yaml' in it. After loading, the keys
in the first dictionary should be the annotator's name, and the value should be a dictionary of the annotator's
configuration which should have the following keys:
- prompt_template (str): a prompt template or path to it. The template should contain placeholders for keys in
the example dictionary, typically {instruction} and {output_1} {output_2}.
- fn_completions (str): function in `alpaca_farm.decoders` for completions. Needs to accept as first argument
Expand Down Expand Up @@ -58,9 +58,10 @@ class BaseAnnotator(abc.ABC):
is_store_missing_annotations : bool, optional
Whether to store missing annotations. If True it avoids trying to reannotate examples that have errors.
base_dir : Path, optional
base_dir : Path or list of Path, optional
Path to the directory containing the annotators configs. I.e. annotators_config will be relative
to this directory. If None uses self.DEFAULT_BASE_DIR
to this directory. If None uses self.DEFAULT_BASE_DIR. If a list we will use the first such that
annotators_config can be loaded.
is_raise_if_missing_primary_keys : bool, optional
Whether to ensure that the primary keys are in the example dictionary. If True, raises an error.
Expand All @@ -85,7 +86,7 @@ class BaseAnnotator(abc.ABC):
def __init__(
self,
primary_keys: Sequence[str],
annotators_config: Union[utils.AnyPath, list[dict[str, Any]]] = constants.DEFAULT_ANNOTATOR_CONFIG,
annotators_config: Union[types.AnyPath] = constants.DEFAULT_ANNOTATOR_CONFIG,
seed: Optional[int] = 0,
is_avoid_reannotations: bool = True,
other_output_keys_to_keep: Sequence[str] = (
Expand All @@ -95,13 +96,13 @@ def __init__(
),
other_input_keys_to_keep: Sequence[str] = (),
is_store_missing_annotations: bool = True,
base_dir: Optional[utils.AnyPath] = None,
base_dir: Optional[Union[types.AnyPath, Sequence[types.AnyPath]]] = None,
is_raise_if_missing_primary_keys: bool = True,
annotation_type: Optional[Type] = None,
is_reapply_parsing: bool = False,
):
logging.info(f"Creating the annotator from `{annotators_config}`.")
self.base_dir = Path(base_dir or self.DEFAULT_BASE_DIR)
base_dir = base_dir or self.DEFAULT_BASE_DIR
self.seed = seed
self.is_avoid_reannotations = is_avoid_reannotations
self.primary_keys = list(primary_keys)
Expand All @@ -113,7 +114,15 @@ def __init__(
self.annotation_type = annotation_type or self.DEFAULT_ANNOTATION_TYPE
self.is_reapply_parsing = is_reapply_parsing

self.annotators_config = self._initialize_annotators_config(annotators_config)
# loop over all the base_dirs until you find the annotators_config
if not isinstance(base_dir, (list, tuple, set)):
base_dir = [base_dir]
for d in base_dir:
self.base_dir = Path(d)
self.annotators_config = self._initialize_annotators_config(annotators_config)
if self.annotators_config.exists():
break

self.annotators = self._initialize_annotators()
self.df_annotations = None

Expand Down Expand Up @@ -151,7 +160,7 @@ def annotator_name(self) -> str:

def __call__(
self,
to_annotate: utils.AnyData,
to_annotate: types.AnyData,
chunksize: Optional[int] = 128,
**decoding_kwargs,
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -207,6 +216,9 @@ def __call__(

### Private methods ###
def _initialize_annotators_config(self, annotators_config):
if isinstance(annotators_config, (list, tuple)):
return annotators_config

# setting it relative to the config directory
annotators_config = self.base_dir / annotators_config

Expand Down Expand Up @@ -243,7 +255,7 @@ def _add_missing_primary_keys_(self, df: pd.DataFrame):
for c in missing_primary_keys:
df[c] = None

def _preprocess(self, to_annotate: utils.AnyData) -> pd.DataFrame:
def _preprocess(self, to_annotate: types.AnyData) -> pd.DataFrame:
"""Preprocess the examples to annotate. In particular takes care of filtering unnecessary examples."""

df_to_annotate = utils.convert_to_dataframe(to_annotate)
Expand Down Expand Up @@ -316,7 +328,7 @@ def _annotate(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataF
def _postprocess_and_store_(
self,
df_annotated: pd.DataFrame,
to_annotate: utils.AnyData,
to_annotate: types.AnyData,
) -> list[dict[str, Any]]:
"""Convert the dataframe into a list of dictionaries to be returned, and store current anntations."""

Expand Down Expand Up @@ -476,11 +488,11 @@ class BaseAnnotatorJSON(BaseAnnotator):
"""
)

def __init__(self, *args, caching_path: Optional[utils.AnyPath] = "auto", **kwargs):
def __init__(self, *args, caching_path: Optional[types.AnyPath] = "auto", **kwargs):
super().__init__(*args, **kwargs)
self.caching_path = self._initialize_cache(caching_path)

def save(self, path: Optional[utils.AnyPath] = None):
def save(self, path: Optional[types.AnyPath] = None):
"""Save all annotations to json."""

path = path or self.caching_path
Expand All @@ -492,7 +504,7 @@ def save(self, path: Optional[utils.AnyPath] = None):
self.df_annotations = self.df_annotations[~self.df_annotations[self.annotation_key].isna()]
self.df_annotations.to_json(path, orient="records", indent=2)

def load_(self, path: Optional[utils.AnyPath] = None):
def load_(self, path: Optional[types.AnyPath] = None):
"""Load all the annotations from json."""
path = path or self.caching_path
if path is not None:
Expand Down Expand Up @@ -591,15 +603,15 @@ class SingleAnnotator:

def __init__(
self,
prompt_template: utils.AnyPath,
fn_completion_parser: Optional[Union[Callable, str]] = "regex_parser",
prompt_template: types.AnyPath,
fn_completion_parser: Optional[Union[Callable, str]] = None,
completion_parser_kwargs: Optional[dict[str, Any]] = None,
fn_completions: Union[Callable, str] = "openai_completions",
completions_kwargs: Optional[dict[str, Any]] = None,
is_shuffle: bool = True,
seed: Optional[int] = 123,
batch_size: int = 1,
base_dir: utils.AnyPath = constants.EVALUATORS_CONFIG_DIR,
base_dir: types.AnyPath = constants.EVALUATORS_CONFIG_DIR,
annotation_column: str = "annotation",
is_store_raw_completions: bool = True,
processors_to_kwargs: Optional[dict[str, dict]] = None,
Expand Down Expand Up @@ -719,7 +731,7 @@ def _search_processor(self, name: Union[str, Type["processors.BaseProcessor"]])
assert issubclass(name, processors.BaseProcessor)
return name

def _get_prompt_template(self, prompt_template: utils.AnyPath):
def _get_prompt_template(self, prompt_template: types.AnyPath):
return utils.read_or_return(self.base_dir / prompt_template)

def _make_prompts(
Expand Down

0 comments on commit c4def44

Please sign in to comment.