diff --git a/thunder/tests/distributed/helper.py b/thunder/tests/distributed/helper.py index 55a3e3cc51..03c711daeb 100644 --- a/thunder/tests/distributed/helper.py +++ b/thunder/tests/distributed/helper.py @@ -2,7 +2,6 @@ from functools import partial from functools import wraps from typing import ClassVar, TYPE_CHECKING -import inspect import math import os import sys @@ -130,14 +129,10 @@ def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False): local_rank = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_rank) os.environ["LOCAL_RANK"] = str(local_rank) - if "destroy_process_group" in inspect.signature(self.run_test).parameters: - run_test_kwargs = {"destroy_process_group": False} - else: - run_test_kwargs = {} torch.distributed.barrier() try: - self.run_test(test_name, pipe, **run_test_kwargs) + self.run_test(test_name, pipe) except Exception: raise finally: