From 3122b81b919001d4e1dd977090b6a28f35071a93 Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Mon, 13 Nov 2023 14:14:22 -0800 Subject: [PATCH] Add callback to catch NaNs in the train loss (#97) --- diffusion/callbacks/__init__.py | 2 ++ diffusion/callbacks/nan_catcher.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 diffusion/callbacks/nan_catcher.py diff --git a/diffusion/callbacks/__init__.py b/diffusion/callbacks/__init__.py index 16ed15f2..1f129c1b 100644 --- a/diffusion/callbacks/__init__.py +++ b/diffusion/callbacks/__init__.py @@ -4,9 +4,11 @@ """Custom callbacks for Diffusion.""" from diffusion.callbacks.log_diffusion_images import LogDiffusionImages +from diffusion.callbacks.nan_catcher import NaNCatcher from diffusion.callbacks.scheduled_garbage_collector import ScheduledGarbageCollector __all__ = [ 'LogDiffusionImages', + 'NaNCatcher', 'ScheduledGarbageCollector', ] diff --git a/diffusion/callbacks/nan_catcher.py b/diffusion/callbacks/nan_catcher.py new file mode 100644 index 00000000..46d9e260 --- /dev/null +++ b/diffusion/callbacks/nan_catcher.py @@ -0,0 +1,30 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Callback for catching loss NaNs.""" + +from typing import Dict, Sequence + +import torch +from composer import Callback, Logger, State + + +class NaNCatcher(Callback): + """Catches NaNs in the loss and raises an error if one is found.""" + + def after_loss(self, state: State, logger: Logger): + """Check if loss is NaN and raise an error if so.""" + # Should check if any of the elements of the loss are NaN + if isinstance(state.loss, torch.Tensor): + if torch.isnan(state.loss).any(): + raise RuntimeError('Train loss contains a NaN.') + elif isinstance(state.loss, Sequence): + for loss in state.loss: + if torch.isnan(loss).any(): + raise RuntimeError('Train loss contains a NaN.') + elif isinstance(state.loss, Dict): + for k, v in state.loss.items(): + if torch.isnan(v).any(): + raise RuntimeError(f'Train loss {k} contains a NaN.') + else: + raise TypeError(f'Loss is of type {type(state.loss)}, but should be a tensor or a list of tensors')