From 30052679111fb33902e21005a9c264c7d8e4c4b2 Mon Sep 17 00:00:00 2001 From: JaiDhyani Date: Fri, 23 Feb 2024 03:13:08 -0800 Subject: [PATCH] Add load_logprob_dataset and load_logprob_datasets functions --- src/delphi/eval/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index 72863140..ee66ea74 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -11,6 +11,8 @@ PreTrainedTokenizerFast, ) +from delphi.eval import constants + GenericPreTrainedTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] @@ -94,3 +96,16 @@ def tokenize( Int[torch.Tensor, "seq"], tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0], ) + + +def load_logprob_dataset(model: str): + return cast( + Dataset, load_dataset(f"transcendingvictor/{model}-validation-logprobs") + ) + + +def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]: + return { + model: cast(dict, load_logprob_dataset(model)[split])["logprobs"] + for model in constants.LLAMA2_MODELS + }