Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better timeouts #129

Merged
merged 10 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import contextlib
import datetime
import logging
import typing

Expand All @@ -25,12 +26,21 @@
logger = logging.getLogger(__name__)


def broadcast(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False) -> Work | None:
def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None:
if group is not None and timeout is not None:
# TODO: Only works for nccl?
group._add_ephemeral_timeout(datetime.timedelta(seconds=timeout))


def broadcast(
tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, timeout: float | None = None
) -> Work | None:
"""Same as torch.distributed.broadcast, but without the complication of going through the global rank."""
assert group is not None
opts = torch.distributed.BroadcastOptions()
opts.rootRank = src
opts.rootTensor = 0
add_ephemeral_timeout(group, timeout)
work = group.broadcast([tensor], opts)
if async_op:
return work
Expand All @@ -53,10 +63,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
)


def safe_barrier(group: ProcessGroup | None, value: int | str = 1) -> None:
def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None:
if group:
hashed = hash(value) % 2**32
out = allreduce_scalar(hashed, dtype=torch.int64, group=group)
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout)
if out != hashed * group.size():
raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})")

Expand All @@ -66,9 +76,11 @@ def allreduce_scalar(
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
op=ReduceOp.SUM,
timeout: float | None = None,
) -> float | int:
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
add_ephemeral_timeout(group, timeout)
torch.distributed.all_reduce(value, op=op, group=group)
return value.item()
else:
Expand All @@ -80,13 +92,14 @@ def broadcast_scalar(
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
src: int = 0,
timeout: float | None = None,
) -> float | int:
if not group:
return value
tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device()))
if group.rank() == src:
tensor.fill_(value)
broadcast(tensor, src, group)
broadcast(tensor, src, group, timeout=timeout)
return tensor.item()


Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def setup(
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
self._distributed = distributed
self._samples_per_phase = samples_per_phase
Expand Down
4 changes: 4 additions & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.utils.data

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
Expand Down Expand Up @@ -53,6 +54,7 @@ def setup(
distributed: "Distributed",
samples_per_phase: dict[PhaseType, int],
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
"""
Load the datasets, and prepare or load the samplings.
Expand Down Expand Up @@ -83,6 +85,8 @@ def setup(
)
dataset = self._config.datasets[phase].build_and_sample(sampling_config)
self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True

@property
Expand Down
23 changes: 8 additions & 15 deletions fast_llm/data/dataset/blended.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

import numpy as np

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.config import SamplingConfig
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.utils import Assert, normalize_probabilities

try:
Expand Down Expand Up @@ -44,7 +42,7 @@ def __init__(

if sampling_config.cache_directory is None:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = self._build_blending_indices(len(self._datasets) <= 20)
self._dataset_index, self._sample_index = self._build_blending_indices()
else:
group = sampling_config.distributed.world_group
self._dataset_idx_filename = sampling_config.cache_directory / (self._name + "_blending_dataset_idx.npy")
Expand All @@ -55,14 +53,11 @@ def __init__(
if (group is None or group.rank() == 0) and not (
self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file()
):
dataset_index, sample_index = self._build_blending_indices(len(self._datasets) <= 20)
dataset_index, sample_index = self._build_blending_indices()
sampling_config.cache_directory.mkdir(exist_ok=True, parents=True)
np.save(self._dataset_idx_filename, dataset_index)
np.save(self._sample_idx_filename, sample_index)

safe_barrier(group, self._name)
self._load_mappings(True)

def __getstate__(self) -> tuple[typing.Any, ...]:
return (
self._datasets,
Expand All @@ -84,23 +79,20 @@ def __setstate__(self, state: tuple[typing.Any, ...]):
) = state
if isinstance(dataset_index, pathlib.Path):
self._dataset_idx_filename, self._sample_idx_filename = dataset_index, sample_index
self._load_mappings(False)
else:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = dataset_index, sample_index

def _load_mappings(self, verbose: bool) -> None:
if verbose:
log_main_rank(lambda: f" > loading blending dataset index mapping from {self._dataset_idx_filename}")
def _load_mappings(self) -> None:
if hasattr(self, "_dataset_index") and hasattr(self, "_sample_index"):
return
self._dataset_index = np.load(self._dataset_idx_filename, mmap_mode="r")
if verbose:
log_main_rank(lambda: f" > loading blending dataset index mapping from {self._sample_idx_filename}")
self._sample_index = np.load(self._sample_idx_filename, mmap_mode="r")

def __len__(self) -> int:
return self._num_samples

def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray]:
def _build_blending_indices(self) -> tuple[np.ndarray, np.ndarray]:
assert _extension_available, (
"The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly."
)
Expand All @@ -113,7 +105,7 @@ def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray
self._weights,
len(self._datasets),
self._num_samples,
verbose,
True, # Verbose
)
available_samples_per_dataset = np.array([len(dataset) for dataset in self._datasets])
sampled_per_dataset = np.bincount(dataset_index)
Expand All @@ -133,6 +125,7 @@ def _build_blending_indices(self, verbose: bool) -> tuple[np.ndarray, np.ndarray
return dataset_index, dataset_sample_index

def __getitem__(self, idx: int) -> typing.Any:
self._load_mappings()
return self._datasets[self._dataset_index[idx]][self._sample_index[idx].item()]

@property
Expand Down
25 changes: 24 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import json
import pathlib
import time
import typing

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
Expand All @@ -21,7 +22,7 @@
if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset, GPTRandomSampledDataset
from fast_llm.data.tokenizer import Tokenizer


Expand Down Expand Up @@ -364,3 +365,25 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset:
)

return dataset_config.build_and_sample(config)


@config_class()
class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't this be moved to a module in the tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it has to be run in a separate process (for distributed) which doesn't import the test file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so you're calling the fast-llm CLI in a subprocess in the tests.

"""
A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout.
"""

# TODO: This belongs to a testing plugin.
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "test_slow"
sleep: float = Field(
default=1,
desc="Sleep time during build, in seconds.",
hint=FieldHint.core,
)

def build_and_sample(self, config: SamplingConfig) -> "GPTRandomSampledDataset":
assert config.distributed.config.world_size > 1
if config.distributed.config.rank == 0:
time.sleep(self.sleep)
return GPTRandomDatasetConfig().build_and_sample(config)
59 changes: 23 additions & 36 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig
from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset
Expand Down Expand Up @@ -37,59 +36,56 @@ def __init__(
):
assert isinstance(sampling_config, GPTSamplingConfig)
self._indexed_dataset = indexed_dataset

group = sampling_config.distributed.world_group
self._num_samples = sampling_config.num_samples
self._sequence_length = sampling_config.sequence_length
self._seed = sampling_config.seed

if sampling_config.cache_directory is None:
log_main_rank(
" > No dataset cache directory provided, building the index map on all ranks."
"This may be very inefficient...",
log_fn=logger.warning,
)
self._doc_idx, self._sample_idx, self._shuffle_idx = self._sample(sampling_config)
self._doc_idx, self._sample_idx, self._shuffle_idx = self._sample()
else:
cache_prefix = (
f"{self.name}_ns_{sampling_config.num_samples}_sl_{sampling_config.sequence_length}"
f"_s_{sampling_config.seed}"
)
cache_prefix = f"{self.name}_ns_{self._num_samples}_sl_{self._sequence_length}" f"_s_{self._seed}"
# TODO: Any way to combine into a single file? (Memmap is harder)
self._doc_idx_filename = sampling_config.cache_directory / (cache_prefix + "_doc_idx.npy")
self._sample_idx_filename = sampling_config.cache_directory / (cache_prefix + "_sample_idx.npy")
self._shuffle_idx_filename = sampling_config.cache_directory / (cache_prefix + "_shuffle_idx.npy")

# Build the indexed mapping if it doesn't exist.
# TODO: This only works if the dataset location is accessible by all job.
if (group is None or group.rank() == 0) and not (
if (
sampling_config.distributed.world_group is None or sampling_config.distributed.world_group.rank() == 0
) and not (
self._doc_idx_filename.is_file()
and self._sample_idx_filename.is_file()
and self._shuffle_idx_filename.is_file()
):
log_main_rank(" > Building the index map on rank 0 ...")
doc_idx, sample_idx, shuffle_idx = self._sample(sampling_config)
doc_idx, sample_idx, shuffle_idx = self._sample()
sampling_config.cache_directory.mkdir(parents=True, exist_ok=True)
np.save(self._doc_idx_filename, doc_idx)
np.save(self._sample_idx_filename, sample_idx)
np.save(self._shuffle_idx_filename, shuffle_idx)

safe_barrier(group, self._indexed_dataset.name)
self._load_mappings(True)

def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def _sample(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Create a `GPTSampledDataset` with the requested parameters.
"""
document_sizes = self._indexed_dataset.get_document_sizes()
num_documents = len(document_sizes)
num_tokens = document_sizes.sum()
np_rng = np.random.RandomState(seed=sampling_config.seed)
np_rng = np.random.RandomState(seed=self._seed)

num_epochs = math.ceil((sampling_config.sequence_length * sampling_config.num_samples + 1) / num_tokens)
num_epochs = math.ceil((self._sequence_length * self._num_samples + 1) / num_tokens)
# For the last epoch, decide whether include the entire epoch
# in the global shuffle or not.
# Get the number of samples for the last epoch
main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // sampling_config.sequence_length
last_epoch_samples = sampling_config.num_samples - main_epochs_samples
samples_per_epoch = (num_tokens - 1) // sampling_config.sequence_length
main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // self._sequence_length
last_epoch_samples = self._num_samples - main_epochs_samples
samples_per_epoch = (num_tokens - 1) // self._sequence_length
# If we have less than 80% of the samples for the last epoch, separate out the epoch and treat it differently.
# Note: the 80% number is just based on common sense and can be adjusted if needed.
separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch
Expand All @@ -108,7 +104,7 @@ def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.nd
sample_idx = build_sample_idx(
document_sizes,
doc_idx,
sampling_config.sequence_length,
self._sequence_length,
num_epochs,
num_tokens,
True,
Expand All @@ -128,9 +124,9 @@ def _sample(self, sampling_config: GPTSamplingConfig) -> tuple[np.ndarray, np.nd
else:
np_rng.shuffle(shuffle_idx)

Assert.geq(len(shuffle_idx), sampling_config.num_samples)
Assert.geq(len(shuffle_idx), self._num_samples)
# TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch.
return doc_idx, sample_idx, shuffle_idx[: sampling_config.num_samples]
return doc_idx, sample_idx, shuffle_idx[: self._num_samples]

def __getstate__(
self,
Expand Down Expand Up @@ -165,34 +161,25 @@ def __setstate__(self, state: tuple[GPTIndexedDataset, pathlib.Path, pathlib.Pat
self._sample_idx,
self._shuffle_idx,
) = state
self._load_mappings(False)

def _load_mappings(self, verbose: bool) -> None:
if hasattr(self, "_doc_idx"):
def _load_mappings(self) -> None:
if hasattr(self, "_doc_idx") and hasattr(self, "_sample_idx") and hasattr(self, "_shuffle_idx"):
return
if verbose:
log_main_rank(lambda: f" > loading doc-idx mapping from {self._doc_idx_filename}")
self._doc_idx = np.load(self._doc_idx_filename, mmap_mode="r")
if verbose:
log_main_rank(lambda: f" > loading sample-idx mapping from {self._sample_idx_filename}")
self._sample_idx = np.load(self._sample_idx_filename, mmap_mode="r")
if verbose:
log_main_rank(lambda: f" > loading shuffle-idx mapping from {self._shuffle_idx_filename}")
self._shuffle_idx = np.load(self._shuffle_idx_filename, mmap_mode="r")
if verbose:
log_main_rank(lambda: f" loaded dataset with {len(self)} samples.")

def __len__(self) -> int:
# -1 is due to data structure used to retrieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self._shuffle_idx.shape[0]
return self._num_samples

def __getitem__(self, idx: int) -> typing.Any:
"""
Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents)
with the requested sampling index.
The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`).
"""
# Lazy load indexes.
self._load_mappings()
# Get the shuffled index.
shuffled_idx = self._shuffle_idx[idx]
# Start and end documents and offsets.
Expand Down
Loading
Loading