From 2aa360fc40060f33ad3b1571d3fae69343e55f69 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 13 May 2024 10:31:17 -0700 Subject: [PATCH] add ddp test ghstack-source-id: 18074e0f320faf1065e1ca101bbb820d9644b7a7 Pull Request resolved: https://github.com/pytorch/PiPPy/pull/1114 --- test/test_composability.py | 69 ++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/test/test_composability.py b/test/test_composability.py index 97f95d582..f2dae40f7 100644 --- a/test/test_composability.py +++ b/test/test_composability.py @@ -18,6 +18,7 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.nn.parallel import DistributedDataParallel as DDP # torch.testing._internal.common_distributed requies "expecttest" from torch.testing._internal.common_distributed import MultiProcessTestCase @@ -198,10 +199,12 @@ def test_manual_pipeline_with_manual_allreduce(self): ddp_pp_model.all_reduce(num_microbatches) print(f"{self.rank} finished all_reduce") + # pytest test/test_composability.py -vsk test_manual_pipeline_with_data_parallel_dp_type_DDP + @parametrize("dp_type", ["DDP", "FSDP"]) @parametrize( "schedule_name", ["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b"] ) - def test_manual_pipeline_with_fsdp(self, schedule_name): + def test_manual_pipeline_with_data_parallel(self, dp_type, schedule_name): device_mesh, device = self._init_device_mesh( mesh_shape=(2, 2), mesh_dim_names=("dp", "pp") ) @@ -229,25 +232,32 @@ def build_stage(stage_idx, num_stages): ) partial_model.to(device) - # apply FSDP - mp_policy = MixedPrecisionPolicy( - # TODO(whc) need to fix PP + FSDP-mixed-precision - # tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs - # param_dtype=torch.bfloat16, reduce_dtype=torch.float32 - param_dtype=torch.float32, - reduce_dtype=torch.float32, - ) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - for layer in partial_model.children(): - fully_shard( - layer, - **fsdp_config, - reshard_after_forward=False, + if dp_type == "FSDP": + # apply FSDP + mp_policy = MixedPrecisionPolicy( + # TODO(whc) need to fix PP + FSDP-mixed-precision + # tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs + # param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + param_dtype=torch.float32, + reduce_dtype=torch.float32, + ) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + for layer in partial_model.children(): + fully_shard( + layer, + **fsdp_config, + reshard_after_forward=False, + ) + data_parallel_model = fully_shard(partial_model, **fsdp_config) + elif dp_type == "DDP": + data_parallel_model = DDP( + partial_model, process_group=dp_mesh.get_group() ) - fsdp_model = fully_shard(partial_model, **fsdp_config) + else: + raise RuntimeError(f"unsupported dp type {dp_type}") stage = self._create_manual_pipeline_stage( - fsdp_model, + data_parallel_model, stage_idx, num_stages, device, @@ -339,14 +349,23 @@ def build_stage(stage_idx, num_stages): # Validate that whichever weights we have locally match that part of our local/full ref model # (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler) ref_parameters = dict(ref_model.named_parameters()) - for partial_model, offset in zip(partial_models, offsets): - for name, p in partial_model.named_parameters(): - parts = name.split(".") - parts[0] = str(int(parts[0]) + offset) - name = ".".join(parts) - ref_p = ref_parameters[name] - self.assertTrue(isinstance(p.grad, DTensor)) - self.assertEqual(ref_p.grad, p.grad.full_tensor()) + if dp_type == "FSDP": + for partial_model, offset in zip(partial_models, offsets): + for name, p in partial_model.named_parameters(): + parts = name.split(".") + parts[0] = str(int(parts[0]) + offset) + name = ".".join(parts) + ref_p = ref_parameters[name] + self.assertTrue(isinstance(p.grad, DTensor)) + self.assertEqual(ref_p.grad, p.grad.full_tensor()) + elif dp_type == "DDP": + for partial_model, offset in zip(partial_models, offsets): + for name, p in partial_model.named_parameters(): + parts = name.split(".")[1:] # remove the "module." prefix + parts[0] = str(int(parts[0]) + offset) + name = ".".join(parts) + ref_p = ref_parameters[name] + self.assertEqual(ref_p.grad, p.grad) instantiate_parametrized_tests(TestPipelineComposability)