From 8d1637f0d16908cc08e62c92ae5a92ca4ab59552 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 19 Nov 2024 11:14:52 +0100 Subject: [PATCH] fix distributed tests with pt main (#1452) --- thunder/tests/distributed/helper.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/thunder/tests/distributed/helper.py b/thunder/tests/distributed/helper.py index 03c711daeb..55a3e3cc51 100644 --- a/thunder/tests/distributed/helper.py +++ b/thunder/tests/distributed/helper.py @@ -2,6 +2,7 @@ from functools import partial from functools import wraps from typing import ClassVar, TYPE_CHECKING +import inspect import math import os import sys @@ -129,10 +130,14 @@ def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False): local_rank = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_rank) os.environ["LOCAL_RANK"] = str(local_rank) + if "destroy_process_group" in inspect.signature(self.run_test).parameters: + run_test_kwargs = {"destroy_process_group": False} + else: + run_test_kwargs = {} torch.distributed.barrier() try: - self.run_test(test_name, pipe) + self.run_test(test_name, pipe, **run_test_kwargs) except Exception: raise finally: