Skip to content

Commit

Permalink
Merge pull request #305 from VectorInstitute/nnunet-amp-integration
Browse files Browse the repository at this point in the history
Nnunet amp integration
  • Loading branch information
jewelltaylor authored Dec 19, 2024
2 parents 1ec2707 + 2513384 commit 97a829e
Showing 1 changed file with 78 additions and 4 deletions.
82 changes: 78 additions & 4 deletions fl4health/clients/nnunet_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from flwr.common.logger import FLOWER_LOGGER, console_handler, log
from flwr.common.typing import Config, Scalar
from torch import nn
from torch.cuda.amp import GradScaler
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -188,6 +189,9 @@ def __init__(
# Used to redirect stdout to logger
self.stream2debug = StreamToLogger(FLOWER_LOGGER, DEBUG)

# Used to scale gradients if using mixed precision training (true if device is cuda)
self.grad_scaler: GradScaler | None = GradScaler() if self.device.type == "cuda" else None

# nnunet specific attributes to be initialized in setup_client
self.nnunet_trainer_class = nnunet_trainer_class
self.nnunet_trainer_class_kwargs = nnunet_trainer_class_kwargs
Expand All @@ -203,6 +207,58 @@ def __init__(
log(INFO, "Switching pytorch model jit compile to OFF")
os.environ["nnUNet_compile"] = str("false")

def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[TrainingLosses, TorchPredType]:
"""
Given a single batch of input and target data, generate predictions, compute loss, update parameters and
optionally update metrics if they exist. (ie backprop on a single batch of data).
Assumes self.model is in train mode already.
Overrides parent to include mixed precision training (autocasting and corresponding gradient scaling)
as per the original nnUNetTrainer.
Args:
input (TorchInputType): The input to be fed into the model.
target (TorchTargetType): The target corresponding to the input.
Returns:
Tuple[TrainingLosses, TorchPredType]: The losses object from the train step along with
a dictionary of any predictions produced by the model.
"""
# If the device type is not cuda, we don't use mixed precision training
# So we are safe to use the BasicClient train_step method
# Note that transform_gradients is defined for the NnunetClient
if self.device.type != "cuda":
return super().train_step(input, target)

# If performing mixed precision training, scaler should be defined
assert self.grad_scaler is not None

# Clear gradients from optimizer if they exist
self.optimizers["global"].zero_grad()

# Call user defined methods to get predictions and compute loss
preds, features = self.predict(input)
target = self.transform_target(target)
losses = self.compute_training_loss(preds, features, target)

# Custom backward pass logic with gradient scaling adapted from nnUNetTrainer:
# https://github.com/MIC-DKFZ/nnUNet/blob/43349fa5f0680e8109a78dca7215c19e258c9dd7/ \
# nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py#L999

# Compute scaled loss and perform backward pass
scaled_backward_loss = self.grad_scaler.scale(losses.backward["backward"])
scaled_backward_loss.backward()

# Rescale gradients then clip based on specified norm
self.grad_scaler.unscale_(self.optimizers["global"])
self.transform_gradients(losses)

# Update parameters and scaler
self.grad_scaler.step(self.optimizers["global"])
self.grad_scaler.update()

return losses, preds

@use_default_signal_handlers # Dataloaders use multiprocessing
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
"""
Expand Down Expand Up @@ -514,7 +570,8 @@ def setup_client(self, config: Config) -> None:
def predict(self, input: TorchInputType) -> Tuple[TorchPredType, Dict[str, torch.Tensor]]:
"""
Generate model outputs. Overridden because nnunets output lists when
deep supervision is on so we have to reformat the output into dicts
deep supervision is on so we have to reformat the output into dicts
If device type is cuda, loss computed in mixed precision.
Args:
input (TorchInputType): The model inputs
Expand All @@ -525,7 +582,14 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, Dict[str, torch
unused by this subclass and therefore is always an empty dict
"""
if isinstance(input, torch.Tensor):
output = self.model(input)
# If device type is cuda, nnUNet defaults to mixed precision forward pass
# https://github.com/MIC-DKFZ/nnUNet/blob/43349fa5f0680e8109a78dca7215c19e258c9dd7 \
# nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py#L993
if self.device.type == "cuda":
with torch.autocast(self.device.type, enabled=True):
output = self.model(input)
else:
output = self.model(input)
else:
raise TypeError('"input" must be of type torch.Tensor for nnUNetClient')

Expand All @@ -547,7 +611,8 @@ def compute_loss_and_additional_losses(
target: TorchTargetType,
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
"""
Checks the pred and target types and computes the loss
Checks the pred and target types and computes the loss.
If device type is cuda, loss computed in mixed precision.
Args:
preds (TorchPredType): Dictionary of model output tensors indexed
Expand Down Expand Up @@ -577,7 +642,16 @@ def compute_loss_and_additional_losses(
f"Got {len(loss_preds)} predictions and {len(loss_targets)} targets."
)

return self.criterion(loss_preds, loss_targets), None
# If device type is cuda, nnUNet defaults to compute loss in mixed precision
# https://github.com/MIC-DKFZ/nnUNet/blob/43349fa5f0680e8109a78dca7215c19e258c9dd7 \
# nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py#L993
if self.device.type == "cuda":
with torch.autocast(self.device.type, enabled=True):
loss = self.criterion(loss_preds, loss_targets), None
else:
loss = self.criterion(loss_preds, loss_targets), None

return loss

def mask_data(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down

0 comments on commit 97a829e

Please sign in to comment.