Skip to content

Commit

Permalink
Skip nnmodule hook guards by default (pytorch#98371)
Browse files Browse the repository at this point in the history
This PR makes basic nnmodule forward hooks work by default, without any overhead.  But it leaves silent correctness issues if users modify/remove their hooks later, thus also emits a warning.

- the usual case is to not use hooks, so avoid guard overhead here
- registering any hook before compile will trigger a warning about hook support
- registering a hook later (or removing one) requires user knowledge and opting in,
  currently this isn't warnable (but maybe we can observe compiled nnmodules to make it
  warnable).

Why skip hook guards by default instead of not tracing __call__/hooks by default?
- avoid having a mode flag that alters dynamo tracing behavior (harder to test both codepaths
  in CI with full coverage)
- the most basic hook usecase (registering a hook before compile, and never removing it)
  will work by default with this PR, while it would require enablement and incur overhead
  in the 'not tracing __call__' proposal.

Pull Request resolved: pytorch#98371
Approved by: https://github.com/jansel
  • Loading branch information
wconstab authored and pytorchmergebot committed Apr 7, 2023
1 parent 46d765c commit 390c51b
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/source/compile/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ please check out the references below.

get-started
technical-overview
nn-module
47 changes: 47 additions & 0 deletions docs/source/compile/nn-module.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
PyTorch 2.0 NNModule Support
============================

**Author**: `Will Constable <https://github.com/wconstab>`_

`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces
arbitrary python classes, with the intent of producing faster code by making assumptions about the structure.

This doc describes some of the tradeoffs or edge cases that come up due to this specialization.

NNModule Hooks Support
----------------------
Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered
they would simply be ignored in the compiled program. Indeed many users do not
use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases
for composing nn.Module hooks with `torch.compile`.

Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`,
`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'.
These hooks are partially supported by `torch.compile` with limitations described below.

Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still
unsupported by `torch.compile`.

`nn.Module.__call__` Hooks Usage and limitations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter
and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove
or alter the hooks later, your use case should be supported by default.

**skip_nnmodule_hook_guards**
By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed
on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing
if any hook dict is changed after compilation.

If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately
(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added
guards.

TODO: confirm if backward/pre_backward hooks are working or not and document accordingly

state_dict Hooks
~~~~~~~~~~~~~~~~
State dict hooks have not yet been supported in `torch.compile`.


TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present.
2 changes: 2 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,7 @@ def fn(x):
)
)

@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
def test_hooks_outer(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1354,6 +1355,7 @@ def guard_fail_fn(failure):
the eval_frame entrypoint to Module.__call__?
"""

@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
def test_hooks_inner(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
# Make dynamo skip guarding on hooks on nn modules
# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
# dynamo will not notice and will execute whichever version you first compiled.
skip_nnmodule_hook_guards = False
skip_nnmodule_hook_guards = True

# If True, raises exception if TorchDynamo is called with a context manager
raise_on_ctx_manager_usage = True
Expand Down
40 changes: 18 additions & 22 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
dynamo_timed,
format_graph_code,
format_graph_tabular,
nnmodule_doc_url_msg,
nnmodule_has_hooks,
same,
)
from .variables.base import VariableTracker
Expand Down Expand Up @@ -382,26 +384,6 @@ def update_co_names(self, name):
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)

@staticmethod
def module_has_hooks(mod, only_check_unsupported=False):
supported_hooks = [
"_forward_pre_hooks",
"_forward_hooks",
]
unsupported_hooks = [
"_backward_pre_hooks",
"_backward_hooks",
"_state_dict_pre_hooks",
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks",
]
check_hooks = unsupported_hooks
if not only_check_unsupported:
check_hooks += supported_hooks

return any(len(getattr(mod, x)) > 0 for x in check_hooks if hasattr(mod, x))

def register_attr_or_module(
self,
target: Union[torch.nn.Module, torch.Tensor, Any],
Expand Down Expand Up @@ -433,10 +415,24 @@ def wrap_name(module_key):

elif isinstance(target, torch.nn.Module):
assert isinstance(target, torch.nn.Module)
if self.module_has_hooks(target, only_check_unsupported=True):
if nnmodule_has_hooks(target, check_forward_hooks=True):
torch._logging.warning_once(
log,
"nn.Module forward/_pre hooks are only partially supported, and were detected in your model. "
"In particular, if you do not change/remove hooks after calling .compile(), you can disregard this "
"warning, and otherwise you may need to set torch._dynamo.config.skip_nnmodule_hook_guards=False "
"to ensure recompiling after changing hooks."
f"{nnmodule_doc_url_msg} ",
)
if nnmodule_has_hooks(
target, check_backward_hooks=True, check_state_dict_hooks=True
):
torch._logging.warning_once(
log, "nn.Module hooks are not fully supported, they may be ignored"
log,
"nn.Module state_dict and backward hooks are not yet supported by torch.compile, "
f"but were detected in your model and will be silently ignored. {nnmodule_doc_url_msg}",
)

options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))

def wrap_name(module_key):
Expand Down
44 changes: 44 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@

counters = collections.defaultdict(collections.Counter)
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -1439,3 +1441,45 @@ def format_graph_tabular(fn_name, gm):

def format_bytecode(prefix, name, filename, line_no, code):
return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n"


def nnmodule_has_hooks(
mod,
check_forward_hooks=False,
check_backward_hooks=False,
check_state_dict_hooks=False,
):
"""
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
hooks executed during module.__call__, and state_dict hooks which are executed separately.
"""
hook_dicts_to_check = []
check_all_hooks = (
not check_forward_hooks
and not check_backward_hooks
and not check_state_dict_hooks
)
if check_forward_hooks or check_all_hooks:
hook_dicts_to_check.extend(
[
"_forward_pre_hooks",
"_forward_hooks",
]
)
if check_backward_hooks or check_all_hooks:
hook_dicts_to_check.extend(
[
"_backward_pre_hooks",
"_backward_hooks",
]
)
if check_state_dict_hooks:
hook_dicts_to_check.extend(
[
"_state_dict_pre_hooks",
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks",
]
)
return any(len(getattr(mod, x)) > 0 for x in hook_dicts_to_check if hasattr(mod, x))

0 comments on commit 390c51b

Please sign in to comment.