Skip to content

Commit

Permalink
Ported token mapping function (#22)
Browse files Browse the repository at this point in the history
* add token mapping

* use cast for typing

* moved token_map function to library

* added test case for token_map

* added test cases for token_map

* review changes

* add load_hf_dataset function

* shorten docstring

* revert pickling changes

* added token mapping script
  • Loading branch information
menamerai authored Feb 14, 2024
1 parent 4ed8b19 commit b8d0d8c
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
22 changes: 22 additions & 0 deletions scripts/map_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python3

import argparse
import pickle

from delphi.eval.token_map import token_map
from delphi.eval.utils import load_validation_dataset

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"dataset_name", help="Dataset from huggingface to run token_map on"
)
parser.add_argument("--output", help="Output file name", default="token_map.pkl")
args = parser.parse_args()

dataset = load_validation_dataset(args.dataset_name)

mapping = token_map(dataset)

with open(f"data/{args.output}", "wb") as f:
pickle.dump(mapping, file=f)
18 changes: 18 additions & 0 deletions src/delphi/eval/token_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
from typing import cast

from datasets import Dataset


def token_map(
tokenized_dataset: Dataset,
) -> dict[int, list[tuple[int, int]]]:
"""Return a mapping of tokens to their (prompt_idx, token_idx) locations in the tokenized_dataset."""

mapping = {}
for prompt_idx, prompt in enumerate(tokenized_dataset):
prompt = cast(dict, prompt)
for token_idx, token in enumerate(prompt["tokens"]):
mapping.setdefault(token, []).append((prompt_idx, token_idx))

return mapping
48 changes: 48 additions & 0 deletions tests/eval/test_token_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from datasets import Dataset

from delphi.eval.token_map import token_map


def test_token_map():
tokenized_dataset = Dataset.from_dict(
{
"tokens": [
[0, 1, 2, 3, 4, 5, 0, 6, 7],
[0, 1, 2, 3, 4, 5, 0, 6, 7],
[0, 1, 2, 3, 4, 5, 0, 6, 7],
]
}
)
mapping = token_map(tokenized_dataset)
assert mapping == {
0: [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)],
1: [(0, 1), (1, 1), (2, 1)],
2: [(0, 2), (1, 2), (2, 2)],
3: [(0, 3), (1, 3), (2, 3)],
4: [(0, 4), (1, 4), (2, 4)],
5: [(0, 5), (1, 5), (2, 5)],
6: [(0, 7), (1, 7), (2, 7)],
7: [(0, 8), (1, 8), (2, 8)],
}

# fmt: off
tokenized_dataset = Dataset.from_dict(
{ # one really long prompt
"tokens": [
[0, 1, 2, 3, 4, 5, 0, 6, 7, 0, 1, 2, 3, 4, 5, 0, 6, 7, 0, 1, 2, 3, 4, 5, 0, 6, 7]
]
}
)
# fmt: on
mapping = token_map(tokenized_dataset)
assert mapping == {
0: [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)],
1: [(0, 1), (0, 10), (0, 19)],
2: [(0, 2), (0, 11), (0, 20)],
3: [(0, 3), (0, 12), (0, 21)],
4: [(0, 4), (0, 13), (0, 22)],
5: [(0, 5), (0, 14), (0, 23)],
6: [(0, 7), (0, 16), (0, 25)],
7: [(0, 8), (0, 17), (0, 26)],
}

0 comments on commit b8d0d8c

Please sign in to comment.