diff --git a/scripts/train_tokenizer.py b/scripts/train_tokenizer.py index f4585671..403b8489 100755 --- a/scripts/train_tokenizer.py +++ b/scripts/train_tokenizer.py @@ -1,84 +1,68 @@ #!/usr/bin/env python3 import argparse -import tempfile -from delphi.train.tokenizer import ( - hf_bpe_tokenizer_to_llama_tokenizer, - hf_dataset_to_text, - sp_processor_to_hf_bpe_tokenizer, - train_sentence_piece, -) +from datasets import Dataset, Features, Value, load_dataset +from tokenizers import ByteLevelBPETokenizer # type: ignore +from tqdm.auto import tqdm +from transformers import PreTrainedTokenizerFast -def main( - *, - vocab_size: int, - dataset_name: str, - split: str, - column: str, - repo_id: str, - hf_token: str, -): - """Trains a SentencePiece tokenizer, converts it to LlamaTokenizerFast and pushes it to the Hugging Face Hub.""" - with tempfile.TemporaryFile(mode="w+") as text_file: - print("Loading and writing dataset to text file...") - hf_dataset_to_text( - dataset_name=dataset_name, - split=split, - column=column, - text_file=text_file, - ) - text_file.seek(0) - print("Training SentencePiece tokenizer...\n") - sp_processor = train_sentence_piece( - vocab_size=vocab_size, - sentence_iterator=text_file, - ) - print("\nConverting SentencePiece tokenizer Llama tokenizer...") - hf_bpe_tokenizer = sp_processor_to_hf_bpe_tokenizer(sp_processor) - llama_tokenizer = hf_bpe_tokenizer_to_llama_tokenizer(hf_bpe_tokenizer) - print("Pushing Llama tokenizer to Hugging Face Hub...") - llama_tokenizer.push_to_hub( - repo_id=repo_id, - token=hf_token, +def train_byte_level_bpe( + dataset: Dataset, feature: str, vocab_size: int +) -> PreTrainedTokenizerFast: + tokenizer = ByteLevelBPETokenizer() + text_generator = (example[feature] for example in dataset) # type: ignore + tokenizer.train_from_iterator( + text_generator, + vocab_size=vocab_size, + special_tokens=["", ""], + show_progress=True, + length=len(dataset), + ) + return PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + bos_token="", + eos_token="", ) - print("Done.") if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Train a SentencePiece tokenizer and convert to HF" - ) + parser = argparse.ArgumentParser(description="", allow_abbrev=False) + parser.add_argument( - "--vocab-size", - "-v", - type=int, - help="Vocabulary size of the tokenizer", + "--in-repo-id", + "-i", + type=str, + required=True, + help="Input dataset", ) parser.add_argument( - "--dataset-name", - "-d", + "--feature", + "-f", type=str, - help="Dataset name with or without delphi-suite/ prefix", + required=True, + help="Name of the feature (column) containing text documents in the input dataset", ) parser.add_argument( "--split", "-s", type=str, - default="train", - help="Split of the dataset to be used for training, supports slicing like 'train[:10%%]'", + required=True, + help="Split of the dataset to be used for tokenizer training, supports slicing like 'train[:10%%]'", ) parser.add_argument( - "--column", - "-c", - type=str, - help="Column of the dataset to be used for training", + "--vocab-size", + "-v", + type=int, + required=True, + help="Vocabulary size of the tokenizer", ) parser.add_argument( - "--repo-id", - "-r", + "--out-repo-id", + "-o", type=str, - help="Hugging Face repository ID", + required=True, + help="Where to push the resulting tokenizer", ) parser.add_argument( "--hf-token", @@ -87,11 +71,20 @@ def main( help="Hugging Face API token", ) args = parser.parse_args() - main( - vocab_size=args.vocab_size, - dataset_name=args.dataset_name, + + print(f"Loading dataset '{args.in_repo_id}'...") + in_dataset_split = load_dataset( + args.in_repo_id, split=args.split, - column=args.column, - repo_id=args.repo_id, - hf_token=args.hf_token, + features=Features({args.feature: Value("string")}), + ) + assert isinstance(in_dataset_split, Dataset) + tokenizer = train_byte_level_bpe( + dataset=in_dataset_split, + feature=args.feature, + vocab_size=args.vocab_size, + ) + tokenizer.push_to_hub( + repo_id=args.out_repo_id, + token=args.hf_token, ) diff --git a/src/delphi/interp/__init__.py b/src/delphi/interp/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/delphi/train/tokenizer.py b/src/delphi/train/tokenizer.py deleted file mode 100644 index 190ccdee..00000000 --- a/src/delphi/train/tokenizer.py +++ /dev/null @@ -1,81 +0,0 @@ -import io -import os -import tempfile -from typing import cast - -from datasets import Dataset, load_dataset -from sentencepiece import SentencePieceProcessor, SentencePieceTrainer -from tokenizers import SentencePieceBPETokenizer # type: ignore -from transformers import LlamaTokenizerFast - - -def hf_dataset_to_text( - dataset_name: str, split: str, column: str, text_file: io.TextIOBase -): - dataset = cast(Dataset, load_dataset(dataset_name, split=split)) - for text in dataset[column]: - text = text.strip() - text_file.write(text + "\n") - - -def train_sentence_piece( - vocab_size: int, - sentence_iterator: io.TextIOBase, -) -> SentencePieceProcessor: - """Trains a custom SentencePiece tokenizer.""" - model = io.BytesIO() - SentencePieceTrainer.train( # type: ignore - sentence_iterator=sentence_iterator, - model_writer=model, - model_type="bpe", - vocab_size=vocab_size, - self_test_sample_size=0, - character_coverage=1.0, - num_threads=os.cpu_count(), - split_digits=True, - allow_whitespace_only_pieces=True, - byte_fallback=True, - unk_surface=r" \342\201\207 ", - normalization_rule_name="identity", - ) - return SentencePieceProcessor(model_proto=model.getvalue()) # type: ignore - - -def sp_processor_to_hf_bpe_tokenizer( - sp_processor: SentencePieceProcessor, -) -> SentencePieceBPETokenizer: - """Converts a SentencePieceProcessor to a SentencePieceBPETokenizer.""" - vocab = { - sp_processor.id_to_piece(index): index # type: ignore - for index in range(sp_processor.GetPieceSize()) - } - merges = [] - for piece_l in vocab.keys(): - for piece_r in vocab.keys(): - merge = f"{piece_l}{piece_r}" - piece_id = vocab.get(merge, None) - if piece_id: - merges += [(piece_l, piece_r, piece_id)] - merges = sorted(merges, key=lambda val: val[2]) - merges = [(val[0], val[1]) for val in merges] - - return SentencePieceBPETokenizer(vocab, merges) - - -def hf_bpe_tokenizer_to_llama_tokenizer( - hf_bpe_tokenizer: SentencePieceBPETokenizer, -) -> LlamaTokenizerFast: - with tempfile.NamedTemporaryFile(mode="w+", suffix=".json") as tmp_json_file: - hf_bpe_tokenizer.save(tmp_json_file.name) - return LlamaTokenizerFast( - tokenizer_file=tmp_json_file.name, - unk_token="", - unk_token_id=0, - bos_token="", - bos_token_id=1, - eos_token="", - eos_token_id=2, - pad_token="", - pad_token_id=3, - padding_side="right", - )