Skip to content

Commit

Permalink
Tokenizer comparison script
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed May 20, 2024
1 parent 4918be0 commit a133002
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dom-tokenizers"
version = "0.0.7"
version = "0.0.8"
authors = [{ name = "Gary Benson" }]
description = "DOM-aware tokenizers for 🤗 Hugging Face language models"
readme = "README.md"
Expand Down Expand Up @@ -49,6 +49,7 @@ train = [
[project.scripts]
train-tokenizer = "dom_tokenizers.train:main"
dump-tokenizations = "dom_tokenizers.dump:main"
tokenizer-diff = "dom_tokenizers.diff:main"

[build-system]
requires = ["setuptools>=61.0"]
Expand Down
42 changes: 42 additions & 0 deletions src/dom_tokenizers/diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
import warnings

from argparse import ArgumentParser
from difflib import unified_diff

from .tokenizers import DOMSnapshotTokenizer

SEND_BUGS_TO = "https://github.com/gbenson/dom-tokenizers/issues"


def main():
parser = ArgumentParser(
description="Compare saved tokenizations with specified tokenizer's.",
epilog=f"Report bugs to: <{SEND_BUGS_TO}>.")

parser.add_argument(
"reference", metavar="FILENAME",
help="output from dump-tokenizers")
parser.add_argument(
"tokenizer", metavar="TOKENIZER",
help="tokenizer model name or path")
args = parser.parse_args()

warnings.filterwarnings("ignore", message=r".*resume_download.*")

tokenizer = DOMSnapshotTokenizer.from_pretrained(args.tokenizer)
assert tokenizer.backend_tokenizer.normalizer.strip_accents

for line in open(args.reference).readlines():
row = json.loads(line)
source_index = row["source_index"]
serialized = json.dumps(row["dom_snapshot"], separators=(",", ":"))
b = tokenizer.tokenize(serialized)
a = row["tokenized"]
if b == a:
continue
for line in unified_diff(
a, b,
fromfile=f"a/{source_index}",
tofile=f"b/{source_index}"):
print(line.rstrip())

0 comments on commit a133002

Please sign in to comment.