forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Skip nnmodule hook guards by default (pytorch#98371)
This PR makes basic nnmodule forward hooks work by default, without any overhead. But it leaves silent correctness issues if users modify/remove their hooks later, thus also emits a warning. - the usual case is to not use hooks, so avoid guard overhead here - registering any hook before compile will trigger a warning about hook support - registering a hook later (or removing one) requires user knowledge and opting in, currently this isn't warnable (but maybe we can observe compiled nnmodules to make it warnable). Why skip hook guards by default instead of not tracing __call__/hooks by default? - avoid having a mode flag that alters dynamo tracing behavior (harder to test both codepaths in CI with full coverage) - the most basic hook usecase (registering a hook before compile, and never removing it) will work by default with this PR, while it would require enablement and incur overhead in the 'not tracing __call__' proposal. Pull Request resolved: pytorch#98371 Approved by: https://github.com/jansel
- Loading branch information
1 parent
46d765c
commit 390c51b
Showing
6 changed files
with
113 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,3 +78,4 @@ please check out the references below. | |
|
||
get-started | ||
technical-overview | ||
nn-module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
PyTorch 2.0 NNModule Support | ||
============================ | ||
|
||
**Author**: `Will Constable <https://github.com/wconstab>`_ | ||
|
||
`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces | ||
arbitrary python classes, with the intent of producing faster code by making assumptions about the structure. | ||
|
||
This doc describes some of the tradeoffs or edge cases that come up due to this specialization. | ||
|
||
NNModule Hooks Support | ||
---------------------- | ||
Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered | ||
they would simply be ignored in the compiled program. Indeed many users do not | ||
use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases | ||
for composing nn.Module hooks with `torch.compile`. | ||
|
||
Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`, | ||
`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'. | ||
These hooks are partially supported by `torch.compile` with limitations described below. | ||
|
||
Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still | ||
unsupported by `torch.compile`. | ||
|
||
`nn.Module.__call__` Hooks Usage and limitations | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter | ||
and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove | ||
or alter the hooks later, your use case should be supported by default. | ||
|
||
**skip_nnmodule_hook_guards** | ||
By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed | ||
on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing | ||
if any hook dict is changed after compilation. | ||
|
||
If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately | ||
(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added | ||
guards. | ||
|
||
TODO: confirm if backward/pre_backward hooks are working or not and document accordingly | ||
|
||
state_dict Hooks | ||
~~~~~~~~~~~~~~~~ | ||
State dict hooks have not yet been supported in `torch.compile`. | ||
|
||
|
||
TODO: warn_once if graph-breaking on hooks. warn_once to point to this doc if hooks are present. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters