diff --git a/thunder/tests/distributed/helper.py b/thunder/tests/distributed/helper.py index 03c711daeb..55a3e3cc51 100644 --- a/thunder/tests/distributed/helper.py +++ b/thunder/tests/distributed/helper.py @@ -2,6 +2,7 @@ from functools import partial from functools import wraps from typing import ClassVar, TYPE_CHECKING +import inspect import math import os import sys @@ -129,10 +130,14 @@ def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False): local_rank = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_rank) os.environ["LOCAL_RANK"] = str(local_rank) + if "destroy_process_group" in inspect.signature(self.run_test).parameters: + run_test_kwargs = {"destroy_process_group": False} + else: + run_test_kwargs = {} torch.distributed.barrier() try: - self.run_test(test_name, pipe) + self.run_test(test_name, pipe, **run_test_kwargs) except Exception: raise finally: