From 077c1fe27dc32ede6b0d7e86cdd6aacf04bf6b3e Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Tue, 21 May 2024 10:32:58 -0700 Subject: [PATCH] add warmup steps to throughput logger (#840) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/840 Adding warmup steps to different performance loggers Reviewed By: diego-urgell Differential Revision: D57596034 fbshipit-source-id: ceeb60ae08b7bae33f69525816407f36d0510bfc --- .../callbacks/test_throughput_logger.py | 37 ++++++++++++++++--- .../framework/callbacks/throughput_logger.py | 14 ++++++- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/tests/framework/callbacks/test_throughput_logger.py b/tests/framework/callbacks/test_throughput_logger.py index 30ff071fd1..61282d5200 100644 --- a/tests/framework/callbacks/test_throughput_logger.py +++ b/tests/framework/callbacks/test_throughput_logger.py @@ -32,7 +32,7 @@ class ThroughputLoggerTest(unittest.TestCase): def test_maybe_log_for_step(self) -> None: logger = MagicMock(spec=MetricLogger) - throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Items": 32}, 1) + throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Items": 32}) phase_state = PhaseState(dataloader=[]) phase_state.iteration_timer.recorded_durations = { "data_wait_time": [1, 4], @@ -75,7 +75,7 @@ def test_maybe_log_for_step(self) -> None: def test_maybe_log_for_step_early_return(self) -> None: logger = MagicMock(spec=MetricLogger) - throughput_logger = ThroughputLogger(logger, {"Batches": 1}, 1) + throughput_logger = ThroughputLogger(logger, {"Batches": 1}) phase_state = PhaseState(dataloader=[]) recorded_durations_dict = { "data_wait_time": [0.0, 4.0], @@ -101,7 +101,9 @@ def test_maybe_log_for_step_early_return(self) -> None: # step_logging_for % log_every_n_steps != 0 recorded_durations_dict["data_wait_time"] = [1.0, 2.0] - throughput_logger = ThroughputLogger(logger, {"Batches": 1}, 2) + throughput_logger = ThroughputLogger( + logger, {"Batches": 1}, log_every_n_steps=2 + ) throughput_logger._maybe_log_for_step(state, step_logging_for=1) logger.log.assert_not_called() @@ -330,17 +332,40 @@ def test_epoch_logging_time(self) -> None: any_order=True, ) + def test_warmup_steps(self) -> None: + logger = MagicMock(spec=MetricLogger) + throughput_logger = ThroughputLogger( + logger, {"Batches": 1, "Items": 32}, warmup_steps=1 + ) + phase_state = PhaseState(dataloader=[]) + phase_state.iteration_timer.recorded_durations = { + "data_wait_time": [1, 4], + "train_iteration_time": [3], + } + state = State(entry_point=EntryPoint.TRAIN, train_state=phase_state) + + throughput_logger._maybe_log_for_step(state, 1) + logger.log.assert_not_called() + + throughput_logger._maybe_log_for_step(state, 2) + self.assertEqual(logger.log.call_count, 2) + def test_input_validation(self) -> None: logger = MagicMock(spec=MetricLogger) with self.assertRaisesRegex(ValueError, "throughput_per_batch cannot be empty"): - ThroughputLogger(logger, {}, 1) + ThroughputLogger(logger, {}) with self.assertRaisesRegex( ValueError, "throughput_per_batch item Batches must be at least 1, got -1" ): - ThroughputLogger(logger, {"Queries": 8, "Batches": -1}, 1) + ThroughputLogger(logger, {"Queries": 8, "Batches": -1}) with self.assertRaisesRegex( ValueError, "log_every_n_steps must be at least 1, got 0" ): - ThroughputLogger(logger, {"Batches": 1}, 0) + ThroughputLogger(logger, {"Batches": 1}, log_every_n_steps=0) + + with self.assertRaisesRegex( + ValueError, "warmup_steps must be at least 0, got -1" + ): + ThroughputLogger(logger, {"Batches": 1}, warmup_steps=-1) diff --git a/torchtnt/framework/callbacks/throughput_logger.py b/torchtnt/framework/callbacks/throughput_logger.py index e8e381cbf1..d23d4c72c4 100644 --- a/torchtnt/framework/callbacks/throughput_logger.py +++ b/torchtnt/framework/callbacks/throughput_logger.py @@ -48,7 +48,8 @@ class ThroughputLogger(Callback): For instace, a user can pass in {Batches: 1, Queries: 32} which will visualize two charts - one for Batches per second and one for Queries per second. As an example, if each of your batches is of type: {data: torch.Size([16, 8, 8]), labels: torch.Size([16,1])}, then you could pass {Queries: 16}. - log_every_n_steps: an optional int to control the log frequency. + log_every_n_steps: an int to control the log frequency. Default is 1. + warmup_steps: an int to control the number of warmup steps. We will start logging only after the amount of warmup steps were completed. Default is 0. Note: The values reported are only for rank 0. @@ -59,7 +60,9 @@ def __init__( self, logger: MetricLogger, throughput_per_batch: Mapping[str, int], + *, log_every_n_steps: int = 1, + warmup_steps: int = 0, ) -> None: self._logger = logger @@ -80,6 +83,12 @@ def __init__( ) self._log_every_n_steps = log_every_n_steps + + if warmup_steps < 0: + raise ValueError(f"warmup_steps must be at least 0, got {warmup_steps}") + + self._warmup_steps = warmup_steps + self._epoch_start_times: Dict[ActivePhase, float] = {} self._steps_in_epoch: Dict[ActivePhase, int] = defaultdict(int) @@ -154,6 +163,9 @@ def _maybe_log_for_step( *, is_step_end_hook: bool = True, ) -> None: + if step_logging_for <= self._warmup_steps: + return + if step_logging_for % self._log_every_n_steps != 0: return