Skip to content

Commit

Permalink
Add module summary callback (#284)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #284

added callback to print module summary at start of train

Reviewed By: ananthsub

Differential Revision: D42137915

fbshipit-source-id: b0f9f1e0fe914cf3fc42d2094b7dbaf7fd6fc04b
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Dec 21, 2022
1 parent ecb5d5b commit ad1593d
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 0 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
parameterized
pytest
pytest-cov
torcheval-nightly
torchsnapshot-nightly
127 changes: 127 additions & 0 deletions tests/framework/callbacks/test_module_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Any, Tuple
from unittest.mock import MagicMock

import torch
from torchtnt.framework import AutoUnit
from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader
from torchtnt.framework.callbacks.module_summary import ModuleSummary
from torchtnt.framework.state import EntryPoint, PhaseState, State


class ModuleSummaryTest(unittest.TestCase):
def test_module_summary_max_depth(self) -> None:
"""
Test ModuleSummary callback with train entry point
"""
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 1

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = State(
entry_point=EntryPoint.TRAIN,
train_state=PhaseState(
dataloader=dataloader,
max_epochs=max_epochs,
),
)

my_unit = MagicMock(spec=DummyTrainUnit)
module_summary_callback = ModuleSummary(max_depth=2)
module_summary_callback.on_train_epoch_start(state, my_unit)
self.assertEqual(module_summary_callback._max_depth, 2)

def test_module_summary_retrieve_module_summaries(self) -> None:
"""
Test ModuleSummary callback in train
"""

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
self.b1 = torch.nn.BatchNorm1d(2)
self.l2 = torch.nn.Linear(2, 2)

def forward(self, x):
x = self.l1(x)
x = self.b1(x)
x = self.l2(x)
return x

my_module = Net()
my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)

auto_unit = DummyAutoUnit(
module=my_module,
optimizer=my_optimizer,
)

module_summary_callback = ModuleSummary()
summaries = module_summary_callback._retrieve_module_summaries(auto_unit)

self.assertEqual(len(summaries), 1)
ms = summaries[0]
self.assertEqual(ms.module_name, "module")
self.assertEqual(ms.module_type, "Net")
self.assertTrue("l1" in ms.submodule_summaries)
self.assertTrue("b1" in ms.submodule_summaries)
self.assertTrue("l2" in ms.submodule_summaries)

def test_module_summary_retrieve_module_summaries_module_inputs(self) -> None:
"""
Test ModuleSummary callback in train
"""

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
self.b1 = torch.nn.BatchNorm1d(2)
self.l2 = torch.nn.Linear(2, 2)

def forward(self, x):
x = self.l1(x)
x = self.b1(x)
x = self.l2(x)
return x

my_module = Net()
my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)

auto_unit = DummyAutoUnit(
module=my_module,
optimizer=my_optimizer,
)

module_inputs = {"module": ((torch.rand(2, 2),), {})}
module_summary_callback = ModuleSummary(module_inputs=module_inputs)
summaries = module_summary_callback._retrieve_module_summaries(auto_unit)

self.assertEqual(len(summaries), 1)
ms = summaries[0]
self.assertEqual(ms.flops_forward, 16)
self.assertEqual(ms.flops_backward, 24)
self.assertEqual(ms.in_size, [2, 2])
self.assertTrue(ms.forward_elapsed_time_ms != "?")


Batch = Tuple[torch.tensor, torch.tensor]


class DummyAutoUnit(AutoUnit[Batch]):
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)

return loss, outputs
2 changes: 2 additions & 0 deletions torchtnt/framework/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .garbage_collector import GarbageCollector
from .lambda_callback import Lambda
from .learning_rate_monitor import LearningRateMonitor
from .module_summary import ModuleSummary
from .pytorch_profiler import PyTorchProfiler
from .tensorboard_parameter_monitor import TensorBoardParameterMonitor
from .torchsnapshot_saver import TorchSnapshotSaver
Expand All @@ -19,6 +20,7 @@
"GarbageCollector",
"Lambda",
"LearningRateMonitor",
"ModuleSummary",
"PyTorchProfiler",
"TensorBoardParameterMonitor",
"TorchSnapshotSaver",
Expand Down
96 changes: 96 additions & 0 deletions torchtnt/framework/callbacks/module_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple

try:
from torcheval.tools import (
get_module_summary,
get_summary_table,
ModuleSummary as ModuleSummaryObj,
prune_module_summary,
)

_TORCHEVAL_AVAILABLE = True
except Exception:
_TORCHEVAL_AVAILABLE = False

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.rank_zero_log import rank_zero_info


def _log_module_summary_tables(module_summaries: List[ModuleSummaryObj]) -> None:
for ms in module_summaries:
rank_zero_info("\n" + get_summary_table(ms))


class ModuleSummary(Callback):
"""
A callback which generates and logs a summary of the modules. Requires torcheval:
https://pytorch.org/torcheval/stable/
Args:
max_depth: The maximum depth of module summaries to keep.
process_fn: Function to print the module summaries. Default is to log all module summary tables.
module_inputs: A mapping from module name to (args, kwargs) for that module. Useful when wanting FLOPS, activation sizes, etc.
Raises:
RuntimeError:
If torcheval is not installed.
"""

def __init__(
self,
max_depth: Optional[int] = None,
process_fn: Callable[
[List[ModuleSummaryObj]], None
] = _log_module_summary_tables,
# pyre-ignore
module_inputs: Optional[
MutableMapping[str, Tuple[Tuple[Any, ...], Dict[str, Any]]]
] = None,
) -> None:
if not _TORCHEVAL_AVAILABLE:
raise RuntimeError(
"ModuleSummary support requires torcheval. "
"Please make sure ``torcheval`` is installed. "
"Installation: https://github.com/pytorch/torcheval#installing-torcheval"
)
self._max_depth = max_depth
self._process_fn = process_fn
self._module_inputs = module_inputs

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
self._get_and_process_summaries(unit)

def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
if state.entry_point != EntryPoint.EVALUATE:
return
self._get_and_process_summaries(unit)

def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
self._get_and_process_summaries(unit)

def _retrieve_module_summaries(self, unit: AppStateMixin) -> List[ModuleSummaryObj]:
module_summaries = []
for module_name, module in unit.tracked_modules().items():
args, kwargs = (), {}
if self._module_inputs and module_name in self._module_inputs:
args, kwargs = self._module_inputs[module_name]
module_summary = get_module_summary(
module, module_args=args, module_kwargs=kwargs
)
module_summary._module_name = module_name
if self._max_depth:
prune_module_summary(module_summary, max_depth=self._max_depth)
module_summaries.append(module_summary)
return module_summaries

def _get_and_process_summaries(self, unit: AppStateMixin) -> None:
module_summaries = self._retrieve_module_summaries(unit)
self._process_fn(module_summaries)

0 comments on commit ad1593d

Please sign in to comment.