diff --git a/benchmarks/accelerate_opt/main.py b/benchmarks/accelerate_opt/main.py index 9c003eda1..244050b4c 100644 --- a/benchmarks/accelerate_opt/main.py +++ b/benchmarks/accelerate_opt/main.py @@ -126,16 +126,6 @@ class CustomInitProcessGroupKwargs(InitProcessGroupKwargs): world_size=int(os.environ["WORLD_SIZE"]), ) - # Accelerator SUCK, it is impossible to make it use hccl - # We can bypass Accelerator logic by initializing the group ourselves - if acc.device_type == "hpu": - acc.init_process_group( - init_method=f"tcp://{MASTER_ADDR}:{MASTER_PORT}", - timeout=timedelta(seconds=60), - rank=int(os.environ["RANK"]), - world_size=int(os.environ["WORLD_SIZE"]), - ) - accelerator = Accelerator(kwargs_handlers=[init_process_group_kwargs]) else: accelerator = Accelerator() diff --git a/milabench/_version.py b/milabench/_version.py index 2f6bd5d42..ea26a5bb4 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v0.1.0-32-ge9e52501" -__commit__ = "e9e52501ad92d2ee2dac97e66f601a0458404986" -__date__ = "2024-06-26 02:37:50 -0400" +__tag__ = "v0.1.0-16-g2c27dfa" +__commit__ = "2c27dfafaab94dacf92cc7bbcf610d8b4886692c" +__date__ = "2024-07-02 20:05:12 +0000" diff --git a/milabench/commands/executors.py b/milabench/commands/executors.py index 11d5ddb72..c43f17bc4 100644 --- a/milabench/commands/executors.py +++ b/milabench/commands/executors.py @@ -71,14 +71,16 @@ async def execute_command( coro.append(fut) warden.extend(pack.processes) - if timeout: - delay = pack.config.get("max_duration", timeout_delay) - timeout_task = asyncio.create_task(force_terminate(pack, delay)) - timeout_tasks.append(timeout_task) - - results = await asyncio.gather(*coro) - if timeout: - for task in timeout_tasks: - task.cancel() - return results + delay = pack.config.get("max_duration", timeout_delay) + + try: + async with asyncio.timeout(delay): + return await asyncio.gather(*coro) + + except TimeoutError: + await force_terminate(pack, delay) + return [-1 for _ in coro] + + return await asyncio.gather(*coro) +