From 5b2df3f49cffc55b3d65bd82171109237936dd6d Mon Sep 17 00:00:00 2001 From: "Iris Zhang (PyTorch)" Date: Sat, 9 Nov 2024 23:39:37 -0800 Subject: [PATCH] Remove composable API's fully_shard from torchtnt example and test Differential Revision: D65702749 --- tests/utils/test_prepare_module_gpu.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/utils/test_prepare_module_gpu.py b/tests/utils/test_prepare_module_gpu.py index d583a7085d..8c1e436c37 100644 --- a/tests/utils/test_prepare_module_gpu.py +++ b/tests/utils/test_prepare_module_gpu.py @@ -9,8 +9,6 @@ import unittest import torch - -from torch.distributed._composable import fully_shard from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.nn.parallel import DistributedDataParallel as DDP @@ -93,7 +91,7 @@ def _test_is_fsdp_module() -> None: model = FSDP(torch.nn.Linear(1, 1, device=device)) assert _is_fsdp_module(model) model = torch.nn.Linear(1, 1, device=device) - fully_shard(model) + model = FSDP(model) assert _is_fsdp_module(model) @skip_if_not_distributed