Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add metadata to completion: date, version,... #402

Merged
merged 7 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions src/alpaca_eval/annotators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import json
import logging
import os
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Type, Union

import numpy as np
import pandas as pd

import alpaca_eval

from .. import completion_parsers, constants, processors, types, utils
from ..decoders import get_fn_completions

Expand Down Expand Up @@ -50,10 +53,12 @@ class BaseAnnotator(abc.ABC):
Keys use to distinguish the example.

other_output_keys_to_keep : sequence of str, optional
Other output columns to store besides the annotations.
Other output columns to store besides the annotations. You can use `{annotation_key}` to refer to the name
of the annotation column.

other_input_keys_to_keep : sequence of str, optional
Other columns to keep from the input dataframe besides the primary keys.
Other columns to keep from the input dataframe besides the primary keys. You can use `{annotation_key}` to refer
to the name of the annotation column.

is_store_missing_annotations : bool, optional
Whether to store missing annotations. If True it avoids trying to reannotate examples that have errors.
Expand Down Expand Up @@ -90,16 +95,19 @@ def __init__(
seed: Optional[int] = 0,
is_avoid_reannotations: bool = True,
other_output_keys_to_keep: Sequence[str] = (
"price_per_example",
"time_per_example",
"raw_completion",
"{annotation_key}_price_per_example",
"{annotation_key}_time_per_example",
"{annotation_key}_version",
"{annotation_key}_date",
"{annotation_key}_raw_completion",
),
other_input_keys_to_keep: Sequence[str] = (),
is_store_missing_annotations: bool = True,
base_dir: Optional[Union[types.AnyPath, Sequence[types.AnyPath]]] = None,
is_raise_if_missing_primary_keys: bool = True,
annotation_type: Optional[Type] = None,
is_reapply_parsing: bool = False,
**single_annotator_kwargs,
):
logging.info(f"Creating the annotator from `{annotators_config}`.")
base_dir = base_dir or self.DEFAULT_BASE_DIR
Expand All @@ -123,9 +131,11 @@ def __init__(
if self.annotators_config.exists():
break

self.annotators = self._initialize_annotators()
self.annotators = self._initialize_annotators(**single_annotator_kwargs)
self.df_annotations = None

other_output_keys_to_keep = [c.format(annotation_key=self.annotation_key) for c in other_output_keys_to_keep]
other_input_keys_to_keep = [c.format(annotation_key=self.annotation_key) for c in other_input_keys_to_keep]
self.other_input_keys_to_keep = self._get_other_input_keys_to_keep(other_input_keys_to_keep)
self.other_output_keys_to_keep = self._get_other_output_keys_to_keep(other_output_keys_to_keep)
self.other_keys_to_keep = self.other_output_keys_to_keep + self.other_input_keys_to_keep
Expand All @@ -148,6 +158,11 @@ def annotation_key(self) -> str:
"""How to refer to the annotations, this will be the key for annotations in the output."""
return "annotation"

@property
def completion_key(self) -> str:
"""How to refer to the raw completions, this will be the key for raw completions in the output."""
return f"{self.annotation_key}_raw_completion"

@property
def random_seed_keys(self) -> list[str]:
"""What key / column to seed on for the random generator."""
Expand Down Expand Up @@ -227,7 +242,7 @@ def _initialize_annotators_config(self, annotators_config):

return annotators_config

def _initialize_annotators(self) -> dict[str, "SingleAnnotator"]:
def _initialize_annotators(self, **kwargs) -> dict[str, "SingleAnnotator"]:
"""Load all the configs and prompts if necessary."""
annotators_config = utils.load_configs(self.annotators_config)
try:
Expand All @@ -241,7 +256,9 @@ def _initialize_annotators(self) -> dict[str, "SingleAnnotator"]:
seed=self.seed,
base_dir=base_dir,
annotation_column=self.annotation_key,
completion_column=self.completion_key,
**annotator_config,
**kwargs,
)
for name, annotator_config in annotators_config.items()
}
Expand Down Expand Up @@ -311,8 +328,8 @@ def _annotate(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataF
]
# if df_to_annotate "raw_completion" is a dict, put it back to a json string so that you can reparse it
# TODO: this is for backward compatibility, remove in the future
if "raw_completion" in df_to_annotate.columns:
df_to_annotate["raw_completion"] = df_to_annotate["raw_completion"].apply(
if self.completion_key in df_to_annotate.columns:
df_to_annotate[self.completion_key] = df_to_annotate[self.completion_key].apply(
lambda x: json.dumps(x) if isinstance(x, dict) else x
)

Expand Down Expand Up @@ -583,11 +600,11 @@ class SingleAnnotator:
annotation_column : str, optional
Name of the annotation column in the output dataframe.

is_store_raw_completions : bool, optional
Whether to store raw completions at `"raw_completion"` column in the output dataframe. Note that raw_completion
will not be modified by the postprocessors. E.g. if we switch the columns output_1 and output_2 in the prompt
then the raw completion will show the switched order, which makes interpretation harder. This should
nevertheless not be an issue when using reapply_parsing because of seeding.
completion_column : str, optional
Name of the raw completion column in the output dataframe. If None will not store the raw completions. Note that
raw_completion will not be modified by the postprocessors. E.g. if we switch the columns output_1 and output_2
in the prompt then the raw completion will show the switched order, which makes interpretation harder. This
should nevertheless not be an issue when using reapply_parsing because of seeding.

processors_to_kwargs : Sequence[dict(str, dict)], optional
A dictionary of BaseProcessor objects to apply for preprocessing the dataframe before making the prompts and
Expand All @@ -599,6 +616,9 @@ class SingleAnnotator:

completion_key : str, optional
Key of the output of `fn_completions` to use for parsing the completions into annotations.

packages_for_which_to_show_version : Sequence[str], optional
List of packages for which to show the version in the metadata of the completions.
"""

def __init__(
Expand All @@ -613,10 +633,12 @@ def __init__(
batch_size: int = 1,
base_dir: types.AnyPath = constants.EVALUATORS_CONFIG_DIR,
annotation_column: str = "annotation",
is_store_raw_completions: bool = True,
completion_column: Optional[str] = "raw_completion",
processors_to_kwargs: Optional[dict[str, dict]] = None,
is_add_default_processors: bool = True,
completion_key: str = "completions",
packages_for_which_to_show_version: Optional[Sequence[str]] = ("alpaca_eval",),
prfx_to_completion_cols: Optional[str] = "{annotation_column}_",
# The following two keys are only for the documentation
pretty_name: Optional[str] = None,
link: Optional[str] = None,
Expand All @@ -637,7 +659,11 @@ def __init__(
self.is_shuffle = is_shuffle
self.batch_size = batch_size
self.annotation_column = annotation_column
self.completion_column = "raw_completion" if is_store_raw_completions else None
self.completion_column = completion_column
self.packages_for_which_to_show_version = packages_for_which_to_show_version
if prfx_to_completion_cols is None:
prfx_to_completion_cols = ""
self.prfx_to_completion_cols = prfx_to_completion_cols.format(annotation_column=annotation_column)

self.is_add_default_processors = is_add_default_processors
self.processors = []
Expand Down Expand Up @@ -690,9 +716,14 @@ def __call__(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataFr
# prompts and completions here will not be the same length as the dataframe due to batching
prompts, df_to_annotate = self._make_prompts(df_to_annotate)
completions = self.fn_completions(prompts=prompts, **self.completions_kwargs, **decoding_kwargs)
self._add_metadata_to_completions_(completions)
completions = {
f"{self.prfx_to_completion_cols}{k}" if k != self.completion_key else k: v
for k, v in completions.items()
}

for k, v in completions.items():
if k != "completions":
if k != self.completion_key:
if self.batch_size != 1 and (len(df_to_annotate) == len(v) * self.batch_size):
v = [el for el in v for _ in range(self.batch_size)]
df_to_annotate[k] = v
Expand Down Expand Up @@ -735,7 +766,7 @@ def _search_processor(self, name: Union[str, Type["processors.BaseProcessor"]])
return name

def _get_prompt_template(self, prompt_template: types.AnyPath):
return utils.read_or_return(self.base_dir / prompt_template)
return utils.read_or_return(prompt_template, relative_to=self.base_dir)

def _make_prompts(
self, df_to_annotate: pd.DataFrame, prompt_template: Optional[str] = None
Expand All @@ -762,6 +793,12 @@ def _make_prompts(
prompt_template = self.prompt_template
return utils.make_prompts(df=df_to_annotate, template=prompt_template, batch_size=self.batch_size)

def _add_metadata_to_completions_(self, completions: dict[str, Any]):
"""Add metadata to the completions."""
completions["date"] = datetime.now().isoformat()
if self.packages_for_which_to_show_version is not None:
completions["version"] = utils.get_multi_package_version(self.packages_for_which_to_show_version)

def _preprocess(self, df_to_annotate: pd.DataFrame) -> pd.DataFrame:
"""Preprocess the examples before annotating. In particular, takes care of all the randomization."""

Expand Down
10 changes: 10 additions & 0 deletions src/alpaca_eval/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,15 @@ def get_fn_completions(name: Union[str, Callable]) -> Callable:
logging.exception(f"You need {packages} to use bedrock_anthropic. Error:")
raise e

elif name == "cache_completions":
from .cache import cache_completions

return cache_completions

elif name == "test_completions":
from .test import test_completions

return test_completions

else:
raise ValueError(f"Unknown decoder: {name}")
51 changes: 51 additions & 0 deletions src/alpaca_eval/decoders/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import json
from pathlib import Path
from typing import Sequence

from alpaca_eval.decoders import get_fn_completions
from alpaca_eval.types import AnyPath

__all__ = ["cache_completions"]


def cache_completions(prompts: Sequence[str], fn_completions: str, cache_path: AnyPath, **completions_kwargs):
"""Simple wrapper around a completion function to cache the results to JSON on disk.
Parameters
----------
prompts : list of str
Prompts to get completions for.

fn_completions : str
Function in `decoders.py` to use for decoding the output.

cache_path : str
Path to the cache file.

completions_kwargs : dict
kwargs for fn_completions. E.g. model_name, max_tokens, temperature, top_p, top_k, stop_seq.

"""
assert isinstance(fn_completions, str), "fn_completions must be a string to be hashable."
all_args = [dict(prompt=p, fn_completions=fn_completions, completions_kwargs=completions_kwargs) for p in prompts]

cache_path = Path(cache_path)

try:
with open(cache_path, "r") as f:
cache = json.load(f)
except FileNotFoundError:
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache = {}

outs = []
fn_completions = get_fn_completions(fn_completions)
for args in all_args:
hashable_args = json.dumps(args, sort_keys=True)
if hashable_args not in cache:
cache[hashable_args] = fn_completions(prompts=[args["prompt"]], **args["completions_kwargs"])
outs.append(cache[hashable_args])

with open(cache_path, "w") as f:
json.dump(cache, f)

return outs
27 changes: 27 additions & 0 deletions src/alpaca_eval/decoders/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import logging
from typing import Sequence

from .. import utils

__all__ = ["test_completions"]


def test_completions(
prompts: Sequence[str],
model_name="test",
value: str = "{'name': 'test'}",
**decoding_kwargs,
) -> dict[str, list]:
"""Completion function for testing purposes. Returns the same value for all prompts."""

n_examples = len(prompts)

kwargs = dict(model_name=model_name, **decoding_kwargs)
logging.info(f"Kwargs to completion: {kwargs}")
with utils.Timer() as t:
responses = [value for _ in prompts]
avg_time = [t.duration / n_examples] * len(responses)
price_per_example = [0] * len(responses)
return dict(
completions=responses, price_per_example=price_per_example, time_per_example=avg_time, completions_all=responses
)
2 changes: 1 addition & 1 deletion src/alpaca_eval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def get_completions(configs, df: pd.DataFrame, old_output_path: Optional[Path] =
if len(curr_outputs) > 0:
prompts, _ = utils.make_prompts(
curr_outputs,
template=utils.read_or_return(base_dir / configs["prompt_template"]),
template=utils.read_or_return(configs["prompt_template"], relative_to=base_dir),
)
fn_completions = decoders.get_fn_completions(configs["fn_completions"])
completions = fn_completions(prompts=prompts, **configs["completions_kwargs"])["completions"]
Expand Down
Loading
Loading