Skip to content

Commit

Permalink
Dataset tests (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier authored Jan 22, 2025
1 parent 5ba311c commit 7660ba4
Show file tree
Hide file tree
Showing 11 changed files with 516 additions and 94 deletions.
3 changes: 2 additions & 1 deletion fast_llm/data/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import pathlib

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.utils import Assert
Expand Down Expand Up @@ -28,7 +29,7 @@ class TokenizerConfig(Config):
hint=FieldHint.deprecated,
valid=check_field(Assert.eq, TokenizerFromFile),
)
path: str | None = Field(
path: pathlib.Path | None = Field(
default=None,
desc="Path to the tokenizer file.",
hint=FieldHint.core,
Expand Down
20 changes: 8 additions & 12 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import CopySplitDataset, PhaseSplits, SampledSplitDataset
from fast_llm.data.dataset.blended import BlendedDataset
from fast_llm.data.dataset.gpt.config import DatasetSource, GPTSamplingConfig
from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset
from fast_llm.data.dataset.gpt.fim import FimDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig, LegacyDatasetSource
from fast_llm.data.dataset.gpt.fim import GPTFimDataset
from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.tokenizer import Tokenizer
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
}

data_base_path = None
if self._config.format == DatasetSource.file:
if self._config.format == LegacyDatasetSource.file:
Assert.eq(len(self._config.path), 1)
data_path = pathlib.Path(self._config.path[0])
dataset_defs = json.load(data_path.open("r"))
Expand All @@ -73,7 +73,7 @@ def __init__(
[dataset_def["weight"] for dataset_def in dataset_defs["datasets"]]
)
self._build_and_sample_dataset = self._build_and_sample_gpt_dataset
elif self._config.format == DatasetSource.list:
elif self._config.format == LegacyDatasetSource.list:
Assert.geq(len(self._config.path), 1)
if len(self._config.path) == 1:
dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0]
Expand All @@ -83,11 +83,7 @@ def __init__(
assert len(dataset_prefixes) == len(set(dataset_prefixes))
dataset_weights = normalize_probabilities([float(x) for x in self._config.path[::2]])
self._build_and_sample_dataset = self._build_and_sample_gpt_dataset
elif self._config.format == DatasetSource.sample:
Assert.eq(len(self._config.path), 1)
dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0]
self._build_and_sample_dataset = self._build_and_sample_dummy_dataset
elif self._config.format == DatasetSource.random:
elif self._config.format == LegacyDatasetSource.random:
Assert.eq(len(self._config.path), 0)
dataset_prefixes, dataset_weights = [None], [1.0]
self._build_and_sample_dataset = self._build_and_sample_dummy_dataset
Expand Down Expand Up @@ -245,7 +241,7 @@ def _build_and_sample_gpt_dataset(self, name: str, sampling_configs: PhaseSplits
datasets = SampledSplitDataset[GPTDatasetSlice](
"fim",
{
phase: FimDataset(self.config.fim, dataset, sampling_configs[phase])
phase: GPTFimDataset(self.config.fim, dataset, sampling_configs[phase])
for phase, dataset in datasets.items()
},
)
Expand All @@ -254,6 +250,6 @@ def _build_and_sample_gpt_dataset(self, name: str, sampling_configs: PhaseSplits
def _build_and_sample_dummy_dataset(self, name: str, sampling_configs: PhaseSplits[GPTSamplingConfig]):
return CopySplitDataset(
f"{name}_split",
GPTDummyDataset(name),
GPTRandomDataset(name),
list(sampling_configs),
).sample(sampling_configs)
33 changes: 20 additions & 13 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,32 @@ class FimConfig(Config):
desc="TODO.",
hint=FieldHint.feature,
)
prefix_token: str = Field(
default="<fim_prefix>",
desc="TODO.",
hint=FieldHint.feature,
)
middle_token: str = Field(
default="<fim_middle>",
desc="TODO.",
hint=FieldHint.feature,
)
pad_token: str = Field(
default="<fim_pad>",
desc="TODO.",
hint=FieldHint.feature,
)
suffix_token: str = Field(
default="<fim_suffix>",
desc="TODO.",
hint=FieldHint.feature,
)

def _validate(self):
super()._validate()
Assert.in_range_incl(self.rate, 0, 1)


class DatasetSource(str, enum.Enum):
"""
An enum for the different ways to load datasets.
TODO: Reduce the diversity?
TODO: Is this specific to GPT data?
"""

list = "list"
file = "file"
sample = "sample"
random = "random"


class LegacyDatasetSource(str, enum.Enum):
"""
An enum for the different ways to load datasets.
Expand Down
10 changes: 3 additions & 7 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@
from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingConfig
from fast_llm.engine.distributed.config import MAX_SEED

FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_PAD = "<fim_pad>"
FIM_SUFFIX = "<fim_suffix>"


class FimDataset(SampledDataset):
class GPTFimDataset(SampledDataset):
"""
An implementation of FIM (fill in the middle) post-processing of GPT datasets.
Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py
Expand All @@ -27,7 +22,8 @@ def __init__(
self._sampling_config = sampling_config
self._tokenizer = sampling_config.tokenizer
self._suffix_tok_id, self._prefix_tok_id, self._middle_tok_id, self._pad_tok_id = (
self._tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]
self._tokenizer.vocab[tok]
for tok in [config.suffix_token, config.prefix_token, config.middle_token, config.pad_token]
)
self.fim_split_sample = (
self._tokenizer.vocab[self._config.split_sample] if self._config.split_sample is not None else None
Expand Down
6 changes: 4 additions & 2 deletions fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def sample(self, config: GPTSamplingConfig) -> "GPTSampledIndexedDataset":
return GPTSampledIndexedDataset(self, config)


class GPTDatasetSlice(DatasetSlice, GPTIndexedDataset):
class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset):
"""
A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset.
"""
Expand Down Expand Up @@ -56,7 +56,9 @@ def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, fl
)


