Skip to content

Commit

Permalink
map tokens to hf (#44)
Browse files Browse the repository at this point in the history
* map_tokens from risky pickle to safe hf

* mapping sorted and uploaded correctly

* using a list, not a dict

* (unimportant) cleaning up some lines

* change token_map to return list

* added split_slice arg to load_validation_dataset()
  • Loading branch information
transcendingvictor authored Feb 28, 2024
1 parent 6dd4ac8 commit 6177d73
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 36 deletions.
41 changes: 34 additions & 7 deletions scripts/map_tokens.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,51 @@
#!/usr/bin/env python3

import argparse
import os
import pickle

import pandas as pd
from datasets import Dataset

from delphi.constants import STATIC_ASSETS_DIR
from delphi.eval.token_map import token_map
from delphi.eval.utils import load_validation_dataset

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")

parser.add_argument(
"dataset_name",
type=str,
help="Dataset from huggingface to run token_map on",
)
parser.add_argument(
"--username",
type=str,
help="Hugging Face API username",
)
parser.add_argument(
"dataset_name", help="Dataset from huggingface to run token_map on"
"--token",
type=str,
help="Hugging Face API token",
)
parser.add_argument(
"--tokenizer-size",
type=int,
default=4096,
help="Size of the tokenizer",
)
parser.add_argument("--output", help="Output file name", default="token_map.pkl")
args = parser.parse_args()

dataset = load_validation_dataset(args.dataset_name)

mapping = token_map(dataset)
hf_dataset = Dataset.from_dict(
{"prompt_pos_idx": token_map(dataset, args.tokenizer_size)}
)

with open(f"{STATIC_ASSETS_DIR}/{args.output}", "wb") as f:
pickle.dump(mapping, file=f)
repo_id = f"{args.username}/v0-token-map" # location in to hf

hf_dataset.push_to_hub(
repo_id=repo_id,
split="validation",
private=False,
token=args.token,
)
10 changes: 5 additions & 5 deletions src/delphi/eval/token_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

def token_map(
tokenized_dataset: Dataset,
) -> dict[int, list[tuple[int, int]]]:
tokenizer_size: int,
) -> list[list[tuple[int, int]]]:
"""Return a mapping of tokens to their (prompt_idx, token_idx) locations in the tokenized_dataset."""

mapping = {}
mapping = [[] for _ in range(tokenizer_size)]
for prompt_idx, prompt in enumerate(tokenized_dataset):
prompt = cast(dict, prompt)
for token_idx, token in enumerate(prompt["tokens"]):
mapping.setdefault(token, []).append((prompt_idx, token_idx))

for position_idx, token in enumerate(prompt["tokens"]):
mapping[token].append((prompt_idx, position_idx))
return mapping
4 changes: 2 additions & 2 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_next_and_top_k_probs(
return next_probs, top_k


def load_validation_dataset(dataset_name: str) -> Dataset:
def load_validation_dataset(dataset_name: str, split_slice: str = "") -> Dataset:
if "/" not in dataset_name:
dataset_name = f"delphi-suite/{dataset_name}"
data_str = f"data/validation-*.parquet"
Expand All @@ -76,7 +76,7 @@ def load_validation_dataset(dataset_name: str) -> Dataset:
verification_mode="no_checks",
# this seems to be the only split when using data_files
# regardless of the files we're actually loading
split="train",
split=f"train{split_slice}",
)
return cast(Dataset, dataset)

Expand Down
43 changes: 21 additions & 22 deletions tests/eval/test_token_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ def test_token_map():
]
}
)
mapping = token_map(tokenized_dataset)
assert mapping == {
0: [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)],
1: [(0, 1), (1, 1), (2, 1)],
2: [(0, 2), (1, 2), (2, 2)],
3: [(0, 3), (1, 3), (2, 3)],
4: [(0, 4), (1, 4), (2, 4)],
5: [(0, 5), (1, 5), (2, 5)],
6: [(0, 7), (1, 7), (2, 7)],
7: [(0, 8), (1, 8), (2, 8)],
}
assert token_map(tokenized_dataset, tokenizer_size=9) == [
[(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)],
[(0, 1), (1, 1), (2, 1)],
[(0, 2), (1, 2), (2, 2)],
[(0, 3), (1, 3), (2, 3)],
[(0, 4), (1, 4), (2, 4)],
[(0, 5), (1, 5), (2, 5)],
[(0, 7), (1, 7), (2, 7)],
[(0, 8), (1, 8), (2, 8)],
[], # token 8 is not present in the dataset
]

# fmt: off
tokenized_dataset = Dataset.from_dict(
Expand All @@ -35,14 +35,13 @@ def test_token_map():
}
)
# fmt: on
mapping = token_map(tokenized_dataset)
assert mapping == {
0: [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)],
1: [(0, 1), (0, 10), (0, 19)],
2: [(0, 2), (0, 11), (0, 20)],
3: [(0, 3), (0, 12), (0, 21)],
4: [(0, 4), (0, 13), (0, 22)],
5: [(0, 5), (0, 14), (0, 23)],
6: [(0, 7), (0, 16), (0, 25)],
7: [(0, 8), (0, 17), (0, 26)],
}
assert token_map(tokenized_dataset, tokenizer_size=8) == [
[(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)],
[(0, 1), (0, 10), (0, 19)],
[(0, 2), (0, 11), (0, 20)],
[(0, 3), (0, 12), (0, 21)],
[(0, 4), (0, 13), (0, 22)],
[(0, 5), (0, 14), (0, 23)],
[(0, 7), (0, 16), (0, 25)],
[(0, 8), (0, 17), (0, 26)],
]

0 comments on commit 6177d73

Please sign in to comment.