Skip to content

Commit

Permalink
Automatic mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Apr 11, 2024
1 parent 0503b22 commit 279d4f5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
3 changes: 3 additions & 0 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ class TrainingConfig(Config):
loss_tracker_params: LossTrackerParameters = dataclasses.field(default_factory=LossTrackerParameters)
"""Arguments of the loss tracker."""

automatic_mixed_precision: bool = False
"""Automatic mixed precision and gradient scaling are enabled if True."""


@dataclasses.dataclass
class PretrainingConfig(TrainingConfig):
Expand Down
24 changes: 20 additions & 4 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def __init__(self, configs: Union[TrainingConfig, Config], rank=None, num_proces
self.loss_criterion = self.configs.loss_function
self.iterations_per_epoch = 0
self.current_epoch = 0
self.use_torch_amp = False
self.grad_scaler = None
self.gatekeeper = MultirankGateKeeper(0, 1)

self.debug = self.configs.debug
Expand Down Expand Up @@ -383,6 +385,7 @@ def build(self):
self.build_optimizer()
self.build_scheduler()
self.load_state_checkpoint()
self.build_amp()

self.build_dir()

Expand Down Expand Up @@ -607,8 +610,9 @@ def load_data_and_get_loss(self, data_and_labels, loss_buffer, *args, **kwargs):
:return: loss_buffer, total_loss_tensor, preds, labels
"""
data, labels = self.process_data_loader_yield(data_and_labels)
preds = self.model(*data)
losses, total_loss_tensor = self.compute_losses(loss_buffer, preds, labels)
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.use_torch_amp):
preds = self.model(*data)
losses, total_loss_tensor = self.compute_losses(loss_buffer, preds, labels)
return loss_buffer, total_loss_tensor, preds, labels

def run_training_epoch(self):
Expand Down Expand Up @@ -677,8 +681,9 @@ def run_validation(self):

def run_model_update_step(self, loss_node):
self.optimizer.zero_grad()
loss_node.backward()
self.optimizer.step()
self.grad_scaler.scale(loss_node).backward()
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()

def get_model_object(self):
if isinstance(self.model, (nn.parallel.DistributedDataParallel, nn.DataParallel)):
Expand Down Expand Up @@ -728,6 +733,17 @@ def build_scheduler(self):
max_lr=self.learning_rate, step_size_up=step_size,
cycle_momentum=False, mode='triangular2')

def build_amp(self):
# Do not use torch.autocast and torch.GradScaler() in trainers using other backends like HuggingFace
# Accelerate or PyTorch Lightning. These backends have their own AMP routines.
self.use_torch_amp = False
if (self.configs.automatic_mixed_precision and
self.__class__ not in [HuggingFaceAccelerateTrainer,
HuggingFaceAcceleratePretrainer,
PyTorchLightningTrainer]):
self.use_torch_amp = True
self.grad_scaler = torch.cuda.amp.GradScaler(enabled=self.use_torch_amp)

def build_model(self):
self.model_class_handle = self.configs.model_class
self.model = self.configs.model_class(**self.configs.model_params.__dict__)
Expand Down

0 comments on commit 279d4f5

Please sign in to comment.