diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py new file mode 100644 index 00000000..327f4651 --- /dev/null +++ b/scripts/map_tokens.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +import argparse +import pickle + +from delphi.eval.token_map import token_map +from delphi.eval.utils import load_validation_dataset + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "dataset_name", help="Dataset from huggingface to run token_map on" + ) + parser.add_argument("--output", help="Output file name", default="token_map.pkl") + args = parser.parse_args() + + dataset = load_validation_dataset(args.dataset_name) + + mapping = token_map(dataset) + + with open(f"data/{args.output}", "wb") as f: + pickle.dump(mapping, file=f) diff --git a/src/delphi/eval/token_map.py b/src/delphi/eval/token_map.py new file mode 100644 index 00000000..17e8e2ca --- /dev/null +++ b/src/delphi/eval/token_map.py @@ -0,0 +1,18 @@ +import os +from typing import cast + +from datasets import Dataset + + +def token_map( + tokenized_dataset: Dataset, +) -> dict[int, list[tuple[int, int]]]: + """Return a mapping of tokens to their (prompt_idx, token_idx) locations in the tokenized_dataset.""" + + mapping = {} + for prompt_idx, prompt in enumerate(tokenized_dataset): + prompt = cast(dict, prompt) + for token_idx, token in enumerate(prompt["tokens"]): + mapping.setdefault(token, []).append((prompt_idx, token_idx)) + + return mapping diff --git a/tests/eval/test_token_map.py b/tests/eval/test_token_map.py new file mode 100644 index 00000000..5a276080 --- /dev/null +++ b/tests/eval/test_token_map.py @@ -0,0 +1,48 @@ +import pytest +from datasets import Dataset + +from delphi.eval.token_map import token_map + + +def test_token_map(): + tokenized_dataset = Dataset.from_dict( + { + "tokens": [ + [0, 1, 2, 3, 4, 5, 0, 6, 7], + [0, 1, 2, 3, 4, 5, 0, 6, 7], + [0, 1, 2, 3, 4, 5, 0, 6, 7], + ] + } + ) + mapping = token_map(tokenized_dataset) + assert mapping == { + 0: [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)], + 1: [(0, 1), (1, 1), (2, 1)], + 2: [(0, 2), (1, 2), (2, 2)], + 3: [(0, 3), (1, 3), (2, 3)], + 4: [(0, 4), (1, 4), (2, 4)], + 5: [(0, 5), (1, 5), (2, 5)], + 6: [(0, 7), (1, 7), (2, 7)], + 7: [(0, 8), (1, 8), (2, 8)], + } + + # fmt: off + tokenized_dataset = Dataset.from_dict( + { # one really long prompt + "tokens": [ + [0, 1, 2, 3, 4, 5, 0, 6, 7, 0, 1, 2, 3, 4, 5, 0, 6, 7, 0, 1, 2, 3, 4, 5, 0, 6, 7] + ] + } + ) + # fmt: on + mapping = token_map(tokenized_dataset) + assert mapping == { + 0: [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)], + 1: [(0, 1), (0, 10), (0, 19)], + 2: [(0, 2), (0, 11), (0, 20)], + 3: [(0, 3), (0, 12), (0, 21)], + 4: [(0, 4), (0, 13), (0, 22)], + 5: [(0, 5), (0, 14), (0, 23)], + 6: [(0, 7), (0, 16), (0, 25)], + 7: [(0, 8), (0, 17), (0, 26)], + }