Skip to content

Commit

Permalink
Support None metric value in best checkpointing (#688)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #688

Reviewed By: galrotem

Differential Revision: D53023423

fbshipit-source-id: 8627deff9b57c032b8a2cf345fb67bb9efc61bc8
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 24, 2024
1 parent 918313d commit 3f2ecbf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
9 changes: 8 additions & 1 deletion tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,13 @@ def test_best_checkpoint_no_top_k(self) -> None:
state = get_dummy_train_state()
my_train_unit = MyTrainLossUnit()

my_train_unit.train_loss = None
bcs.on_train_epoch_end(state, my_train_unit)
# none metric-value will not be updated in checkpoint dirpaths
self.assertEqual(bcs._ckpt_dirpaths, [])
self.assertEqual(os.listdir(temp_dir), ["epoch_0_step_0"])

my_train_unit.train_loss = 0.01
bcs.on_train_epoch_end(state, my_train_unit)
self.assertEqual(
bcs._ckpt_dirpaths,
Expand Down Expand Up @@ -810,7 +817,7 @@ def train_step(self, state: State, data: Batch) -> None:
class MyTrainLossUnit(TrainUnit[Batch]):
def __init__(self) -> None:
super().__init__()
self.train_loss = 0.01
self.train_loss: Optional[float] = 0.01

def train_step(self, state: State, data: Batch) -> None:
return None
51 changes: 27 additions & 24 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
Note:
If best_checkpoint_config is enabled, the attribute must be on the unit upon checkpoint time, and must be castable to "float". This value must be maintained by the unit, and updated
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends. If the metric value is None, the
checkpoint will be saved, without the metric value in the checkpoint name
"""

metadata_fname: Optional[str] = None
Expand Down Expand Up @@ -224,17 +225,18 @@ def _generate_checkpoint_and_upkeep(
)

metric_value = getattr(unit, best_checkpoint_config.monitored_metric)
try:
metric_value_f = float(metric_value)
except Exception as e:
raise RuntimeError(
f"Unable to convert monitored metric {best_checkpoint_config.monitored_metric} to a float. Please ensure the value can be converted to float and is not a multi-element tensor value."
) from e

# update checkpoint path to include the metric value info
checkpoint_path += (
f"_{best_checkpoint_config.monitored_metric}={metric_value_f}"
)
if metric_value is not None:
try:
metric_value_f = float(metric_value)
except Exception as e:
raise RuntimeError(
f"Unable to convert monitored metric {best_checkpoint_config.monitored_metric} to a float. Please ensure the value can be converted to float and is not a multi-element tensor value."
) from e

# update checkpoint path to include the metric value info
checkpoint_path += (
f"_{best_checkpoint_config.monitored_metric}={metric_value_f}"
)

should_checkpoint = self._should_save_checkpoint(metric_value_f)
if not should_checkpoint:
Expand All @@ -256,18 +258,19 @@ def _generate_checkpoint_and_upkeep(
self._remove_checkpoint(state)

if best_checkpoint_config:
# insert the checkpoint path at the right index to preserve ordering
keys = [
float(os.path.basename(x).split("=")[-1])
for x in self._ckpt_dirpaths
]
if best_checkpoint_config.mode == "min":
keys.reverse()
# Use bisect.bisect() to find the insertion point
idx = bisect.bisect(keys, none_throws(metric_value_f))
if best_checkpoint_config.mode == "min":
idx = len(self._ckpt_dirpaths) - idx
self._ckpt_dirpaths.insert(idx, checkpoint_path)
if metric_value_f:
# insert the checkpoint path at the right index to preserve ordering
keys = [
float(os.path.basename(x).split("=")[-1])
for x in self._ckpt_dirpaths
]
if best_checkpoint_config.mode == "min":
keys.reverse()
# Use bisect.bisect() to find the insertion point
idx = bisect.bisect(keys, metric_value_f)
if best_checkpoint_config.mode == "min":
idx = len(self._ckpt_dirpaths) - idx
self._ckpt_dirpaths.insert(idx, checkpoint_path)
else:
self._ckpt_dirpaths.append(checkpoint_path)

Expand Down

0 comments on commit 3f2ecbf

Please sign in to comment.