From b8d0d8ce09a59fc35c42a2e2370bf1c5e68b5071 Mon Sep 17 00:00:00 2001 From: Rai <62800649+menamerai@users.noreply.github.com> Date: Wed, 14 Feb 2024 14:15:25 -0500 Subject: [PATCH] Ported token mapping function (#22) * add token mapping * use cast for typing * moved token_map function to library * added test case for token_map * added test cases for token_map * review changes * add load_hf_dataset function * shorten docstring * revert pickling changes * added token mapping script --- scripts/map_tokens.py | 22 +++++++++++++++++ src/delphi/eval/token_map.py | 18 ++++++++++++++ tests/eval/test_token_map.py | 48 ++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 scripts/map_tokens.py create mode 100644 src/delphi/eval/token_map.py create mode 100644 tests/eval/test_token_map.py diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py new file mode 100644 index 00000000..327f4651 --- /dev/null +++ b/scripts/map_tokens.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +import argparse +import pickle + +from delphi.eval.token_map import token_map +from delphi.eval.utils import load_validation_dataset + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "dataset_name", help="Dataset from huggingface to run token_map on" + ) + parser.add_argument("--output", help="Output file name", default="token_map.pkl") + args = parser.parse_args() + + dataset = load_validation_dataset(args.dataset_name) + + mapping = token_map(dataset) + + with open(f"data/{args.output}", "wb") as f: + pickle.dump(mapping, file=f) diff --git a/src/delphi/eval/token_map.py b/src/delphi/eval/token_map.py new file mode 100644 index 00000000..17e8e2ca --- /dev/null +++ b/src/delphi/eval/token_map.py @@ -0,0 +1,18 @@ +import os +from typing import cast + +from datasets import Dataset + + +def token_map( + tokenized_dataset: Dataset, +) -> dict[int, list[tuple[int, int]]]: + """Return a mapping of tokens to their (prompt_idx, token_idx) locations in the tokenized_dataset.""" + + mapping = {} + for prompt_idx, prompt in enumerate(tokenized_dataset): + prompt = cast(dict, prompt) + for token_idx, token in enumerate(prompt["tokens"]): + mapping.setdefault(token, []).append((prompt_idx, token_idx)) + + return mapping diff --git a/tests/eval/test_token_map.py b/tests/eval/test_token_map.py new file mode 100644 index 00000000..5a276080 --- /dev/null +++ b/tests/eval/test_token_map.py @@ -0,0 +1,48 @@ +import pytest +from datasets import Dataset + +from delphi.eval.token_map import token_map + + +def test_token_map(): + tokenized_dataset = Dataset.from_dict( + { + "tokens": [ + [0, 1, 2, 3, 4, 5, 0, 6, 7], + [0, 1, 2, 3, 4, 5, 0, 6, 7], + [0, 1, 2, 3, 4, 5, 0, 6, 7], + ] + } + ) + mapping = token_map(tokenized_dataset) + assert mapping == { + 0: [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)], + 1: [(0, 1), (1, 1), (2, 1)], + 2: [(0, 2), (1, 2), (2, 2)], + 3: [(0, 3), (1, 3), (2, 3)], + 4: [(0, 4), (1, 4), (2, 4)], + 5: [(0, 5), (1, 5), (2, 5)], + 6: [(0, 7), (1, 7), (2, 7)], + 7: [(0, 8), (1, 8), (2, 8)], + } + + # fmt: off + tokenized_dataset = Dataset.from_dict( + { # one really long prompt + "tokens": [ + [0, 1, 2, 3, 4, 5, 0, 6, 7, 0, 1, 2, 3, 4, 5, 0, 6, 7, 0, 1, 2, 3, 4, 5, 0, 6, 7] + ] + } + ) + # fmt: on + mapping = token_map(tokenized_dataset) + assert mapping == { + 0: [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)], + 1: [(0, 1), (0, 10), (0, 19)], + 2: [(0, 2), (0, 11), (0, 20)], + 3: [(0, 3), (0, 12), (0, 21)], + 4: [(0, 4), (0, 13), (0, 22)], + 5: [(0, 5), (0, 14), (0, 23)], + 6: [(0, 7), (0, 16), (0, 25)], + 7: [(0, 8), (0, 17), (0, 26)], + }