From 6177d73ada0569120d2d80bb97647fa19a941094 Mon Sep 17 00:00:00 2001 From: transcendingvictor <92531256+transcendingvictor@users.noreply.github.com> Date: Wed, 28 Feb 2024 07:04:41 +0100 Subject: [PATCH] map tokens to hf (#44) * map_tokens from risky pickle to safe hf * mapping sorted and uploaded correctly * using a list, not a dict * (unimportant) cleaning up some lines * change token_map to return list * added split_slice arg to load_validation_dataset() --- scripts/map_tokens.py | 41 ++++++++++++++++++++++++++++------ src/delphi/eval/token_map.py | 10 ++++----- src/delphi/eval/utils.py | 4 ++-- tests/eval/test_token_map.py | 43 ++++++++++++++++++------------------ 4 files changed, 62 insertions(+), 36 deletions(-) diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py index 832ac0da..5bafbffe 100755 --- a/scripts/map_tokens.py +++ b/scripts/map_tokens.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 import argparse -import os -import pickle + +import pandas as pd +from datasets import Dataset from delphi.constants import STATIC_ASSETS_DIR from delphi.eval.token_map import token_map @@ -10,15 +11,41 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="") + + parser.add_argument( + "dataset_name", + type=str, + help="Dataset from huggingface to run token_map on", + ) + parser.add_argument( + "--username", + type=str, + help="Hugging Face API username", + ) parser.add_argument( - "dataset_name", help="Dataset from huggingface to run token_map on" + "--token", + type=str, + help="Hugging Face API token", + ) + parser.add_argument( + "--tokenizer-size", + type=int, + default=4096, + help="Size of the tokenizer", ) - parser.add_argument("--output", help="Output file name", default="token_map.pkl") args = parser.parse_args() dataset = load_validation_dataset(args.dataset_name) - mapping = token_map(dataset) + hf_dataset = Dataset.from_dict( + {"prompt_pos_idx": token_map(dataset, args.tokenizer_size)} + ) - with open(f"{STATIC_ASSETS_DIR}/{args.output}", "wb") as f: - pickle.dump(mapping, file=f) + repo_id = f"{args.username}/v0-token-map" # location in to hf + + hf_dataset.push_to_hub( + repo_id=repo_id, + split="validation", + private=False, + token=args.token, + ) diff --git a/src/delphi/eval/token_map.py b/src/delphi/eval/token_map.py index 17e8e2ca..4ac7b0df 100644 --- a/src/delphi/eval/token_map.py +++ b/src/delphi/eval/token_map.py @@ -6,13 +6,13 @@ def token_map( tokenized_dataset: Dataset, -) -> dict[int, list[tuple[int, int]]]: + tokenizer_size: int, +) -> list[list[tuple[int, int]]]: """Return a mapping of tokens to their (prompt_idx, token_idx) locations in the tokenized_dataset.""" - mapping = {} + mapping = [[] for _ in range(tokenizer_size)] for prompt_idx, prompt in enumerate(tokenized_dataset): prompt = cast(dict, prompt) - for token_idx, token in enumerate(prompt["tokens"]): - mapping.setdefault(token, []).append((prompt_idx, token_idx)) - + for position_idx, token in enumerate(prompt["tokens"]): + mapping[token].append((prompt_idx, position_idx)) return mapping diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index d0611afb..0c4a8a6f 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -66,7 +66,7 @@ def get_next_and_top_k_probs( return next_probs, top_k -def load_validation_dataset(dataset_name: str) -> Dataset: +def load_validation_dataset(dataset_name: str, split_slice: str = "") -> Dataset: if "/" not in dataset_name: dataset_name = f"delphi-suite/{dataset_name}" data_str = f"data/validation-*.parquet" @@ -76,7 +76,7 @@ def load_validation_dataset(dataset_name: str) -> Dataset: verification_mode="no_checks", # this seems to be the only split when using data_files # regardless of the files we're actually loading - split="train", + split=f"train{split_slice}", ) return cast(Dataset, dataset) diff --git a/tests/eval/test_token_map.py b/tests/eval/test_token_map.py index 5a276080..2f896326 100644 --- a/tests/eval/test_token_map.py +++ b/tests/eval/test_token_map.py @@ -14,17 +14,17 @@ def test_token_map(): ] } ) - mapping = token_map(tokenized_dataset) - assert mapping == { - 0: [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)], - 1: [(0, 1), (1, 1), (2, 1)], - 2: [(0, 2), (1, 2), (2, 2)], - 3: [(0, 3), (1, 3), (2, 3)], - 4: [(0, 4), (1, 4), (2, 4)], - 5: [(0, 5), (1, 5), (2, 5)], - 6: [(0, 7), (1, 7), (2, 7)], - 7: [(0, 8), (1, 8), (2, 8)], - } + assert token_map(tokenized_dataset, tokenizer_size=9) == [ + [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)], + [(0, 1), (1, 1), (2, 1)], + [(0, 2), (1, 2), (2, 2)], + [(0, 3), (1, 3), (2, 3)], + [(0, 4), (1, 4), (2, 4)], + [(0, 5), (1, 5), (2, 5)], + [(0, 7), (1, 7), (2, 7)], + [(0, 8), (1, 8), (2, 8)], + [], # token 8 is not present in the dataset + ] # fmt: off tokenized_dataset = Dataset.from_dict( @@ -35,14 +35,13 @@ def test_token_map(): } ) # fmt: on - mapping = token_map(tokenized_dataset) - assert mapping == { - 0: [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)], - 1: [(0, 1), (0, 10), (0, 19)], - 2: [(0, 2), (0, 11), (0, 20)], - 3: [(0, 3), (0, 12), (0, 21)], - 4: [(0, 4), (0, 13), (0, 22)], - 5: [(0, 5), (0, 14), (0, 23)], - 6: [(0, 7), (0, 16), (0, 25)], - 7: [(0, 8), (0, 17), (0, 26)], - } + assert token_map(tokenized_dataset, tokenizer_size=8) == [ + [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)], + [(0, 1), (0, 10), (0, 19)], + [(0, 2), (0, 11), (0, 20)], + [(0, 3), (0, 12), (0, 21)], + [(0, 4), (0, 13), (0, 22)], + [(0, 5), (0, 14), (0, 23)], + [(0, 7), (0, 16), (0, 25)], + [(0, 8), (0, 17), (0, 26)], + ]