From 2ee08aaae07ae0506ca356151868e3fd7582f270 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 19 Mar 2024 03:40:06 -0400 Subject: [PATCH] Disable FSDP use original parameters --- tests/gpu_tests/fsdp_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/gpu_tests/fsdp_test.py b/tests/gpu_tests/fsdp_test.py index 9f25e47..e4188c3 100644 --- a/tests/gpu_tests/fsdp_test.py +++ b/tests/gpu_tests/fsdp_test.py @@ -55,7 +55,8 @@ def setUpClass(cls) -> None: cls.model = cls.model.to(device=device) cls.model = DistributedDataParallel(cls.model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=100) - cls.model = FSDP(cls.model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) + cls.model = FSDP(cls.model, use_orig_params=False, auto_wrap_policy=my_auto_wrap_policy) + print(cls.model) cls.analyzer = Analyzer( analysis_name="gpu_test",