From acec2f86940fb3974eca54e28329e20a16a6c08b Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Fri, 28 Jun 2024 13:40:52 -0400 Subject: [PATCH] Enable GPU warden --- benchmate/benchmate/warden.py | 108 +++++++++++++++++++++----------- milabench/commands/executors.py | 2 +- 2 files changed, 74 insertions(+), 36 deletions(-) diff --git a/benchmate/benchmate/warden.py b/benchmate/benchmate/warden.py index 5b551464a..2a5292b88 100644 --- a/benchmate/benchmate/warden.py +++ b/benchmate/benchmate/warden.py @@ -6,7 +6,7 @@ import time import traceback import warnings -from concurrent.futures import ThreadPoolExecutor +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -60,8 +60,9 @@ def _cuda_parse_processes(): info = [] for line in output.split("\n"): - frags = line.splti(",") - info.append(ProcessInfo(**dict(*zip(metrics, frags)))) + if line: + frags = line.split(",") + info.append(ProcessInfo(**dict(zip(metrics, frags)))) return info @@ -83,7 +84,7 @@ class GPUProcessWarden: def __init__(self, kill_on_start=True, kill_on_end=True): self.gpus = get_gpu_info() - self.arch = self.gpus["arch"] + self.arch = self.gpus.get("arch", "cpu") self.fetch_fun = backends.get(self.arch, _default) self.kill_on_start = kill_on_start self.kill_on_end = kill_on_end @@ -91,13 +92,13 @@ def __init__(self, kill_on_start=True, kill_on_end=True): def __enter__(self): if self.kill_on_start: - self.ensure_free() + self.terminate(True) return self def __exit__(self, *args): if self.kill_on_end: - self.ensure_free() + self.kill(False) return None @@ -108,27 +109,43 @@ def fetch_processes(self): traceback.print_exc() return [] - def kill(self, pid, signal): + def _kill(self, pid, signal): if pid in self.dead_processes: return try: - os.kill(pid, signal) + os.kill(int(pid), signal) + except PermissionError: + syslog("PermissionError: Could not kill process {0}", pid) + self.dead_processes.append(pid) except ProcessLookupError: self.dead_processes.append(pid) - def ensure_free(self): - processes = self.fetch_processes() - if len(processes) == 0: - return + def terminate(self, start=False): + self.ensure_free(signal.SIGTERM, start) - syslog("Found {0} still using devices after bench ended", len(processes)) + def kill(self, start=True): + self.ensure_free(signal.SIGKILL, start) - for process in processes: - self.kill(process.pid, signal.SIGTERM) - - for process in processes: - self.kill(process.pid, signal.SIGKILL) + def ensure_free(self, signal, start): + try: + processes = self.fetch_processes() + if len(processes) == 0: + return + + if start: + syslog("Found {0} still using devices before bench started", len(processes)) + else: + syslog("Found {0} still using devices after bench ended", len(processes)) + + # Those processes might not be known by milabench + # Depending on whose the parent the reaping might be happening later + # and we cannot wait for the process to die + # we could try something like os.waitpid but it might interfere with `SignalProtected` + for process in processes: + self._kill(process.pid, signal.SIGTERM) + except: + traceback.print_exc() class Protected: @@ -210,6 +227,8 @@ def __exit__(self, *args): def destroy(*processes, step=1, timeout=30): + processes = list(processes) + def kill(proc, signal): try: if getattr(proc, "did_setsid", False): @@ -224,7 +243,6 @@ def kill(proc, signal): # Wait a total amout of time, not per process elapsed = 0 - def wait(proc): nonlocal elapsed @@ -234,28 +252,48 @@ def wait(proc): return ret is None - for proc in processes: - if wait(proc): - kill(proc, signal.SIGKILL) - - -class SignalProtected(Protected): - """Delay event handling until all the processes are killed""" + k = 0 + start = - time.time() + stats = defaultdict(int) + while processes and (start + time.time()) < timeout: + finished = [] - def __init__(self): - super().__init__() - self.processes = [] + for i, proc in enumerate(processes): + if wait(proc): + kill(proc, signal.SIGKILL) + stats["alive"] += 1 + else: + stats["dead"] += 1 + finished.append(i) - def add_process(self, *processes): - self.processes.extend(processes) + for i in reversed(finished): + del processes[i] - def stop(self): - destroy(*self.processes) + k += 1 + + syslog("{0} processes needed to be killed, retried {1} times, waited {2} s", stats["alive"], k, elapsed) @contextmanager def process_cleaner(): """Delay signal handling until all the processes have been killed""" - with SignalProtected() as warden: - yield warden + with Protected(): + with GPUProcessWarden() as warden: # => SIGTERM all processes using GPUs + processes = [] + try: # NOTE: we have not waited much between both signals + + warden.kill() # => SIGKILL all processes using GPUs + + yield processes # => Run milabench, spawning processes for the benches + + finally: + warden.terminate() # => SIGTERM all processes using GPUs + + destroy(*processes) # => SIGTERM+SIGKILL milabench processes + + # destroy waited 30s + + # warden.__exit__ # => SIGKILL all processes still using GPUs + + diff --git a/milabench/commands/executors.py b/milabench/commands/executors.py index ce043b3b3..11d5ddb72 100644 --- a/milabench/commands/executors.py +++ b/milabench/commands/executors.py @@ -69,7 +69,7 @@ async def execute_command( fut = execute(pack, *argv, **{**_kwargs, **kwargs}) coro.append(fut) - warden.add_process(*pack.processes) + warden.extend(pack.processes) if timeout: delay = pack.config.get("max_duration", timeout_delay)