From 8122e9f3e29efa4e1a4ec623b943740051f430d1 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Tue, 17 Sep 2024 13:58:40 -0400 Subject: [PATCH 01/20] propagate rank result to gathered result for autograd compatibility --- src/torchmetrics/utilities/distributed.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 455d64c4ae0..4f6eacea866 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -91,6 +91,8 @@ def class_reduce( def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) + # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result @@ -144,4 +146,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens for idx, item_size in enumerate(local_sizes): slice_param = [slice(dim_size) for dim_size in item_size] gathered_result[idx] = gathered_result[idx][slice_param] + # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result From c2b6d19aefc5aa7e0e856b843e8d164a05d6ccb5 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Tue, 17 Sep 2024 16:20:01 -0400 Subject: [PATCH 02/20] add unittest for dpp gather autograd compatibility --- tests/unittests/bases/test_ddp.py | 126 ++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index c057d0cbdf8..4219ee52a56 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -77,6 +77,62 @@ def _test_ddp_gather_uneven_tensors_multidim(rank: int, worldsize: int = NUM_PRO assert (val == torch.ones_like(val)).all() +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test only considers tensors of the same shape across different ranks. + + Note that this test only works for torch>=2.0. + + """ + tensor = torch.ones(50, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test considers tensors of different shapes across different ranks. + + Note that this test only works for torch>=2.0. + + """ + tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> None: dummy = DummyMetricSum() dummy._reductions = {"x": torch.sum} @@ -105,6 +161,76 @@ def test_ddp(process): pytest.pool.map(process, range(NUM_PROCESSES)) +def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test only considers tensors of the same shape across different ranks. + + Note that this test only works for torch>=2.0. + + """ + tensor = torch.ones(50, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + +def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test considers tensors of different shapes across different ranks. + + Note that this test only works for torch>=2.0. + + """ + tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + +@pytest.mark.DDP() +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +@pytest.mark.parametrize( + "process", + [ + _test_ddp_gather_autograd_same_shape, + _test_ddp_gather_autograd_different_shape, + ], +) +def test_ddp_autograd(process): + """Test ddp functions for autograd compatibility.""" + pytest.pool.map(process, range(NUM_PROCESSES)) + + def _test_non_contiguous_tensors(rank): class DummyCatMetric(Metric): full_state_update = True From d1e64e4f5b3601581473804d24daf63514bdf19a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 9 Oct 2024 10:06:14 +0200 Subject: [PATCH 03/20] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5700a43a98b..cdca66ce134 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605)) +- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754)) + + ### Changed - Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649)) From fc366b8d60186dbc5c27cf89766f1ca4bebd5504 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 9 Oct 2024 10:35:56 +0200 Subject: [PATCH 04/20] add to docs --- docs/source/pages/overview.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 5dabc545e50..bd07b62af9f 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -492,6 +492,10 @@ In practice this means that: A functional metric is differentiable if its corresponding modular metric is differentiable. +For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph +propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as +loss functions in a DDP setting. + *************************************** Metrics and hyperparameter optimization *************************************** From 59c9ced8311b2027b4be8580f79640a5b11cb795 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 08:36:17 +0000 Subject: [PATCH 05/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/pages/overview.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index bd07b62af9f..a5d10377c92 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -492,8 +492,8 @@ In practice this means that: A functional metric is differentiable if its corresponding modular metric is differentiable. -For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph -propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as +For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph +propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as loss functions in a DDP setting. *************************************** From 6f188a85b235a51893aa9ae92edf7b5713d3aade Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 9 Oct 2024 12:20:58 +0200 Subject: [PATCH 06/20] Apply suggestions from code review --- src/torchmetrics/utilities/distributed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 4f6eacea866..84f7e345057 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -92,7 +92,8 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) - gathered_result[torch.distributed.get_rank(group)] = result + if _TORCH_GREATER_EQUAL_2_1: + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result @@ -147,5 +148,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens slice_param = [slice(dim_size) for dim_size in item_size] gathered_result[idx] = gathered_result[idx][slice_param] # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) - gathered_result[torch.distributed.get_rank(group)] = result + if _TORCH_GREATER_EQUAL_2_1: + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result From ebb4f4cc1f0d156539d8ccb8ed7b8b68cb6b21a9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 9 Oct 2024 12:22:45 +0200 Subject: [PATCH 07/20] add missing import --- src/torchmetrics/utilities/distributed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 84f7e345057..c3a8edd68cc 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -18,6 +18,8 @@ from torch.nn import functional as F # noqa: N812 from typing_extensions import Literal +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 + def reduce(x: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor: """Reduces a given tensor by a given reduction method. From 05b6e96ab4926993f8df8387694cc5f74f0530a1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 9 Oct 2024 12:54:11 +0200 Subject: [PATCH 08/20] remove redundant functions --- tests/unittests/bases/test_ddp.py | 56 ------------------------------- 1 file changed, 56 deletions(-) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 4219ee52a56..09a69ecf4f4 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -77,62 +77,6 @@ def _test_ddp_gather_uneven_tensors_multidim(rank: int, worldsize: int = NUM_PRO assert (val == torch.ones_like(val)).all() -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") -def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: - """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks. - - This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in - preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained - with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. - This test only considers tensors of the same shape across different ranks. - - Note that this test only works for torch>=2.0. - - """ - tensor = torch.ones(50, requires_grad=True) - result = gather_all_tensors(tensor) - assert len(result) == worldsize - scalar1 = 0 - scalar2 = 0 - for idx in range(worldsize): - if idx == rank: - scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) - else: - scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) - scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) - gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] - gradient2 = torch.autograd.grad(scalar2, [tensor])[0] - assert torch.allclose(gradient1, gradient2) - - -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") -def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: - """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks. - - This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in - preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained - with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. - This test considers tensors of different shapes across different ranks. - - Note that this test only works for torch>=2.0. - - """ - tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) - result = gather_all_tensors(tensor) - assert len(result) == worldsize - scalar1 = 0 - scalar2 = 0 - for idx in range(worldsize): - if idx == rank: - scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) - else: - scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) - scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) - gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] - gradient2 = torch.autograd.grad(scalar2, [tensor])[0] - assert torch.allclose(gradient1, gradient2) - - def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> None: dummy = DummyMetricSum() dummy._reductions = {"x": torch.sum} From f854bf2dcdfab23bbc6fc23ad9ffd9900cf6fd19 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Thu, 10 Oct 2024 18:21:58 -0400 Subject: [PATCH 09/20] try no_grad for the all gather --- src/torchmetrics/utilities/distributed.py | 28 ++++++++++++----------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index c3a8edd68cc..45b736a436c 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -91,8 +91,9 @@ def class_reduce( def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - torch.distributed.all_gather(gathered_result, result, group) + with torch.no_grad(): + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + torch.distributed.all_gather(gathered_result, result, group) # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) if _TORCH_GREATER_EQUAL_2_1: gathered_result[torch.distributed.get_rank(group)] = result @@ -138,17 +139,18 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens return _simple_gather_all_tensors(result, group, world_size) # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate - pad_dims = [] - pad_by = (max_size - local_size).detach().cpu() - for val in reversed(pad_by): - pad_dims.append(0) - pad_dims.append(val.item()) - result_padded = F.pad(result, pad_dims) - gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] - torch.distributed.all_gather(gathered_result, result_padded, group) - for idx, item_size in enumerate(local_sizes): - slice_param = [slice(dim_size) for dim_size in item_size] - gathered_result[idx] = gathered_result[idx][slice_param] + with torch.no_grad(): + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + result_padded = F.pad(result, pad_dims) + gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] + torch.distributed.all_gather(gathered_result, result_padded, group) + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_result[idx] = gathered_result[idx][slice_param] # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) if _TORCH_GREATER_EQUAL_2_1: gathered_result[torch.distributed.get_rank(group)] = result From 25ffff2755fdd7b7a91fb1d716571f4d7f1479c6 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Thu, 10 Oct 2024 23:59:36 -0400 Subject: [PATCH 10/20] retry with all tested torch versions --- src/torchmetrics/utilities/distributed.py | 8 ++++---- tests/unittests/bases/test_ddp.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 45b736a436c..98fffda0f7f 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -95,8 +95,8 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) - if _TORCH_GREATER_EQUAL_2_1: - gathered_result[torch.distributed.get_rank(group)] = result + #if _TORCH_GREATER_EQUAL_2_1: + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result @@ -152,6 +152,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens slice_param = [slice(dim_size) for dim_size in item_size] gathered_result[idx] = gathered_result[idx][slice_param] # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) - if _TORCH_GREATER_EQUAL_2_1: - gathered_result[torch.distributed.get_rank(group)] = result + #if _TORCH_GREATER_EQUAL_2_1: + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 09a69ecf4f4..e385309249d 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -162,7 +162,7 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR @pytest.mark.DDP() @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +#@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") @pytest.mark.parametrize( "process", [ From e82c70ebb23654d4f718e2ac2119dd48a25237da Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 03:59:59 +0000 Subject: [PATCH 11/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/utilities/distributed.py | 4 ++-- tests/unittests/bases/test_ddp.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 98fffda0f7f..4cb704b7c5e 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -95,7 +95,7 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) - #if _TORCH_GREATER_EQUAL_2_1: + # if _TORCH_GREATER_EQUAL_2_1: gathered_result[torch.distributed.get_rank(group)] = result return gathered_result @@ -152,6 +152,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens slice_param = [slice(dim_size) for dim_size in item_size] gathered_result[idx] = gathered_result[idx][slice_param] # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) - #if _TORCH_GREATER_EQUAL_2_1: + # if _TORCH_GREATER_EQUAL_2_1: gathered_result[torch.distributed.get_rank(group)] = result return gathered_result diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index e385309249d..bfca63ee2cf 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -162,7 +162,7 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR @pytest.mark.DDP() @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") -#@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +# @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") @pytest.mark.parametrize( "process", [ From b5f285db867220d94a06f2360a911803901c03d3 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Fri, 11 Oct 2024 00:03:05 -0400 Subject: [PATCH 12/20] incorporate trials --- tests/unittests/bases/test_ddp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index bfca63ee2cf..98cfdcc490f 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -170,7 +170,11 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR _test_ddp_gather_autograd_different_shape, ], ) -def test_ddp_autograd(process): +@pytest.mark.parametrize( + "index", + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], +) +def test_ddp_autograd(process, index): """Test ddp functions for autograd compatibility.""" pytest.pool.map(process, range(NUM_PROCESSES)) From 91cff5ef26149cd13984d3731e3f0c124a9cfe9b Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 14 Oct 2024 22:23:08 +0200 Subject: [PATCH 13/20] lint --- src/torchmetrics/utilities/distributed.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 4cb704b7c5e..cbd317eb2f4 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -18,8 +18,6 @@ from torch.nn import functional as F # noqa: N812 from typing_extensions import Literal -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - def reduce(x: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor: """Reduces a given tensor by a given reduction method. From 4c13d6c7d424727721cf673fb72dc03a2dd86668 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Tue, 15 Oct 2024 09:32:42 -0400 Subject: [PATCH 14/20] try adding contiguous --- tests/unittests/bases/test_ddp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 98cfdcc490f..bea3f1f8f33 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -123,7 +123,8 @@ def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESS scalar2 = 0 for idx in range(worldsize): if idx == rank: - scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + tensor_contig = tensor.contiguous() + scalar1 = scalar1 + torch.sum(tensor_contig * torch.ones_like(tensor_contig)) else: scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) @@ -150,7 +151,8 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR scalar2 = 0 for idx in range(worldsize): if idx == rank: - scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + tensor_contig = tensor.contiguous() + scalar1 = scalar1 + torch.sum(tensor_contig * torch.ones_like(tensor_contig)) else: scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) From 150251cf831e3abd81c38380df9a859b72d75c1d Mon Sep 17 00:00:00 2001 From: cw-tan Date: Fri, 18 Oct 2024 11:39:58 -0400 Subject: [PATCH 15/20] try using float64 --- tests/unittests/bases/test_ddp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index bea3f1f8f33..a02db0b614c 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -116,7 +116,7 @@ def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESS Note that this test only works for torch>=2.0. """ - tensor = torch.ones(50, requires_grad=True) + tensor = torch.ones(50, dtype=torch.float64, requires_grad=True) result = gather_all_tensors(tensor) assert len(result) == worldsize scalar1 = 0 @@ -144,7 +144,7 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR Note that this test only works for torch>=2.0. """ - tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) + tensor = torch.ones(rank + 1, 2 - rank, dtype=torch.float64, requires_grad=True) result = gather_all_tensors(tensor) assert len(result) == worldsize scalar1 = 0 From 9b17d6f7a749806266ea7b7a4f56fa07be0b07d7 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Fri, 18 Oct 2024 20:06:32 -0400 Subject: [PATCH 16/20] try using random numbers --- tests/unittests/bases/test_ddp.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index a02db0b614c..fb44d6ee353 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -116,18 +116,19 @@ def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESS Note that this test only works for torch>=2.0. """ - tensor = torch.ones(50, dtype=torch.float64, requires_grad=True) + tensor = torch.randn(50, dtype=torch.float64, requires_grad=True) result = gather_all_tensors(tensor) assert len(result) == worldsize scalar1 = 0 scalar2 = 0 for idx in range(worldsize): + W = torch.randn_like(result[idx], requires_grad=False) if idx == rank: - tensor_contig = tensor.contiguous() - scalar1 = scalar1 + torch.sum(tensor_contig * torch.ones_like(tensor_contig)) + assert torch.allclose(result[idx], tensor) + scalar1 = scalar1 + torch.sum(tensor * W) else: - scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) - scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar1 = scalar1 + torch.sum(result[idx] * W) + scalar2 = scalar2 + torch.sum(result[idx] * W) gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] gradient2 = torch.autograd.grad(scalar2, [tensor])[0] assert torch.allclose(gradient1, gradient2) @@ -144,18 +145,19 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR Note that this test only works for torch>=2.0. """ - tensor = torch.ones(rank + 1, 2 - rank, dtype=torch.float64, requires_grad=True) + tensor = torch.randn(rank + 1, 2 - rank, dtype=torch.float64, requires_grad=True) result = gather_all_tensors(tensor) assert len(result) == worldsize scalar1 = 0 scalar2 = 0 for idx in range(worldsize): + W = torch.randn_like(result[idx], requires_grad=False) if idx == rank: - tensor_contig = tensor.contiguous() - scalar1 = scalar1 + torch.sum(tensor_contig * torch.ones_like(tensor_contig)) + assert torch.allclose(result[idx], tensor) + scalar1 = scalar1 + torch.sum(tensor * W) else: - scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) - scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar1 = scalar1 + torch.sum(result[idx] * W) + scalar2 = scalar2 + torch.sum(result[idx] * W) gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] gradient2 = torch.autograd.grad(scalar2, [tensor])[0] assert torch.allclose(gradient1, gradient2) From 8b263aeaa4fb896ccbc2a734da8f3f480f675504 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Oct 2024 11:26:33 +0100 Subject: [PATCH 17/20] fix changelog --- CHANGELOG.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89b08df13a0..e7adf17141f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725)) +- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754)) + + ### Changed - Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800)) @@ -70,8 +73,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `input_format` argument to segmentation metrics ([#2572](https://github.com/Lightning-AI/torchmetrics/pull/2572)) - Added `multi-output` support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605)) - Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776)) -- Added `HausdorffDistance` to segmentation package ([#2122](https://github.com/Lightning-AI/torchmetrics/pull/2122)) -- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754)) ### Changed From 8d2c27e8c4400bd164047825d6bc1498fe5c5a49 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Oct 2024 15:17:04 +0100 Subject: [PATCH 18/20] small changes to distributed --- src/torchmetrics/utilities/distributed.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index cbd317eb2f4..90239b46af0 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -92,8 +92,7 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L with torch.no_grad(): gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) - # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) - # if _TORCH_GREATER_EQUAL_2_1: + # to propagate autograd graph from local rank gathered_result[torch.distributed.get_rank(group)] = result return gathered_result @@ -149,7 +148,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens for idx, item_size in enumerate(local_sizes): slice_param = [slice(dim_size) for dim_size in item_size] gathered_result[idx] = gathered_result[idx][slice_param] - # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) - # if _TORCH_GREATER_EQUAL_2_1: + # to propagate autograd graph from local rank gathered_result[torch.distributed.get_rank(group)] = result return gathered_result From 48e699b6dfaa7a5fe46d8a984722f155c510c4d5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Oct 2024 15:17:51 +0100 Subject: [PATCH 19/20] tests --- tests/unittests/bases/test_ddp.py | 90 ++++++++++--------------------- tests/unittests/conftest.py | 15 +++++- 2 files changed, 41 insertions(+), 64 deletions(-) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index fb44d6ee353..07dee96f4da 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -27,6 +27,7 @@ from unittests import NUM_PROCESSES, USE_PYTEST_POOL from unittests._helpers import seed_all from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum +from unittests.conftest import setup_ddp seed_all(42) @@ -105,80 +106,43 @@ def test_ddp(process): pytest.pool.map(process, range(NUM_PROCESSES)) -def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: - """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks. +def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.""" + setup_ddp(rank, worldsize) + x = (rank + 1) * torch.ones(10, requires_grad=True) - This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in - preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained - with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. - This test only considers tensors of the same shape across different ranks. + # random linear transformation, it should really not matter what we do here + a, b = torch.randn(1), torch.randn(1) + y = a * x + b # gradient of y w.r.t. x is a - Note that this test only works for torch>=2.0. - - """ - tensor = torch.randn(50, dtype=torch.float64, requires_grad=True) - result = gather_all_tensors(tensor) + result = gather_all_tensors(y) assert len(result) == worldsize - scalar1 = 0 - scalar2 = 0 - for idx in range(worldsize): - W = torch.randn_like(result[idx], requires_grad=False) - if idx == rank: - assert torch.allclose(result[idx], tensor) - scalar1 = scalar1 + torch.sum(tensor * W) - else: - scalar1 = scalar1 + torch.sum(result[idx] * W) - scalar2 = scalar2 + torch.sum(result[idx] * W) - gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] - gradient2 = torch.autograd.grad(scalar2, [tensor])[0] - assert torch.allclose(gradient1, gradient2) - - -def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: - """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks. - - This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in - preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained - with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. - This test considers tensors of different shapes across different ranks. - - Note that this test only works for torch>=2.0. - - """ - tensor = torch.randn(rank + 1, 2 - rank, dtype=torch.float64, requires_grad=True) - result = gather_all_tensors(tensor) + grad = torch.autograd.grad(result[rank].sum(), x)[0] + assert torch.allclose(grad, a * torch.ones_like(x)) + + +def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.""" + setup_ddp(rank, worldsize) + x = (rank + 1) * torch.ones(rank + 1, 2 - rank, requires_grad=True) + + # random linear transformation, it should really not matter what we do here + a, b = torch.randn(1), torch.randn(1) + y = a * x + b # gradient of y w.r.t. x is a + + result = gather_all_tensors(y) assert len(result) == worldsize - scalar1 = 0 - scalar2 = 0 - for idx in range(worldsize): - W = torch.randn_like(result[idx], requires_grad=False) - if idx == rank: - assert torch.allclose(result[idx], tensor) - scalar1 = scalar1 + torch.sum(tensor * W) - else: - scalar1 = scalar1 + torch.sum(result[idx] * W) - scalar2 = scalar2 + torch.sum(result[idx] * W) - gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] - gradient2 = torch.autograd.grad(scalar2, [tensor])[0] - assert torch.allclose(gradient1, gradient2) + grad = torch.autograd.grad(result[rank].sum(), x)[0] + assert torch.allclose(grad, a * torch.ones_like(x)) @pytest.mark.DDP() @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") -# @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") -@pytest.mark.parametrize( - "process", - [ - _test_ddp_gather_autograd_same_shape, - _test_ddp_gather_autograd_different_shape, - ], -) @pytest.mark.parametrize( - "index", - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "process", [_test_ddp_gather_all_autograd_same_shape, _test_ddp_gather_all_autograd_different_shape] ) -def test_ddp_autograd(process, index): +def test_ddp_autograd(process): """Test ddp functions for autograd compatibility.""" pytest.pool.map(process, range(NUM_PROCESSES)) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index f09f884adeb..58967ba2521 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -45,7 +45,17 @@ def use_deterministic_algorithms(): # noqa: PT004 def setup_ddp(rank, world_size): - """Initialize ddp environment.""" + """Initialize ddp environment. + + If a particular test relies on the order of the processes in the pool to be [0, 1, 2, ...], then this function + should be called inside the test to ensure that the processes are initialized in the same order they are used in + the tests. + + Args: + rank: the rank of the process + world_size: the number of processes + + """ global CURRENT_PORT os.environ["MASTER_ADDR"] = "localhost" @@ -55,6 +65,9 @@ def setup_ddp(rank, world_size): if CURRENT_PORT > MAX_PORT: CURRENT_PORT = START_PORT + if torch.distributed.group.WORLD is not None: # if already initialized, destroy the process group + torch.distributed.destroy_process_group() + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 5f29c4d1164bf4fa21ab4a88ba4f327baf4d72ef Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 31 Oct 2024 17:29:46 +0100 Subject: [PATCH 20/20] caution --- docs/source/pages/overview.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index a5d10377c92..34d0dcbd6fc 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -492,9 +492,10 @@ In practice this means that: A functional metric is differentiable if its corresponding modular metric is differentiable. -For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph -propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as -loss functions in a DDP setting. +.. caution:: + For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph + propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as + loss functions in a DDP setting. *************************************** Metrics and hyperparameter optimization