Skip to content

Commit

Permalink
Fix type-safety of torch.nn.Module instances
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/generative-recommenders#129

X-link: pytorch/FBGEMM#3387

X-link: facebookresearch/FBGEMM#476

X-link: pytorch/torchrec#2562

As laid out in pytorch/pytorch#81462 (comment) the change in pytorch/pytorch#104321 was not necessary and largely destroys the type-safety of `torch.nn.Module` instances.

As far as I can see, the underlying issue of pytorch/pytorch#81462 in `torch.nn.parallel.DistributedDataParallel` has been fixed in the meantime by actually typing `register_comm_hook` correctly.

The proper solution to issues like pytorch/pytorch#81462 is to give the underlying field/method a proper type annotation, then there should be no need to go for a "type system disabling `__getattr__`".

(I'll probably be offline for a while, not able to react here...)

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 XilunWu rec mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse tianyu-l kiukchung lucasllc

Original PR: pytorch/pytorch#115074
Updated testing PR: pytorch/pytorch#141240

Reviewed By: malfet, aorenste, gineshidalgo99, larryliu0820

Differential Revision: D52890934

Pulled By: ezyang

fbshipit-source-id: 23af4111a80b471d810e0bf828f4d49a19b4ba80
  • Loading branch information
bluenote10 authored and facebook-github-bot committed Nov 21, 2024
1 parent 14ebfea commit dc5bafd
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
10 changes: 8 additions & 2 deletions tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,18 @@ def test_revert_sync_batchnorm(self) -> None:
self.assertNotIsInstance(batch_norm, torch.nn.SyncBatchNorm)
self.assertTrue(
torch.equal(
batch_norm.running_mean, none_throws(original_batchnorm.running_mean)
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Tensor, Module]`.
batch_norm.running_mean,
none_throws(original_batchnorm.running_mean),
)
)
self.assertTrue(
torch.equal(
batch_norm.running_var, none_throws(original_batchnorm.running_var)
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Tensor, Module]`.
batch_norm.running_var,
none_throws(original_batchnorm.running_var),
)
)

Expand Down
18 changes: 16 additions & 2 deletions torchtnt/framework/_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,23 @@ def _set_module_training_mode(
is_ddp = isinstance(module, DistributedDataParallel)

if _EXPORT_UTILS_AVAIL and model_is_exported(
module.module if is_ddp else module
# pyre-fixme[6]: For 1st argument expected `Module` but got
# `Union[Module, Tensor]`.
module.module
if is_ddp
else module
):
move_fn = (
torch.ao.quantization.move_exported_model_to_train
if mode
else torch.ao.quantization.move_exported_model_to_eval
)
# pyre-fixme[6]: For 1st argument expected `GraphModule` but got
# `Union[Module, Tensor]`.
move_fn(module.module if is_ddp else module)
module.training = mode
if is_ddp:
# pyre-fixme[16]: `Tensor` has no attribute `training`.
module.module.training = mode
else:
module.train(mode)
Expand All @@ -122,16 +129,23 @@ def _reset_module_training_mode(
is_ddp = isinstance(module, DistributedDataParallel)

if _EXPORT_UTILS_AVAIL and model_is_exported(
module.module if is_ddp else module
# pyre-fixme[6]: For 1st argument expected `Module` but got
# `Union[Module, Tensor]`.
module.module
if is_ddp
else module
):
move_fn = (
torch.ao.quantization.move_exported_model_to_train
if prior_modes[name]
else torch.ao.quantization.move_exported_model_to_eval
)
# pyre-fixme[6]: For 1st argument expected `GraphModule` but got
# `Union[Module, Tensor]`.
move_fn(module.module if is_ddp else module)
module.training = prior_modes[name]
if is_ddp:
# pyre-fixme[16]: `Tensor` has no attribute `training`.
module.module.training = prior_modes[name]
else:
module.train(prior_modes[name])
Expand Down
1 change: 1 addition & 0 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
# https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
# https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync
maybe_no_sync = (
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
module.no_sync()
if not should_update_weights
and (isinstance(module, DDP) or _is_fsdp_module(module))
Expand Down
1 change: 1 addition & 0 deletions torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def revert_sync_batchnorm(
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
# pyre-fixme[16]: `_BatchNormXd` has no attribute `qconfig`.
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child, device))
Expand Down

0 comments on commit dc5bafd

Please sign in to comment.