diff --git a/src/alpaca_eval/annotators/base.py b/src/alpaca_eval/annotators/base.py index 72beba42..349c0835 100644 --- a/src/alpaca_eval/annotators/base.py +++ b/src/alpaca_eval/annotators/base.py @@ -9,7 +9,7 @@ import numpy as np import pandas as pd -from .. import completion_parsers, constants, processors, utils +from .. import completion_parsers, constants, processors, types, utils from ..decoders import get_fn_completions CURRENT_DIR = Path(__file__).parent @@ -23,11 +23,11 @@ class BaseAnnotator(abc.ABC): Parameters ---------- - annotators_config : Path or list of dict, optional - A dictionary or path to a yaml file containing the configuration for the pool of annotators. If a directory, - we search for 'configs.yaml' in it. The keys in the first dictionary should be the annotator's name, and - the value should be a dictionary of the annotator's configuration which should have the following keys: - The path is relative to `base_dir` directory. + annotators_config : Path, optional + A path to a yaml file containing the configuration for the pool of annotators. The path can be absolute or + relative to `base_dir` directory. If a directory, we search for 'configs.yaml' in it. After loading, the keys + in the first dictionary should be the annotator's name, and the value should be a dictionary of the annotator's + configuration which should have the following keys: - prompt_template (str): a prompt template or path to it. The template should contain placeholders for keys in the example dictionary, typically {instruction} and {output_1} {output_2}. - fn_completions (str): function in `alpaca_farm.decoders` for completions. Needs to accept as first argument @@ -58,9 +58,10 @@ class BaseAnnotator(abc.ABC): is_store_missing_annotations : bool, optional Whether to store missing annotations. If True it avoids trying to reannotate examples that have errors. - base_dir : Path, optional + base_dir : Path or list of Path, optional Path to the directory containing the annotators configs. I.e. annotators_config will be relative - to this directory. If None uses self.DEFAULT_BASE_DIR + to this directory. If None uses self.DEFAULT_BASE_DIR. If a list we will use the first such that + annotators_config can be loaded. is_raise_if_missing_primary_keys : bool, optional Whether to ensure that the primary keys are in the example dictionary. If True, raises an error. @@ -85,7 +86,7 @@ class BaseAnnotator(abc.ABC): def __init__( self, primary_keys: Sequence[str], - annotators_config: Union[utils.AnyPath, list[dict[str, Any]]] = constants.DEFAULT_ANNOTATOR_CONFIG, + annotators_config: Union[types.AnyPath] = constants.DEFAULT_ANNOTATOR_CONFIG, seed: Optional[int] = 0, is_avoid_reannotations: bool = True, other_output_keys_to_keep: Sequence[str] = ( @@ -95,13 +96,13 @@ def __init__( ), other_input_keys_to_keep: Sequence[str] = (), is_store_missing_annotations: bool = True, - base_dir: Optional[utils.AnyPath] = None, + base_dir: Optional[Union[types.AnyPath, Sequence[types.AnyPath]]] = None, is_raise_if_missing_primary_keys: bool = True, annotation_type: Optional[Type] = None, is_reapply_parsing: bool = False, ): logging.info(f"Creating the annotator from `{annotators_config}`.") - self.base_dir = Path(base_dir or self.DEFAULT_BASE_DIR) + base_dir = base_dir or self.DEFAULT_BASE_DIR self.seed = seed self.is_avoid_reannotations = is_avoid_reannotations self.primary_keys = list(primary_keys) @@ -113,7 +114,15 @@ def __init__( self.annotation_type = annotation_type or self.DEFAULT_ANNOTATION_TYPE self.is_reapply_parsing = is_reapply_parsing - self.annotators_config = self._initialize_annotators_config(annotators_config) + # loop over all the base_dirs until you find the annotators_config + if not isinstance(base_dir, (list, tuple, set)): + base_dir = [base_dir] + for d in base_dir: + self.base_dir = Path(d) + self.annotators_config = self._initialize_annotators_config(annotators_config) + if self.annotators_config.exists(): + break + self.annotators = self._initialize_annotators() self.df_annotations = None @@ -151,7 +160,7 @@ def annotator_name(self) -> str: def __call__( self, - to_annotate: utils.AnyData, + to_annotate: types.AnyData, chunksize: Optional[int] = 128, **decoding_kwargs, ) -> list[dict[str, Any]]: @@ -207,6 +216,9 @@ def __call__( ### Private methods ### def _initialize_annotators_config(self, annotators_config): + if isinstance(annotators_config, (list, tuple)): + return annotators_config + # setting it relative to the config directory annotators_config = self.base_dir / annotators_config @@ -243,7 +255,7 @@ def _add_missing_primary_keys_(self, df: pd.DataFrame): for c in missing_primary_keys: df[c] = None - def _preprocess(self, to_annotate: utils.AnyData) -> pd.DataFrame: + def _preprocess(self, to_annotate: types.AnyData) -> pd.DataFrame: """Preprocess the examples to annotate. In particular takes care of filtering unnecessary examples.""" df_to_annotate = utils.convert_to_dataframe(to_annotate) @@ -316,7 +328,7 @@ def _annotate(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataF def _postprocess_and_store_( self, df_annotated: pd.DataFrame, - to_annotate: utils.AnyData, + to_annotate: types.AnyData, ) -> list[dict[str, Any]]: """Convert the dataframe into a list of dictionaries to be returned, and store current anntations.""" @@ -476,11 +488,11 @@ class BaseAnnotatorJSON(BaseAnnotator): """ ) - def __init__(self, *args, caching_path: Optional[utils.AnyPath] = "auto", **kwargs): + def __init__(self, *args, caching_path: Optional[types.AnyPath] = "auto", **kwargs): super().__init__(*args, **kwargs) self.caching_path = self._initialize_cache(caching_path) - def save(self, path: Optional[utils.AnyPath] = None): + def save(self, path: Optional[types.AnyPath] = None): """Save all annotations to json.""" path = path or self.caching_path @@ -492,7 +504,7 @@ def save(self, path: Optional[utils.AnyPath] = None): self.df_annotations = self.df_annotations[~self.df_annotations[self.annotation_key].isna()] self.df_annotations.to_json(path, orient="records", indent=2) - def load_(self, path: Optional[utils.AnyPath] = None): + def load_(self, path: Optional[types.AnyPath] = None): """Load all the annotations from json.""" path = path or self.caching_path if path is not None: @@ -591,15 +603,15 @@ class SingleAnnotator: def __init__( self, - prompt_template: utils.AnyPath, - fn_completion_parser: Optional[Union[Callable, str]] = "regex_parser", + prompt_template: types.AnyPath, + fn_completion_parser: Optional[Union[Callable, str]] = None, completion_parser_kwargs: Optional[dict[str, Any]] = None, fn_completions: Union[Callable, str] = "openai_completions", completions_kwargs: Optional[dict[str, Any]] = None, is_shuffle: bool = True, seed: Optional[int] = 123, batch_size: int = 1, - base_dir: utils.AnyPath = constants.EVALUATORS_CONFIG_DIR, + base_dir: types.AnyPath = constants.EVALUATORS_CONFIG_DIR, annotation_column: str = "annotation", is_store_raw_completions: bool = True, processors_to_kwargs: Optional[dict[str, dict]] = None, @@ -719,7 +731,7 @@ def _search_processor(self, name: Union[str, Type["processors.BaseProcessor"]]) assert issubclass(name, processors.BaseProcessor) return name - def _get_prompt_template(self, prompt_template: utils.AnyPath): + def _get_prompt_template(self, prompt_template: types.AnyPath): return utils.read_or_return(self.base_dir / prompt_template) def _make_prompts(