Skip to content

Commit

Permalink
add unittest for dpp gather autograd compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Oct 8, 2024
1 parent 8122e9f commit c2b6d19
Showing 1 changed file with 126 additions and 0 deletions.
126 changes: 126 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,62 @@ def _test_ddp_gather_uneven_tensors_multidim(rank: int, worldsize: int = NUM_PRO
assert (val == torch.ones_like(val)).all()


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.
This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test only considers tensors of the same shape across different ranks.
Note that this test only works for torch>=2.0.
"""
tensor = torch.ones(50, requires_grad=True)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.
This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test considers tensors of different shapes across different ranks.
Note that this test only works for torch>=2.0.
"""
tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> None:
dummy = DummyMetricSum()
dummy._reductions = {"x": torch.sum}
Expand Down Expand Up @@ -105,6 +161,76 @@ def test_ddp(process):
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.
This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test only considers tensors of the same shape across different ranks.
Note that this test only works for torch>=2.0.
"""
tensor = torch.ones(50, requires_grad=True)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.
This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test considers tensors of different shapes across different ranks.
Note that this test only works for torch>=2.0.
"""
tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


@pytest.mark.DDP()
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.parametrize(
"process",
[
_test_ddp_gather_autograd_same_shape,
_test_ddp_gather_autograd_different_shape,
],
)
def test_ddp_autograd(process):
"""Test ddp functions for autograd compatibility."""
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_non_contiguous_tensors(rank):
class DummyCatMetric(Metric):
full_state_update = True
Expand Down

0 comments on commit c2b6d19

Please sign in to comment.