Skip to content

Commit

Permalink
fix: improve gradient accumulation support
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Nov 29, 2024
1 parent c9e2a63 commit 5d651e3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
1. reproducibility
2. in multiprocessing mode, ensure that the same data is shuffled in the same way in all workers
- Bubble BaseComponent instantiation errors correctly
- Improved support for multi-gpu gradient accumulation (only sync the gradients at the end of the accumulation), now controled by the optiona `sub_batch_size` argument of `TrainingData`.

## v0.14.0 (2024-11-14)

Expand Down
85 changes: 57 additions & 28 deletions edsnlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Dict,
Iterable,
Optional,
Sequence,
Union,
)

Expand Down Expand Up @@ -230,7 +231,7 @@ def __init__(
data: Stream,
batch_size: BatchSizeArg,
shuffle: str,
accumulation_batch_size: Optional[BatchSizeArg] = None,
sub_batch_size: Optional[BatchSizeArg] = None,
pipe_names: Optional[Collection[str]] = None,
post_init: bool = True,
):
Expand All @@ -256,7 +257,7 @@ def __init__(
datasets), "fragment" to shuffle the fragment-based datasets
like parquet files, or a batching expression like "2000 words"
to shuffle the dataset in chunks of 2000 words.
accumulation_batch_size: Optional[BatchSizeArg]
sub_batch_size: Optional[BatchSizeArg]
How to split each batch into sub-batches that will be fed to
the model independently to accumulate gradients over.
pipe_names: Optional[Collection[str]]
Expand All @@ -269,7 +270,7 @@ def __init__(
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.accumulation_batch_size = accumulation_batch_size
self.sub_batch_size = sub_batch_size
self.pipe_names = set(pipe_names) if pipe_names else None
self.post_init = post_init

Expand All @@ -282,19 +283,40 @@ def __call__(self, nlp, device):
data = data.map(nlp.preprocess, kwargs=dict(supervision=True))
batcher = stat_batchify(self.batch_size[1] or "docs")
data = data.batchify(batch_size=self.batch_size[0], batch_by=batcher)
if self.accumulation_batch_size:
sub_batcher = stat_batchify(self.accumulation_batch_size[1] or "docs")
if self.sub_batch_size:
sub_batcher = stat_batchify(self.sub_batch_size[1] or "docs")
data = data.map(
lambda batch: [
nlp.collate(sub_batch)
for sub_batch in sub_batcher(batch, self.accumulation_batch_size[0])
nlp.collate(sub_batch, device=device)
for sub_batch in sub_batcher(batch, self.sub_batch_size[0])
]
)
else:
data = data.map(nlp.collate, kwargs=dict(device=device))
return data


class PipeDict(torch.nn.ModuleDict):
def __init__(self, pipes, loss_scales):
super().__init__(pipes)
self.loss_scales = loss_scales

def forward(self, batch, enable: Optional[Sequence[str]] = None):
loss = None
all_results = {}
for name, pipe in self.items():
if enable is None or name in enable:
res = pipe(batch[name])
all_results[name] = res
if "loss" in res:
res["loss"] = res["loss"] * self.loss_scales.get(name, 1)
loss = res["loss"] if loss is None else loss + res["loss"]
if torch.isnan(loss):
raise ValueError(f"NaN loss at component {name}")
res[f"{name}_loss"] = res["loss"]
return all_results, loss


@validate_arguments(registry=registry)
def train(
*,
Expand All @@ -316,7 +338,7 @@ def train(
output_model_dir: Optional[Union[Path, str]] = None,
save_model: bool = True,
logger: bool = True,
config_meta: Dict,
config_meta: Optional[Dict] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -411,14 +433,15 @@ def train(
# accelerator.register_for_checkpointing(dataset)
is_main_process = accelerator.is_main_process
device = accelerator.device
accelerator.print(config_meta["unresolved_config"].to_yaml_str())

output_dir = Path(output_dir or Path.cwd() / "artifacts")
output_model_dir = output_model_dir or output_dir / "model-last"
train_metrics_path = output_dir / "train_metrics.json"
if is_main_process:
os.makedirs(output_dir, exist_ok=True)
config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml")
if config_meta is not None: # pragma: no cover
print(config_meta["unresolved_config"].to_yaml_str())
config_meta["unresolved_config"].to_disk(output_dir / "training_config.yml")

validation_interval = validation_interval or max_steps // 10
checkpoint_interval = checkpoint_interval or validation_interval
Expand Down Expand Up @@ -457,8 +480,8 @@ def train(
nlp.post_init(chain_zip([td.data for td in train_data if td.post_init]))

for phase_i, pipe_names in enumerate(phases):
trained_pipes = [nlp.get_pipe(name) for name in pipe_names]
trained_pipes_params = {p for pipe in trained_pipes for p in pipe.parameters()}
trained_pipes = PipeDict({n: nlp.get_pipe(n) for n in pipe_names}, loss_scales)
trained_pipes_params = set(trained_pipes.parameters())
phase_training_data = [
td
for td in train_data
Expand Down Expand Up @@ -506,7 +529,7 @@ def train(
)
)
)
(accel_optim, *trained_pipes) = accelerator.prepare(optim, *trained_pipes)
(accel_optim, trained_pipes) = accelerator.prepare(optim, trained_pipes)
if hasattr(accel_optim.optimizer, "initialize"):
accel_optim.optimizer.initialize()

Expand Down Expand Up @@ -578,28 +601,34 @@ def train(
set_flat_stats(b, batch_stats)

res_stats = defaultdict(lambda: 0.0)
for batch, batch_pipe_names in zip(batches, batches_pipe_names):
loss = torch.zeros((), device=accelerator.device)
with nlp.cache():
for name, pipe in zip(pipe_names, trained_pipes):
if name not in batch_pipe_names:
continue
res = dict(pipe(batch[name]))
if "loss" in res:
res["loss"] = res["loss"] * loss_scales.get(name, 1)
loss += res["loss"]
res[f"{name}_loss"] = res["loss"]
for idx, (batch, batch_pipe_names) in enumerate(
zip(batches, batches_pipe_names)
):
cache_ctx = (
nlp.cache() if len(batch_pipe_names) > 1 else nullcontext()
)
no_sync_ctx = (
accelerator.no_sync(trained_pipes)
if idx < len(batches) - 1
else nullcontext()
)
with cache_ctx, no_sync_ctx:
all_res, loss = trained_pipes(
batch,
enable=batch_pipe_names,
)
for name, res in all_res.items():
for k, v in res.items():
if (
isinstance(v, (float, int))
or isinstance(v, torch.Tensor)
and v.ndim == 0
):
res_stats[k] += float(v)
if torch.isnan(loss):
raise ValueError(f"NaN loss at component {name}")
del k, v, res, pipe
accelerator.backward(loss)
del k, v
del res
del all_res
accelerator.backward(loss)
del loss

# Sync output stats after forward such as losses, supports, etc.
Expand Down
2 changes: 1 addition & 1 deletion tests/training/qlf_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ train_data:
shuffle: dataset
batch_size: 4 docs
pipe_names: [ "qualifier" ]
accumulation_batch_size: 10 words
sub_batch_size: 10 words

val_data:
"@readers": json
Expand Down

0 comments on commit 5d651e3

Please sign in to comment.