From 5aaa8c9426a7c9fb7d281b522ad86864dad80431 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:46:11 -0700 Subject: [PATCH] Add loss generating tokens for loss accumulation (#3677) --- composer/core/data_spec.py | 19 +++++++- composer/trainer/trainer.py | 16 +++++-- tests/test_simple_nlp.py | 84 ++++++++++++++++++++++++++++++++++ tests/trainer/test_dataspec.py | 37 ++++++++++++++- 4 files changed, 149 insertions(+), 7 deletions(-) diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index 35aa94f05e..670011e05b 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -5,6 +5,7 @@ from __future__ import annotations import collections.abc +import logging import textwrap import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Sequence, Union @@ -20,6 +21,8 @@ __all__ = ['DataSpec', 'ensure_data_spec'] +log = logging.getLogger(__name__) + def _split_list(l, microbatch_size: int): if len(l) < microbatch_size: @@ -185,14 +188,14 @@ def __init__( device_transforms: Optional[Callable[[Batch], Batch]] = None, split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None, get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None, - get_num_tokens_in_batch: Optional[Callable[[Batch], int]] = None, + get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None, ) -> None: self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader self.num_tokens = num_tokens self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms self.split_batch = default_split_batch if split_batch is None else split_batch self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch - self.get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch + self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch if num_samples is not None: self.num_samples = num_samples @@ -295,6 +298,18 @@ def _default_get_num_tokens_in_batch(self, batch: Batch) -> int: return self.dataloader.dataset.max_seq_len * samples_per_batch # type: ignore return 0 + def get_num_tokens_in_batch(self, batch: Batch, token_type: str = 'total') -> int: + num_tokens = self._get_num_tokens_in_batch(batch) + + if isinstance(num_tokens, int): + return num_tokens + elif isinstance(num_tokens, dict): + if token_type not in num_tokens: + raise ValueError(f'Token type {token_type} not found in num_tokens dict.') + return num_tokens[token_type] + else: + raise ValueError(f'Unexpected return type from get_num_tokens_in_batch: {type(num_tokens)}') + def ensure_data_spec(dataloader: Union[DataSpec, Iterable, dict]) -> DataSpec: """Ensures that the ``dataloader`` is a :class:`.DataSpec`. diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index db7752f879..3241832c13 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1027,7 +1027,10 @@ class Trainer: it into sections of size ``device_train_microbatch_size``. If the batch size of the dataloader is not divisible by ``device_train_microbatch_size``, the last section will be potentially smaller. accumulate_train_batch_on_tokens (bool, optional): Whether training loss is accumulated over the number of tokens in a batch, - rather than the number of samples. Only works if the train data spec implements `get_num_tokens_in_batch`. (default: ``False``) + rather than the number of samples. Only works if the train data spec implements `get_num_tokens_in_batch`. + Note: If you are using this flag, you can optionally have your `get_num_tokens_in_batch` function return a dictionary + with two keys (`total` and `loss_generating`). Composer will then accumulate the batch on loss generating tokens specifically, + even though total tokens will be used for any other time involving tokens. (default: ``False``) seed (int, optional): The seed used in randomization. If ``None``, then a random seed will be created. (default: ``None``) @@ -3061,11 +3064,13 @@ def _train_microbatches( # Tracker for gradient accumulation if self.accumulate_train_batch_on_tokens: - current_batch_size = sum([self._train_data_spec.get_num_tokens_in_batch(b) for b in microbatches]) + current_batch_size = sum([ + self._train_data_spec.get_num_tokens_in_batch(b, token_type='loss_generating') for b in microbatches + ]) if current_batch_size == 0: raise ValueError( textwrap.dedent( - 'Requested loss accumulation based on number of tokens in training batch, ' + 'Requested loss accumulation based on number of loss generating tokens in training batch, ' 'but zero tokens found (perhaps due to an improper DataSpec).', ), ) @@ -3124,7 +3129,10 @@ def _train_microbatch( device_batch = deepcopy(self.state.batch) if self.accumulate_train_batch_on_tokens: - microbatch_size = self._train_data_spec.get_num_tokens_in_batch(self.state.batch) + microbatch_size = self._train_data_spec.get_num_tokens_in_batch( + self.state.batch, + token_type='loss_generating', + ) else: microbatch_size = self._train_data_spec.get_num_samples_in_batch(self.state.batch) if self.state.deepspeed_enabled or not isinstance(self.state.model, DistributedDataParallel): diff --git a/tests/test_simple_nlp.py b/tests/test_simple_nlp.py index 5fd107aaa5..db47809012 100644 --- a/tests/test_simple_nlp.py +++ b/tests/test_simple_nlp.py @@ -217,3 +217,87 @@ def test_simple_nlp_mlm_token_batch(tiny_bert_tokenizer, device): trainer2.fit() assert trainer2.state.train_metrics is not None assert trainer2.state.train_metrics['LanguageCrossEntropy'].compute() == cross_entropy + + +@device('gpu') +def test_simple_nlp_mlm_loss_gen_token_batch(tiny_bert_tokenizer, device): + transformers = pytest.importorskip('transformers') + + vocab_size = tiny_bert_tokenizer.vocab_size + sequence_length = 32 + size = 96 + batch_size = 8 + device = get_device(device) + + train_dataset = RandomTextLMDataset( + size=size, + vocab_size=vocab_size, + sequence_length=sequence_length, + use_keys=True, + pad_token_id=tiny_bert_tokenizer.pad_token_id, + ) + for i in range(size): # Proactively load dataset for consistent randomization + train_dataset[i] + collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer) + + # Get the model's state dict before training starts, so we can reproduce results + model = SimpleTransformerMaskedLM(vocab_size=vocab_size) + state_dict = model.state_dict() + + # Set up the data spec that can count the non-padding tokens in a batch + train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + sampler=dist.get_sampler(train_dataset), + collate_fn=collator, + ) + data_spec = DataSpec( + dataloader=train_dataloader, + get_num_tokens_in_batch=lambda b: (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item(), + ) + + # Arbitrarily divide num tokens by 2 to simulate loss-generating tokens + loss_gen_data_spec = DataSpec( + dataloader=train_dataloader, + get_num_tokens_in_batch=lambda b: { + 'total': (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item(), + 'loss_generating': (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item() // 2, + }, + ) + + trainer = Trainer( + model=model, + seed=42, + train_dataloader=data_spec, + max_duration='2ep', + device_train_microbatch_size=batch_size // 2, + accumulate_train_batch_on_tokens=False, + device=device, + ) + trainer.fit() + + # Check that there is some train cross entropy + assert trainer.state.train_metrics is not None + cross_entropy = trainer.state.train_metrics['LanguageCrossEntropy'].compute() + assert cross_entropy != 0.0 + + # Set up a trainer that accumulates train loss based on token counts, after reloading original state dict + model.load_state_dict(state_dict) + token_trainer = Trainer( + model=model, + seed=42, + train_dataloader=loss_gen_data_spec, + max_duration='2ep', + device_train_microbatch_size=batch_size // 2, + accumulate_train_batch_on_tokens=True, + device=device, + ) + token_trainer.fit() + + # Check that there is some train cross entropy + assert token_trainer.state.train_metrics is not None + token_cross_entropy = token_trainer.state.train_metrics['LanguageCrossEntropy'].compute() + assert token_cross_entropy != 0.0 + + # Require that the train cross entropies are different between the trainers + assert cross_entropy != token_cross_entropy diff --git a/tests/trainer/test_dataspec.py b/tests/trainer/test_dataspec.py index a875a8d315..cf7056cd65 100644 --- a/tests/trainer/test_dataspec.py +++ b/tests/trainer/test_dataspec.py @@ -1,7 +1,8 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any +import contextlib +from typing import Any, Optional import pytest import torch @@ -71,6 +72,40 @@ def test_get_num_tokens_hf_default(batch_size: int, sequence_length: int, use_ke assert actual == expected +@pytest.mark.parametrize( + 'return_dict,requested_key,expected', + [ + [True, None, 8], # dict with default key + [False, None, 8], # int with default key + [False, 'loss_generating', 8], # int with non-default key + [True, 'loss_generating', 4], # dict with non-default key + ], +) +def test_get_num_tokens_types(return_dict: bool, requested_key: Optional[str], expected: Optional[int]): + should_error = expected is None + error_context = pytest.raises(ValueError) if should_error else contextlib.nullcontext() + + def get_num_tokens_in_batch(batch): + num_tokens = 8 + num_loss_generating_tokens = 4 + + if return_dict: + return {'total': num_tokens, 'loss_generating': num_loss_generating_tokens} + + return num_tokens + + dataspec = DataSpec(dataloader=[], get_num_tokens_in_batch=get_num_tokens_in_batch) + + batch = {} + extra_args = {} + if requested_key is not None: + extra_args['token_type'] = requested_key + + with error_context: + actual = dataspec.get_num_tokens_in_batch(batch, **extra_args) + assert actual == expected + + def test_small_batch_at_end_warning(): batch_size = 4 dataset_size = 17