From 5d651e3c78fe152e10a0fcbf4625fc92043d90e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 28 Nov 2024 16:57:55 +0100 Subject: [PATCH] fix: improve gradient accumulation support --- changelog.md | 1 + edsnlp/training/trainer.py | 85 +++++++++++++++++++++++------------ tests/training/qlf_config.yml | 2 +- 3 files changed, 59 insertions(+), 29 deletions(-) diff --git a/changelog.md b/changelog.md index 773bb8ff1..4e2f7f5f5 100644 --- a/changelog.md +++ b/changelog.md @@ -19,6 +19,7 @@ 1. reproducibility 2. in multiprocessing mode, ensure that the same data is shuffled in the same way in all workers - Bubble BaseComponent instantiation errors correctly +- Improved support for multi-gpu gradient accumulation (only sync the gradients at the end of the accumulation), now controled by the optiona `sub_batch_size` argument of `TrainingData`. ## v0.14.0 (2024-11-14) diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 8abb12612..56ac7f0af 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -13,6 +13,7 @@ Dict, Iterable, Optional, + Sequence, Union, ) @@ -230,7 +231,7 @@ def __init__( data: Stream, batch_size: BatchSizeArg, shuffle: str, - accumulation_batch_size: Optional[BatchSizeArg] = None, + sub_batch_size: Optional[BatchSizeArg] = None, pipe_names: Optional[Collection[str]] = None, post_init: bool = True, ): @@ -256,7 +257,7 @@ def __init__( datasets), "fragment" to shuffle the fragment-based datasets like parquet files, or a batching expression like "2000 words" to shuffle the dataset in chunks of 2000 words. - accumulation_batch_size: Optional[BatchSizeArg] + sub_batch_size: Optional[BatchSizeArg] How to split each batch into sub-batches that will be fed to the model independently to accumulate gradients over. pipe_names: Optional[Collection[str]] @@ -269,7 +270,7 @@ def __init__( self.data = data self.batch_size = batch_size self.shuffle = shuffle - self.accumulation_batch_size = accumulation_batch_size + self.sub_batch_size = sub_batch_size self.pipe_names = set(pipe_names) if pipe_names else None self.post_init = post_init @@ -282,12 +283,12 @@ def __call__(self, nlp, device): data = data.map(nlp.preprocess, kwargs=dict(supervision=True)) batcher = stat_batchify(self.batch_size[1] or "docs") data = data.batchify(batch_size=self.batch_size[0], batch_by=batcher) - if self.accumulation_batch_size: - sub_batcher = stat_batchify(self.accumulation_batch_size[1] or "docs") + if self.sub_batch_size: + sub_batcher = stat_batchify(self.sub_batch_size[1] or "docs") data = data.map( lambda batch: [ - nlp.collate(sub_batch) - for sub_batch in sub_batcher(batch, self.accumulation_batch_size[0]) + nlp.collate(sub_batch, device=device) + for sub_batch in sub_batcher(batch, self.sub_batch_size[0]) ] ) else: @@ -295,6 +296,27 @@ def __call__(self, nlp, device): return data +class PipeDict(torch.nn.ModuleDict): + def __init__(self, pipes, loss_scales): + super().__init__(pipes) + self.loss_scales = loss_scales + + def forward(self, batch, enable: Optional[Sequence[str]] = None): + loss = None + all_results = {} + for name, pipe in self.items(): + if enable is None or name in enable: + res = pipe(batch[name]) + all_results[name] = res + if "loss" in res: + res["loss"] = res["loss"] * self.loss_scales.get(name, 1) + loss = res["loss"] if loss is None else loss + res["loss"] + if torch.isnan(loss): + raise ValueError(f"NaN loss at component {name}") + res[f"{name}_loss"] = res["loss"] + return all_results, loss + + @validate_arguments(registry=registry) def train( *, @@ -316,7 +338,7 @@ def train( output_model_dir: Optional[Union[Path, str]] = None, save_model: bool = True, logger: bool = True, - config_meta: Dict, + config_meta: Optional[Dict] = None, **kwargs, ): """ @@ -411,14 +433,15 @@ def train( # accelerator.register_for_checkpointing(dataset) is_main_process = accelerator.is_main_process device = accelerator.device - accelerator.print(config_meta["unresolved_config"].to_yaml_str()) output_dir = Path(output_dir or Path.cwd() / "artifacts") output_model_dir = output_model_dir or output_dir / "model-last" train_metrics_path = output_dir / "train_metrics.json" if is_main_process: os.makedirs(output_dir, exist_ok=True) - config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml") + if config_meta is not None: # pragma: no cover + print(config_meta["unresolved_config"].to_yaml_str()) + config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml") validation_interval = validation_interval or max_steps // 10 checkpoint_interval = checkpoint_interval or validation_interval @@ -457,8 +480,8 @@ def train( nlp.post_init(chain_zip([td.data for td in train_data if td.post_init])) for phase_i, pipe_names in enumerate(phases): - trained_pipes = [nlp.get_pipe(name) for name in pipe_names] - trained_pipes_params = {p for pipe in trained_pipes for p in pipe.parameters()} + trained_pipes = PipeDict({n: nlp.get_pipe(n) for n in pipe_names}, loss_scales) + trained_pipes_params = set(trained_pipes.parameters()) phase_training_data = [ td for td in train_data @@ -506,7 +529,7 @@ def train( ) ) ) - (accel_optim, *trained_pipes) = accelerator.prepare(optim, *trained_pipes) + (accel_optim, trained_pipes) = accelerator.prepare(optim, trained_pipes) if hasattr(accel_optim.optimizer, "initialize"): accel_optim.optimizer.initialize() @@ -578,17 +601,23 @@ def train( set_flat_stats(b, batch_stats) res_stats = defaultdict(lambda: 0.0) - for batch, batch_pipe_names in zip(batches, batches_pipe_names): - loss = torch.zeros((), device=accelerator.device) - with nlp.cache(): - for name, pipe in zip(pipe_names, trained_pipes): - if name not in batch_pipe_names: - continue - res = dict(pipe(batch[name])) - if "loss" in res: - res["loss"] = res["loss"] * loss_scales.get(name, 1) - loss += res["loss"] - res[f"{name}_loss"] = res["loss"] + for idx, (batch, batch_pipe_names) in enumerate( + zip(batches, batches_pipe_names) + ): + cache_ctx = ( + nlp.cache() if len(batch_pipe_names) > 1 else nullcontext() + ) + no_sync_ctx = ( + accelerator.no_sync(trained_pipes) + if idx < len(batches) - 1 + else nullcontext() + ) + with cache_ctx, no_sync_ctx: + all_res, loss = trained_pipes( + batch, + enable=batch_pipe_names, + ) + for name, res in all_res.items(): for k, v in res.items(): if ( isinstance(v, (float, int)) @@ -596,10 +625,10 @@ def train( and v.ndim == 0 ): res_stats[k] += float(v) - if torch.isnan(loss): - raise ValueError(f"NaN loss at component {name}") - del k, v, res, pipe - accelerator.backward(loss) + del k, v + del res + del all_res + accelerator.backward(loss) del loss # Sync output stats after forward such as losses, supports, etc. diff --git a/tests/training/qlf_config.yml b/tests/training/qlf_config.yml index 884a8e349..afd6da65e 100644 --- a/tests/training/qlf_config.yml +++ b/tests/training/qlf_config.yml @@ -81,7 +81,7 @@ train_data: shuffle: dataset batch_size: 4 docs pipe_names: [ "qualifier" ] - accumulation_batch_size: 10 words + sub_batch_size: 10 words val_data: "@readers": json