diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index d69e79a..1f4711a 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -249,6 +249,9 @@ class TrainingConfig(Config): loss_tracker_params: LossTrackerParameters = dataclasses.field(default_factory=LossTrackerParameters) """Arguments of the loss tracker.""" + automatic_mixed_precision: bool = False + """Automatic mixed precision and gradient scaling are enabled if True.""" + @dataclasses.dataclass class PretrainingConfig(TrainingConfig): diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index 5ba9885..d79a7aa 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -273,6 +273,8 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces self.loss_criterion = self.configs.loss_function self.iterations_per_epoch = 0 self.current_epoch = 0 + self.use_torch_amp = False + self.grad_scaler = None self.gatekeeper = MultirankGateKeeper(0, 1) self.debug = self.configs.debug @@ -383,6 +385,7 @@ def build(self): self.build_optimizer() self.build_scheduler() self.load_state_checkpoint() + self.build_amp() self.build_dir() @@ -607,8 +610,9 @@ def load_data_and_get_loss(self, data_and_labels, loss_buffer, *args, **kwargs): :return: loss_buffer, total_loss_tensor, preds, labels """ data, labels = self.process_data_loader_yield(data_and_labels) - preds = self.model(*data) - losses, total_loss_tensor = self.compute_losses(loss_buffer, preds, labels) + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.use_torch_amp): + preds = self.model(*data) + losses, total_loss_tensor = self.compute_losses(loss_buffer, preds, labels) return loss_buffer, total_loss_tensor, preds, labels def run_training_epoch(self): @@ -677,8 +681,9 @@ def run_validation(self): def run_model_update_step(self, loss_node): self.optimizer.zero_grad() - loss_node.backward() - self.optimizer.step() + self.grad_scaler.scale(loss_node).backward() + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() def get_model_object(self): if isinstance(self.model, (nn.parallel.DistributedDataParallel, nn.DataParallel)): @@ -728,6 +733,17 @@ def build_scheduler(self): max_lr=self.learning_rate, step_size_up=step_size, cycle_momentum=False, mode='triangular2') + def build_amp(self): + # Do not use torch.autocast and torch.GradScaler() in trainers using other backends like HuggingFace + # Accelerate or PyTorch Lightning. These backends have their own AMP routines. + self.use_torch_amp = False + if (self.configs.automatic_mixed_precision and + self.__class__ not in [HuggingFaceAccelerateTrainer, + HuggingFaceAcceleratePretrainer, + PyTorchLightningTrainer]): + self.use_torch_amp = True + self.grad_scaler = torch.cuda.amp.GradScaler(enabled=self.use_torch_amp) + def build_model(self): self.model_class_handle = self.configs.model_class self.model = self.configs.model_class(**self.configs.model_params.__dict__)