From dc5bafdf328ae49562b75470b2ee8706a953a7bc Mon Sep 17 00:00:00 2001 From: Fabian Keller Date: Thu, 21 Nov 2024 15:02:49 -0800 Subject: [PATCH] Fix type-safety of `torch.nn.Module` instances Summary: X-link: https://github.com/facebookresearch/generative-recommenders/pull/129 X-link: https://github.com/pytorch/FBGEMM/pull/3387 X-link: https://github.com/facebookresearch/FBGEMM/pull/476 X-link: https://github.com/pytorch/torchrec/pull/2562 As laid out in https://github.com/pytorch/pytorch/issues/81462#issuecomment-1838731223 the change in https://github.com/pytorch/pytorch/pull/104321 was not necessary and largely destroys the type-safety of `torch.nn.Module` instances. As far as I can see, the underlying issue of https://github.com/pytorch/pytorch/issues/81462 in `torch.nn.parallel.DistributedDataParallel` has been fixed in the meantime by actually typing `register_comm_hook` correctly. The proper solution to issues like https://github.com/pytorch/pytorch/issues/81462 is to give the underlying field/method a proper type annotation, then there should be no need to go for a "type system disabling `__getattr__`". (I'll probably be offline for a while, not able to react here...) cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 XilunWu rec mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse tianyu-l kiukchung lucasllc Original PR: https://github.com/pytorch/pytorch/pull/115074 Updated testing PR: https://github.com/pytorch/pytorch/pull/141240 Reviewed By: malfet, aorenste, gineshidalgo99, larryliu0820 Differential Revision: D52890934 Pulled By: ezyang fbshipit-source-id: 23af4111a80b471d810e0bf828f4d49a19b4ba80 --- tests/utils/test_distributed.py | 10 ++++++++-- torchtnt/framework/_loop_utils.py | 18 ++++++++++++++++-- torchtnt/framework/auto_unit.py | 1 + torchtnt/utils/distributed.py | 1 + 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_distributed.py b/tests/utils/test_distributed.py index cb444267cd..61b30cfbdb 100644 --- a/tests/utils/test_distributed.py +++ b/tests/utils/test_distributed.py @@ -187,12 +187,18 @@ def test_revert_sync_batchnorm(self) -> None: self.assertNotIsInstance(batch_norm, torch.nn.SyncBatchNorm) self.assertTrue( torch.equal( - batch_norm.running_mean, none_throws(original_batchnorm.running_mean) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. + batch_norm.running_mean, + none_throws(original_batchnorm.running_mean), ) ) self.assertTrue( torch.equal( - batch_norm.running_var, none_throws(original_batchnorm.running_var) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. + batch_norm.running_var, + none_throws(original_batchnorm.running_var), ) ) diff --git a/torchtnt/framework/_loop_utils.py b/torchtnt/framework/_loop_utils.py index 2316da9185..ca1b2444bd 100644 --- a/torchtnt/framework/_loop_utils.py +++ b/torchtnt/framework/_loop_utils.py @@ -94,16 +94,23 @@ def _set_module_training_mode( is_ddp = isinstance(module, DistributedDataParallel) if _EXPORT_UTILS_AVAIL and model_is_exported( - module.module if is_ddp else module + # pyre-fixme[6]: For 1st argument expected `Module` but got + # `Union[Module, Tensor]`. + module.module + if is_ddp + else module ): move_fn = ( torch.ao.quantization.move_exported_model_to_train if mode else torch.ao.quantization.move_exported_model_to_eval ) + # pyre-fixme[6]: For 1st argument expected `GraphModule` but got + # `Union[Module, Tensor]`. move_fn(module.module if is_ddp else module) module.training = mode if is_ddp: + # pyre-fixme[16]: `Tensor` has no attribute `training`. module.module.training = mode else: module.train(mode) @@ -122,16 +129,23 @@ def _reset_module_training_mode( is_ddp = isinstance(module, DistributedDataParallel) if _EXPORT_UTILS_AVAIL and model_is_exported( - module.module if is_ddp else module + # pyre-fixme[6]: For 1st argument expected `Module` but got + # `Union[Module, Tensor]`. + module.module + if is_ddp + else module ): move_fn = ( torch.ao.quantization.move_exported_model_to_train if prior_modes[name] else torch.ao.quantization.move_exported_model_to_eval ) + # pyre-fixme[6]: For 1st argument expected `GraphModule` but got + # `Union[Module, Tensor]`. move_fn(module.module if is_ddp else module) module.training = prior_modes[name] if is_ddp: + # pyre-fixme[16]: `Tensor` has no attribute `training`. module.module.training = prior_modes[name] else: module.train(prior_modes[name]) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 225e14e3c8..622a21e49d 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -638,6 +638,7 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: # https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync # https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync maybe_no_sync = ( + # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. module.no_sync() if not should_update_weights and (isinstance(module, DDP) or _is_fsdp_module(module)) diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index 37bc1edafc..08c05ac158 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -436,6 +436,7 @@ def revert_sync_batchnorm( module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked if hasattr(module, "qconfig"): + # pyre-fixme[16]: `_BatchNormXd` has no attribute `qconfig`. module_output.qconfig = module.qconfig for name, child in module.named_children(): module_output.add_module(name, revert_sync_batchnorm(child, device))