From 3f2ecbf2f081767a1a59fe34dca2b1841fdd6bcd Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Tue, 23 Jan 2024 17:16:56 -0800 Subject: [PATCH] Support None metric value in best checkpointing (#688) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/688 Reviewed By: galrotem Differential Revision: D53023423 fbshipit-source-id: 8627deff9b57c032b8a2cf345fb67bb9efc61bc8 --- .../callbacks/test_base_checkpointer.py | 9 +++- .../framework/callbacks/base_checkpointer.py | 51 ++++++++++--------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index ad11e0a711..ea746ab7da 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -672,6 +672,13 @@ def test_best_checkpoint_no_top_k(self) -> None: state = get_dummy_train_state() my_train_unit = MyTrainLossUnit() + my_train_unit.train_loss = None + bcs.on_train_epoch_end(state, my_train_unit) + # none metric-value will not be updated in checkpoint dirpaths + self.assertEqual(bcs._ckpt_dirpaths, []) + self.assertEqual(os.listdir(temp_dir), ["epoch_0_step_0"]) + + my_train_unit.train_loss = 0.01 bcs.on_train_epoch_end(state, my_train_unit) self.assertEqual( bcs._ckpt_dirpaths, @@ -810,7 +817,7 @@ def train_step(self, state: State, data: Batch) -> None: class MyTrainLossUnit(TrainUnit[Batch]): def __init__(self) -> None: super().__init__() - self.train_loss = 0.01 + self.train_loss: Optional[float] = 0.01 def train_step(self, state: State, data: Batch) -> None: return None diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 6c89a07124..e292f45a17 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -69,7 +69,8 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta): Note: If best_checkpoint_config is enabled, the attribute must be on the unit upon checkpoint time, and must be castable to "float". This value must be maintained by the unit, and updated - appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends. + appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends. If the metric value is None, the + checkpoint will be saved, without the metric value in the checkpoint name """ metadata_fname: Optional[str] = None @@ -224,17 +225,18 @@ def _generate_checkpoint_and_upkeep( ) metric_value = getattr(unit, best_checkpoint_config.monitored_metric) - try: - metric_value_f = float(metric_value) - except Exception as e: - raise RuntimeError( - f"Unable to convert monitored metric {best_checkpoint_config.monitored_metric} to a float. Please ensure the value can be converted to float and is not a multi-element tensor value." - ) from e - - # update checkpoint path to include the metric value info - checkpoint_path += ( - f"_{best_checkpoint_config.monitored_metric}={metric_value_f}" - ) + if metric_value is not None: + try: + metric_value_f = float(metric_value) + except Exception as e: + raise RuntimeError( + f"Unable to convert monitored metric {best_checkpoint_config.monitored_metric} to a float. Please ensure the value can be converted to float and is not a multi-element tensor value." + ) from e + + # update checkpoint path to include the metric value info + checkpoint_path += ( + f"_{best_checkpoint_config.monitored_metric}={metric_value_f}" + ) should_checkpoint = self._should_save_checkpoint(metric_value_f) if not should_checkpoint: @@ -256,18 +258,19 @@ def _generate_checkpoint_and_upkeep( self._remove_checkpoint(state) if best_checkpoint_config: - # insert the checkpoint path at the right index to preserve ordering - keys = [ - float(os.path.basename(x).split("=")[-1]) - for x in self._ckpt_dirpaths - ] - if best_checkpoint_config.mode == "min": - keys.reverse() - # Use bisect.bisect() to find the insertion point - idx = bisect.bisect(keys, none_throws(metric_value_f)) - if best_checkpoint_config.mode == "min": - idx = len(self._ckpt_dirpaths) - idx - self._ckpt_dirpaths.insert(idx, checkpoint_path) + if metric_value_f: + # insert the checkpoint path at the right index to preserve ordering + keys = [ + float(os.path.basename(x).split("=")[-1]) + for x in self._ckpt_dirpaths + ] + if best_checkpoint_config.mode == "min": + keys.reverse() + # Use bisect.bisect() to find the insertion point + idx = bisect.bisect(keys, metric_value_f) + if best_checkpoint_config.mode == "min": + idx = len(self._ckpt_dirpaths) - idx + self._ckpt_dirpaths.insert(idx, checkpoint_path) else: self._ckpt_dirpaths.append(checkpoint_path)