Skip to content

Commit

Permalink
Better timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 24, 2025
1 parent 2814e92 commit 3cc977f
Show file tree
Hide file tree
Showing 17 changed files with 144 additions and 103 deletions.
12 changes: 11 additions & 1 deletion fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import contextlib
import datetime
import logging
import typing

Expand All @@ -25,6 +26,12 @@
logger = logging.getLogger(__name__)


def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None:
if timeout is not None:
# TODO: Only works for nccl?
group._add_ephemeral_timeout(datetime.timedelta(seconds=timeout))


def broadcast(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False) -> Work | None:
"""Same as torch.distributed.broadcast, but without the complication of going through the global rank."""
assert group is not None
Expand Down Expand Up @@ -53,9 +60,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
)


def safe_barrier(group: ProcessGroup | None, value: int | str = 1) -> None:
def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None:
if group:
hashed = hash(value) % 2**32
add_ephemeral_timeout(group, timeout)
out = allreduce_scalar(hashed, dtype=torch.int64, group=group)
if out != hashed * group.size():
raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})")
Expand All @@ -80,12 +88,14 @@ def broadcast_scalar(
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
src: int = 0,
timeout: float | None = None,
) -> float | int:
if not group:
return value
tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device()))
if group.rank() == src:
tensor.fill_(value)
add_ephemeral_timeout(group, timeout)
broadcast(tensor, src, group)
return tensor.item()

Expand Down
22 changes: 2 additions & 20 deletions fast_llm/csrc/data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ void build_blending_indices(py::array_t<int16_t>& dataset_index,
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch,
const bool verbose) {
const int64_t num_samples) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
Expand All @@ -115,29 +113,14 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,

// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);

// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();

// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];

if (verbose) {
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
}

// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
Expand All @@ -152,7 +135,7 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
while (remaining_seq_length > 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
Expand All @@ -164,7 +147,6 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the beginning of the next document.
++doc_idx_index;
Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def setup(
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
self._distributed = distributed
self._samples_per_phase = samples_per_phase
Expand Down
4 changes: 4 additions & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.utils.data

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
Expand Down Expand Up @@ -53,6 +54,7 @@ def setup(
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
"""
Load the datasets, and prepare or load the samplings.
Expand Down Expand Up @@ -83,6 +85,8 @@ def setup(
)
dataset = self._config.datasets[phase].build_and_sample(sampling_config)
self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True

@property
Expand Down
23 changes: 8 additions & 15 deletions fast_llm/data/dataset/blended.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

import numpy as np

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.config import SamplingConfig
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.utils import Assert, normalize_probabilities

try:
Expand Down Expand Up @@ -44,7 +42,7 @@ def __init__(

if sampling_config.cache_directory is None:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = self._build_blending_indices(len(self._datasets) <= 20)
self._dataset_index, self._sample_index = self._build_blending_indices()
else:
group = sampling_config.distributed.world_group
self._dataset_idx_filename = sampling_config.cache_directory / (self._name + "_blending_dataset_idx.npy")
Expand All @@ -55,14 +53,11 @@ def __init__(
if (group is None or group.rank() == 0) and not (
self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file()
):
dataset_index, sample_index = self._build_blending_indices(len(self._datasets) <= 20)
dataset_index, sample_index = self._build_blending_indices()
sampling_config.cache_directory.mkdir(exist_ok=True, parents=True)
np.save(self._dataset_idx_filename, dataset_index)
np.save(self._sample_idx_filename, sample_index)

safe_barrier(group, self._name)
self._load_mappings(True)

def __getstate__(self) -> tuple[typing.Any, ...]:
return (
self._datasets,
Expand All @@ -84,23 +79,20 @@ def __setstate__(self, state: tuple[typing.Any, ...]):
) = state
if isinstance(dataset_index, pathlib.Path):
self._dataset_idx_filename, self._sample_idx_filename = dataset_index, sample_index
self._load_mappings(False)
else:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = dataset_index, sample_index

def _load_mappings(self, verbose: bool) -> None:
if verbose:
log_main_rank(lambda: f" > loading blending dataset index mapping from {self._dataset_idx_filename}")
def _load_mappings(self) -> None:
if hasattr(self, "_dataset_index") and hasattr(self, "_sample_index"):
return
self._dataset_index = np.load(self._dataset_idx_filename, mmap_mode="r")
if verbose:
log_main_rank(lambda: f" > loading blending dataset index mapping from {self._sample_idx_filename}")
self._sample_index = np.load(self._sample_idx_filename, mmap_mode="r")

def __len__(self) -> int:
return self._num_samples

def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray]:
def _build_blending_indices(self) -> tuple[np.ndarray, np.ndarray]:
assert _extension_available, (
"The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly."
)
Expand All @@ -113,7 +105,7 @@ def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray
self._weights,
len(self._datasets),
self._num_samples,
verbose,
True, # Verbose
)
available_samples_per_dataset = np.array([len(dataset) for dataset in self._datasets])
sampled_per_dataset = np.bincount(dataset_index)
Expand All @@ -133,6 +125,7 @@ def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray
return dataset_index, dataset_sample_index

def __getitem__(self, idx: int) -> typing.Any:
self._load_mappings()
return self._datasets[self._dataset_index[idx]][self._sample_index[idx].item()]

@property
Expand Down
25 changes: 24 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import json
import pathlib
import time
import typing

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
Expand All @@ -21,7 +22,7 @@
if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset, GPTRandomSampledDataset
from fast_llm.data.tokenizer import Tokenizer


Expand Down Expand Up @@ -364,3 +365,25 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset:
)

return dataset_config.build_and_sample(config)


@config_class()
class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig):
"""
A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout.
"""

# TODO: This belongs to a testing plugin.
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "test_slow"
sleep: float = Field(
default=1,
desc="Sleep time during build, in seconds.",
hint=FieldHint.core,
)

def build_and_sample(self, config: SamplingConfig) -> "GPTRandomSampledDataset":
assert config.distributed.config.world_size > 1
if config.distributed.config.rank == 0:
time.sleep(self.sleep)
return GPTRandomDatasetConfig().build_and_sample(config)
Loading

0 comments on commit 3cc977f

Please sign in to comment.