Skip to content

Commit

Permalink
Add load_logprob_dataset and load_logprob_datasets functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Feb 23, 2024
1 parent 043c408 commit 3005267
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
PreTrainedTokenizerFast,
)

from delphi.eval import constants

GenericPreTrainedTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


Expand Down Expand Up @@ -94,3 +96,16 @@ def tokenize(
Int[torch.Tensor, "seq"],
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
)


def load_logprob_dataset(model: str):
return cast(
Dataset, load_dataset(f"transcendingvictor/{model}-validation-logprobs")
)


def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]:
return {
model: cast(dict, load_logprob_dataset(model)[split])["logprobs"]
for model in constants.LLAMA2_MODELS
}

0 comments on commit 3005267

Please sign in to comment.