Skip to content

Commit

Permalink
eval.utils.load_validation_dataset (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak authored Feb 8, 2024
1 parent 7f7f303 commit 7ae5d16
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
17 changes: 17 additions & 0 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Callable
from typing import cast

import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int


Expand All @@ -25,3 +27,18 @@ def get_next_logprobs(
logprobs = get_all_logprobs(model, input_ids[:, :-1])
next_tokens = input_ids[:, 1:]
return gather_logprobs(logprobs, next_tokens)


def load_validation_dataset(dataset_name: str) -> Dataset:
if "/" not in dataset_name:
dataset_name = f"delphi-suite/{dataset_name}"
data_str = f"data/validation-*.parquet"
dataset = load_dataset(
dataset_name,
data_files=data_str,
verification_mode="no_checks",
# this seems to be the only split when using data_files
# regardless of the files we're actually loading
split="train",
)
return cast(Dataset, dataset)
7 changes: 6 additions & 1 deletion tests/eval/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from delphi.eval.utils import gather_logprobs
from delphi.eval.utils import gather_logprobs, load_validation_dataset


def test_gather_logprobs():
Expand Down Expand Up @@ -41,3 +41,8 @@ def test_gather_logprobs():
)
result = gather_logprobs(logprobs, tokens)
assert torch.allclose(result, expected_output)


def test_load_validation_dataset():
text = load_validation_dataset("tinystories-v2-clean")
tokenized = load_validation_dataset("tinystories-v2-clean-tokenized-v0")

0 comments on commit 7ae5d16

Please sign in to comment.