Skip to content

Commit

Permalink
use load_dataset_split_* utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 22, 2024
1 parent 3e2af53 commit 784c8f2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 20 deletions.
8 changes: 3 additions & 5 deletions scripts/tokenize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from pathlib import Path

from datasets import Dataset, Features, Value, load_dataset
from datasets import Dataset
from huggingface_hub import HfApi
from transformers import AutoTokenizer

Expand Down Expand Up @@ -81,10 +81,8 @@
), "You need to provide --out-repo-id or --out-dir"

print(f"Loading dataset '{args.in_repo_id}'...")
in_dataset_split = load_dataset(
args.in_repo_id,
split=args.split,
features=Features({args.feature: Value("string")}),
in_dataset_split = utils.load_dataset_split_string_feature(
args.in_repo_id, args.split, args.feature
)
assert isinstance(in_dataset_split, Dataset)
print(f"Loading tokenizer from '{args.tokenizer}'...")
Expand Down
10 changes: 5 additions & 5 deletions scripts/train_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#!/usr/bin/env python3
import argparse

from datasets import Dataset, Features, Value, load_dataset
from datasets import Dataset, Features, Value
from tokenizers import ByteLevelBPETokenizer # type: ignore
from transformers import PreTrainedTokenizerFast

from delphi import utils


def train_byte_level_bpe(
dataset: Dataset, feature: str, vocab_size: int
Expand Down Expand Up @@ -75,10 +77,8 @@ def train_byte_level_bpe(
), "You need to provide out_repo_id or out_dir"

print(f"Loading dataset '{args.in_repo_id}'...")
in_dataset_split = load_dataset(
args.in_repo_id,
split=args.split,
features=Features({args.feature: Value("string")}),
in_dataset_split = utils.load_dataset_split_string_feature(
args.repo_id, args.split, args.feature
)
assert isinstance(in_dataset_split, Dataset)
tokenizer = train_byte_level_bpe(
Expand Down
15 changes: 5 additions & 10 deletions src/delphi/train/config/dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
from typing import cast

import datasets
from beartype import beartype
from datasets import Dataset, load_dataset
from datasets import Dataset

from delphi import utils


@beartype
Expand All @@ -28,14 +28,9 @@ class DatasetConfig:
)

def _load(self, split) -> Dataset:
ds = load_dataset(
self.name,
split=split,
features=datasets.Features(
{self.feature: datasets.Sequence(datasets.Value("int32"))}
),
ds = utils.load_dataset_split_sequence_int32_feature(
self.name, split, self.feature
)
ds = cast(Dataset, ds)
ds.set_format("torch")
return ds

Expand Down

0 comments on commit 784c8f2

Please sign in to comment.