Skip to content

Commit

Permalink
add ddp test
Browse files Browse the repository at this point in the history
ghstack-source-id: 18074e0f320faf1065e1ca101bbb820d9644b7a7
Pull Request resolved: #1114
  • Loading branch information
H-Huang committed May 14, 2024
1 parent cdc0ac6 commit 2aa360f
Showing 1 changed file with 44 additions and 25 deletions.
69 changes: 44 additions & 25 deletions test/test_composability.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.nn.parallel import DistributedDataParallel as DDP

# torch.testing._internal.common_distributed requies "expecttest"
from torch.testing._internal.common_distributed import MultiProcessTestCase
Expand Down Expand Up @@ -198,10 +199,12 @@ def test_manual_pipeline_with_manual_allreduce(self):
ddp_pp_model.all_reduce(num_microbatches)
print(f"{self.rank} finished all_reduce")

# pytest test/test_composability.py -vsk test_manual_pipeline_with_data_parallel_dp_type_DDP
@parametrize("dp_type", ["DDP", "FSDP"])
@parametrize(
"schedule_name", ["gpipe", "1f1b", "looped_bfs", "interleaved_1f1b"]
)
def test_manual_pipeline_with_fsdp(self, schedule_name):
def test_manual_pipeline_with_data_parallel(self, dp_type, schedule_name):
device_mesh, device = self._init_device_mesh(
mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")
)
Expand Down Expand Up @@ -229,25 +232,32 @@ def build_stage(stage_idx, num_stages):
)
partial_model.to(device)

# apply FSDP
mp_policy = MixedPrecisionPolicy(
# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer in partial_model.children():
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=False,
if dp_type == "FSDP":
# apply FSDP
mp_policy = MixedPrecisionPolicy(
# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer in partial_model.children():
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=False,
)
data_parallel_model = fully_shard(partial_model, **fsdp_config)
elif dp_type == "DDP":
data_parallel_model = DDP(
partial_model, process_group=dp_mesh.get_group()
)
fsdp_model = fully_shard(partial_model, **fsdp_config)
else:
raise RuntimeError(f"unsupported dp type {dp_type}")

stage = self._create_manual_pipeline_stage(
fsdp_model,
data_parallel_model,
stage_idx,
num_stages,
device,
Expand Down Expand Up @@ -339,14 +349,23 @@ def build_stage(stage_idx, num_stages):
# Validate that whichever weights we have locally match that part of our local/full ref model
# (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler)
ref_parameters = dict(ref_model.named_parameters())
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertTrue(isinstance(p.grad, DTensor))
self.assertEqual(ref_p.grad, p.grad.full_tensor())
if dp_type == "FSDP":
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertTrue(isinstance(p.grad, DTensor))
self.assertEqual(ref_p.grad, p.grad.full_tensor())
elif dp_type == "DDP":
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")[1:] # remove the "module." prefix
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertEqual(ref_p.grad, p.grad)


instantiate_parametrized_tests(TestPipelineComposability)
Expand Down

0 comments on commit 2aa360f

Please sign in to comment.