diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py index 0f9eeec1..5bafbffe 100755 --- a/scripts/map_tokens.py +++ b/scripts/map_tokens.py @@ -27,17 +27,19 @@ type=str, help="Hugging Face API token", ) + parser.add_argument( + "--tokenizer-size", + type=int, + default=4096, + help="Size of the tokenizer", + ) args = parser.parse_args() dataset = load_validation_dataset(args.dataset_name) - mapping = token_map( - dataset - ) # outputs the dictionary: dict[int, list[tuple[int, int]]] - - complete_mapping = [mapping.get(key, None) for key in range(4096)] - - hf_dataset = Dataset.from_dict({"prompt_pos_idx": complete_mapping}) + hf_dataset = Dataset.from_dict( + {"prompt_pos_idx": token_map(dataset, args.tokenizer_size)} + ) repo_id = f"{args.username}/v0-token-map" # location in to hf diff --git a/src/delphi/eval/token_map.py b/src/delphi/eval/token_map.py index 17e8e2ca..4ac7b0df 100644 --- a/src/delphi/eval/token_map.py +++ b/src/delphi/eval/token_map.py @@ -6,13 +6,13 @@ def token_map( tokenized_dataset: Dataset, -) -> dict[int, list[tuple[int, int]]]: + tokenizer_size: int, +) -> list[list[tuple[int, int]]]: """Return a mapping of tokens to their (prompt_idx, token_idx) locations in the tokenized_dataset.""" - mapping = {} + mapping = [[] for _ in range(tokenizer_size)] for prompt_idx, prompt in enumerate(tokenized_dataset): prompt = cast(dict, prompt) - for token_idx, token in enumerate(prompt["tokens"]): - mapping.setdefault(token, []).append((prompt_idx, token_idx)) - + for position_idx, token in enumerate(prompt["tokens"]): + mapping[token].append((prompt_idx, position_idx)) return mapping diff --git a/tests/eval/test_token_map.py b/tests/eval/test_token_map.py index 5a276080..2f896326 100644 --- a/tests/eval/test_token_map.py +++ b/tests/eval/test_token_map.py @@ -14,17 +14,17 @@ def test_token_map(): ] } ) - mapping = token_map(tokenized_dataset) - assert mapping == { - 0: [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)], - 1: [(0, 1), (1, 1), (2, 1)], - 2: [(0, 2), (1, 2), (2, 2)], - 3: [(0, 3), (1, 3), (2, 3)], - 4: [(0, 4), (1, 4), (2, 4)], - 5: [(0, 5), (1, 5), (2, 5)], - 6: [(0, 7), (1, 7), (2, 7)], - 7: [(0, 8), (1, 8), (2, 8)], - } + assert token_map(tokenized_dataset, tokenizer_size=9) == [ + [(0, 0), (0, 6), (1, 0), (1, 6), (2, 0), (2, 6)], + [(0, 1), (1, 1), (2, 1)], + [(0, 2), (1, 2), (2, 2)], + [(0, 3), (1, 3), (2, 3)], + [(0, 4), (1, 4), (2, 4)], + [(0, 5), (1, 5), (2, 5)], + [(0, 7), (1, 7), (2, 7)], + [(0, 8), (1, 8), (2, 8)], + [], # token 8 is not present in the dataset + ] # fmt: off tokenized_dataset = Dataset.from_dict( @@ -35,14 +35,13 @@ def test_token_map(): } ) # fmt: on - mapping = token_map(tokenized_dataset) - assert mapping == { - 0: [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)], - 1: [(0, 1), (0, 10), (0, 19)], - 2: [(0, 2), (0, 11), (0, 20)], - 3: [(0, 3), (0, 12), (0, 21)], - 4: [(0, 4), (0, 13), (0, 22)], - 5: [(0, 5), (0, 14), (0, 23)], - 6: [(0, 7), (0, 16), (0, 25)], - 7: [(0, 8), (0, 17), (0, 26)], - } + assert token_map(tokenized_dataset, tokenizer_size=8) == [ + [(0, 0), (0, 6), (0, 9), (0, 15), (0, 18), (0, 24)], + [(0, 1), (0, 10), (0, 19)], + [(0, 2), (0, 11), (0, 20)], + [(0, 3), (0, 12), (0, 21)], + [(0, 4), (0, 13), (0, 22)], + [(0, 5), (0, 14), (0, 23)], + [(0, 7), (0, 16), (0, 25)], + [(0, 8), (0, 17), (0, 26)], + ]