Skip to content

Commit

Permalink
Lift some DOMSnapshotPreTokenizer into PreTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed May 23, 2024
1 parent 63360ca commit 8f2130b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
15 changes: 7 additions & 8 deletions src/dom_tokenizers/pre_tokenizers/dom_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import magic

from tokenizers import NormalizedString, PreTokenizedString
from tokenizers import NormalizedString
from unidecode import unidecode

from .pre_tokenizer import PreTokenizer
Expand All @@ -36,19 +36,18 @@ def special_tokens(self):
if attr.endswith("token")
]

def pre_tokenize(self, pretok: PreTokenizedString):
"""Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place.
def pre_tokenize_dom(self, serialized: str) -> Iterable[str]:
"""Transform a serialized DOM into a sequence of tokens.
"""
pretok.split(self._split_json)

def _split_json(self, i: int, s: NormalizedString) -> List[NormalizedString]:
snapshot = json.loads(s.normalized)
snapshot = json.loads(serialized)

# Unpack the snapshot if what we have is a raw browser response
if not any(key in snapshot for key in ("documents", "strings")):
snapshot = snapshot.get("result", snapshot)

return list(chain.from_iterable(self._split_serialized(snapshot)))
return (ns.original
for ns in chain.from_iterable(
self._split_serialized(snapshot)))

def _split_serialized(self, snapshot: dict) -> Iterable[List[NormalizedString]]:
emitter = TokenEmitter(self, snapshot)
Expand Down
36 changes: 35 additions & 1 deletion src/dom_tokenizers/pre_tokenizers/pre_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import logging
import weakref

from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import List

from tokenizers import NormalizedString, PreTokenizedString
from tokenizers.pre_tokenizers import PreTokenizer as _PreTokenizer

logger = logging.getLogger(__name__)


class PreTokenizer:
class PreTokenizer(ABC):
@classmethod
def hook_into(cls, tokenizer):
"""Reconfigure `tokenizer` for DOM-aware pre-tokenization.
Expand Down Expand Up @@ -42,3 +50,29 @@ def bind_to(self, tokenizer):
if getattr(backend.normalizer, "lowercase", None) is True:
backend.normalizer.lowercase = False
self._lowercase_output = True

# Entry point

def pre_tokenize(self, pretok: PreTokenizedString):
pretok.split(self._pre_tokenize_dom)
pre_tokenize.__doc__ = _PreTokenizer.pre_tokenize.__doc__

def _pre_tokenize_dom(
self,
index: int,
split: NormalizedString,
) -> List[NormalizedString]:
try:
return [
NormalizedString(token)
for token in self.pre_tokenize_dom(split.original)
]
except Exception as e:
logger.exception(f"{type(e).__name__} in pre-tokenizer:")
raise

@abstractmethod
def pre_tokenize_dom(self, serialized: str) -> Iterable[str]:
"""Transform a serialized DOM into a sequence of tokens.
"""
raise NotImplementedError

0 comments on commit 8f2130b

Please sign in to comment.