diff --git a/scripts/tokenize_dataset.py b/scripts/tokenize_dataset.py index c0c6d13f..8c326278 100755 --- a/scripts/tokenize_dataset.py +++ b/scripts/tokenize_dataset.py @@ -4,7 +4,7 @@ import os from pathlib import Path -from datasets import Dataset, Features, Value, load_dataset +from datasets import Dataset from huggingface_hub import HfApi from transformers import AutoTokenizer @@ -81,10 +81,8 @@ ), "You need to provide --out-repo-id or --out-dir" print(f"Loading dataset '{args.in_repo_id}'...") - in_dataset_split = load_dataset( - args.in_repo_id, - split=args.split, - features=Features({args.feature: Value("string")}), + in_dataset_split = utils.load_dataset_split_string_feature( + args.in_repo_id, args.split, args.feature ) assert isinstance(in_dataset_split, Dataset) print(f"Loading tokenizer from '{args.tokenizer}'...") diff --git a/scripts/train_tokenizer.py b/scripts/train_tokenizer.py index 83e071ae..bc9fa10a 100755 --- a/scripts/train_tokenizer.py +++ b/scripts/train_tokenizer.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 import argparse -from datasets import Dataset, Features, Value, load_dataset +from datasets import Dataset, Features, Value from tokenizers import ByteLevelBPETokenizer # type: ignore from transformers import PreTrainedTokenizerFast +from delphi import utils + def train_byte_level_bpe( dataset: Dataset, feature: str, vocab_size: int @@ -75,10 +77,8 @@ def train_byte_level_bpe( ), "You need to provide out_repo_id or out_dir" print(f"Loading dataset '{args.in_repo_id}'...") - in_dataset_split = load_dataset( - args.in_repo_id, - split=args.split, - features=Features({args.feature: Value("string")}), + in_dataset_split = utils.load_dataset_split_string_feature( + args.repo_id, args.split, args.feature ) assert isinstance(in_dataset_split, Dataset) tokenizer = train_byte_level_bpe( diff --git a/src/delphi/train/config/dataset_config.py b/src/delphi/train/config/dataset_config.py index 0c1e356e..bc2d9892 100644 --- a/src/delphi/train/config/dataset_config.py +++ b/src/delphi/train/config/dataset_config.py @@ -1,9 +1,9 @@ from dataclasses import dataclass, field -from typing import cast -import datasets from beartype import beartype -from datasets import Dataset, load_dataset +from datasets import Dataset + +from delphi import utils @beartype @@ -28,14 +28,9 @@ class DatasetConfig: ) def _load(self, split) -> Dataset: - ds = load_dataset( - self.name, - split=split, - features=datasets.Features( - {self.feature: datasets.Sequence(datasets.Value("int32"))} - ), + ds = utils.load_dataset_split_sequence_int32_feature( + self.name, split, self.feature ) - ds = cast(Dataset, ds) ds.set_format("torch") return ds