class GPTConcatenatedDataset(ConcatenatedDataset, GPTIndexedDataset):
class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset
):
_datasets: list[GPTIndexedDataset]

def get_document_sizes(self) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig


class GPTDummyDataset(SamplableDataset):
class GPTRandomDataset(SamplableDataset):
"""
A dummy dataset that always returns the same random sample, for debugging purposes.
"""

def __init__(self, name: str):
self._name = name

def sample(self, config: GPTSamplingConfig) -> "GPTDummySampledDataset":
return GPTDummySampledDataset(f"{self.name}_sampled", config)
def sample(self, config: GPTSamplingConfig) -> "GPTRandomSampledDataset":
return GPTRandomSampledDataset(config, f"{self.name}_sampled")

@property
def name(self) -> str:
return self._name


class GPTDummySampledDataset(SampledDataset):
def __init__(self, name: str, config: GPTSamplingConfig):
class GPTRandomSampledDataset(SampledDataset):
def __init__(self, config: GPTSamplingConfig, name: str):
self._name = name
self._config = config

def __len__(self) -> int:
return self._config.num_samples

def __getitem__(self, idx) -> np.ndarray:
return np.random.RandomState(self._config.seed + 4857643).randint(
return np.random.RandomState(self._config.seed + 48576439 + 74593 * idx).randint(
0, self._config.vocab_size, size=(self._config.sequence_length + 1,), dtype=np.int64
)

Expand Down
107 changes: 69 additions & 38 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
import pathlib
import typing
Expand All @@ -18,6 +19,8 @@
except ImportError:
_extension_available = False

logger = logging.getLogger(__name__)


class GPTSampledIndexedDataset(SampledDataset):
"""
Expand All @@ -35,33 +38,41 @@ def __init__(
assert isinstance(sampling_config, GPTSamplingConfig)
self._indexed_dataset = indexed_dataset

cache_prefix = (
f"{self.name}_ns_{sampling_config.num_samples}_sl_{sampling_config.sequence_length}"
f"_s_{sampling_config.seed}"
)
# TODO: Any way to combine into a single file? (Memmap is harder)
self._doc_idx_filename = sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy")
self._sample_idx_filename = sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy")
self._shuffle_idx_filename = sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy")

group = sampling_config.distributed.world_group
# Build the indexed mapping if it doesn't exist.
# TODO: This only works if the dataset location is accessible by all job.
if (group is None or group.rank() == 0) and not (
self._doc_idx_filename.is_file()
and self._sample_idx_filename.is_file()
and self._shuffle_idx_filename.is_file()
):
if sampling_config.verbose:

if sampling_config.cache_directory is None:
log_main_rank(
" > No dataset cache directory provided, building the index map on all ranks."
"This may be very inefficient...",
log_fn=logger.warning,
)
self._doc_idx, self._sample_idx, self._shuffle_idx = self._sample(sampling_config)
else:
cache_prefix = (
f"{self.name}_ns_{sampling_config.num_samples}_sl_{sampling_config.sequence_length}"
f"_s_{sampling_config.seed}"
)
# TODO: Any way to combine into a single file? (Memmap is harder)
self._doc_idx_filename = sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy")
self._sample_idx_filename = sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy")
self._shuffle_idx_filename = sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy")

# Build the indexed mapping if it doesn't exist.
# TODO: This only works if the dataset location is accessible by all job.
if (group is None or group.rank() == 0) and not (
self._doc_idx_filename.is_file()
and self._sample_idx_filename.is_file()
and self._shuffle_idx_filename.is_file()
):
log_main_rank(" > Building the index map on rank 0 ...")
doc_idx, sample_idx, shuffle_idx = self._sample(sampling_config)
sampling_config.cache_directory.mkdir(parents=True, exist_ok=True)
np.save(self._doc_idx_filename, doc_idx)
np.save(self._sample_idx_filename, sample_idx)
np.save(self._shuffle_idx_filename, shuffle_idx)
doc_idx, sample_idx, shuffle_idx = self._sample(sampling_config)
sampling_config.cache_directory.mkdir(parents=True, exist_ok=True)
np.save(self._doc_idx_filename, doc_idx)
np.save(self._sample_idx_filename, sample_idx)
np.save(self._shuffle_idx_filename, shuffle_idx)

safe_barrier(group, self._indexed_dataset.name)
self._load_mappings(sampling_config.verbose)
self._load_mappings(True)

def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Expand Down Expand Up @@ -100,7 +111,7 @@ def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.nd
sampling_config.sequence_length,
num_epochs,
num_tokens,
sampling_config.verbose,
True,
)

# shuffle-idx.
Expand All @@ -121,24 +132,44 @@ def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.nd
# TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch.
return doc_idx, sample_idx, shuffle_idx[: sampling_config.num_samples]

def __getstate__(self) -> tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]:
return (
self._indexed_dataset,
self._doc_idx_filename,
self._sample_idx_filename,
self._shuffle_idx_filename,
)
def __getstate__(
self,
) -> tuple[GPTIndexedDataset, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray]:
if hasattr(self, "_doc_idx_filename"):
return (
self._indexed_dataset,
self._doc_idx_filename,
self._sample_idx_filename,
self._shuffle_idx_filename,
)
else:
return (
self._indexed_dataset,
self._doc_idx,
self._sample_idx,
self._shuffle_idx,
)

def __setstate__(self, state: tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]) -> None:
(
self._indexed_dataset,
self._doc_idx_filename,
self._sample_idx_filename,
self._shuffle_idx_filename,
) = state
if isinstance(state[1], pathlib.Path):
(
self._indexed_dataset,
self._doc_idx_filename,
self._sample_idx_filename,
self._shuffle_idx_filename,
) = state
else:
(
self._indexed_dataset,
self._doc_idx,
self._sample_idx,
self._shuffle_idx,
) = state
self._load_mappings(False)

def _load_mappings(self, verbose: bool) -> None:
if hasattr(self, "_doc_idx"):
return
if verbose:
log_main_rank(lambda: f" > loading doc-idx mapping from {self._doc_idx_filename}")
self._doc_idx = np.load(self._doc_idx_filename, mmap_mode="r")
Expand Down Expand Up @@ -169,7 +200,7 @@ def __getitem__(self, idx: int) -> typing.Any:
doc_l, offset_l = self._sample_idx[shuffled_idx + 1]
sample_list = [
self._indexed_dataset.get(
self._doc_idx[doc],
self._doc_idx[doc].item(),
offset=(doc == doc_f) * offset_f,
length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None,
)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __len__(self) -> int:
"""


class DatasetSlice(IndexedDataset):
class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset):

def __init__(
self,
Expand Down
10 changes: 9 additions & 1 deletion fast_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def rms_close(x, y, threshold):
def all_equal(x, y):
import torch

# Make it work for numpy arrays.
x = torch.as_tensor(x)
y = torch.as_tensor(y)

neq = x != y
if neq.any().item(): # noqa
index = torch.where(neq) # noqa
Expand All @@ -156,9 +160,13 @@ def all_equal(x, y):
def all_different(x, y):
import torch

# Make it work for numpy arrays.
x = torch.as_tensor(x)
y = torch.as_tensor(y)

eq = x == y
if eq.any().item(): # noqa
index = torch.where(eq) # noqa
index = torch.where(torch.as_tensor(eq)) # noqa
raise AssertionError(
f"Tensors have {index[0].numel()} unexpected matching entries out of "
f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}"
Expand Down
Loading

0 comments on commit 7660ba4

Please sign in to comment.