Skip to content

Commit

Permalink
Enable GPU warden
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jun 28, 2024
1 parent 5fc051b commit acec2f8
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 36 deletions.
108 changes: 73 additions & 35 deletions benchmate/benchmate/warden.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import traceback
import warnings
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass

Expand Down Expand Up @@ -60,8 +60,9 @@ def _cuda_parse_processes():

info = []
for line in output.split("\n"):
frags = line.splti(",")
info.append(ProcessInfo(**dict(*zip(metrics, frags))))
if line:
frags = line.split(",")
info.append(ProcessInfo(**dict(zip(metrics, frags))))
return info


Expand All @@ -83,21 +84,21 @@ class GPUProcessWarden:

def __init__(self, kill_on_start=True, kill_on_end=True):
self.gpus = get_gpu_info()
self.arch = self.gpus["arch"]
self.arch = self.gpus.get("arch", "cpu")
self.fetch_fun = backends.get(self.arch, _default)
self.kill_on_start = kill_on_start
self.kill_on_end = kill_on_end
self.dead_processes = []

def __enter__(self):
if self.kill_on_start:
self.ensure_free()
self.terminate(True)

return self

def __exit__(self, *args):
if self.kill_on_end:
self.ensure_free()
self.kill(False)

return None

Expand All @@ -108,27 +109,43 @@ def fetch_processes(self):
traceback.print_exc()
return []

def kill(self, pid, signal):
def _kill(self, pid, signal):
if pid in self.dead_processes:
return

try:
os.kill(pid, signal)
os.kill(int(pid), signal)
except PermissionError:
syslog("PermissionError: Could not kill process {0}", pid)
self.dead_processes.append(pid)
except ProcessLookupError:
self.dead_processes.append(pid)

def ensure_free(self):
processes = self.fetch_processes()
if len(processes) == 0:
return
def terminate(self, start=False):
self.ensure_free(signal.SIGTERM, start)

syslog("Found {0} still using devices after bench ended", len(processes))
def kill(self, start=True):
self.ensure_free(signal.SIGKILL, start)

for process in processes:
self.kill(process.pid, signal.SIGTERM)

for process in processes:
self.kill(process.pid, signal.SIGKILL)
def ensure_free(self, signal, start):
try:
processes = self.fetch_processes()
if len(processes) == 0:
return

if start:
syslog("Found {0} still using devices before bench started", len(processes))
else:
syslog("Found {0} still using devices after bench ended", len(processes))

# Those processes might not be known by milabench
# Depending on whose the parent the reaping might be happening later
# and we cannot wait for the process to die
# we could try something like os.waitpid but it might interfere with `SignalProtected`
for process in processes:
self._kill(process.pid, signal.SIGTERM)
except:
traceback.print_exc()


class Protected:
Expand Down Expand Up @@ -210,6 +227,8 @@ def __exit__(self, *args):


def destroy(*processes, step=1, timeout=30):
processes = list(processes)

def kill(proc, signal):
try:
if getattr(proc, "did_setsid", False):
Expand All @@ -224,7 +243,6 @@ def kill(proc, signal):

# Wait a total amout of time, not per process
elapsed = 0

def wait(proc):
nonlocal elapsed

Expand All @@ -234,28 +252,48 @@ def wait(proc):

return ret is None

for proc in processes:
if wait(proc):
kill(proc, signal.SIGKILL)


class SignalProtected(Protected):
"""Delay event handling until all the processes are killed"""
k = 0
start = - time.time()
stats = defaultdict(int)
while processes and (start + time.time()) < timeout:
finished = []

def __init__(self):
super().__init__()
self.processes = []
for i, proc in enumerate(processes):
if wait(proc):
kill(proc, signal.SIGKILL)
stats["alive"] += 1
else:
stats["dead"] += 1
finished.append(i)

def add_process(self, *processes):
self.processes.extend(processes)
for i in reversed(finished):
del processes[i]

def stop(self):
destroy(*self.processes)
k += 1

syslog("{0} processes needed to be killed, retried {1} times, waited {2} s", stats["alive"], k, elapsed)


@contextmanager
def process_cleaner():
"""Delay signal handling until all the processes have been killed"""

with SignalProtected() as warden:
yield warden
with Protected():
with GPUProcessWarden() as warden: # => SIGTERM all processes using GPUs
processes = []
try: # NOTE: we have not waited much between both signals

warden.kill() # => SIGKILL all processes using GPUs

yield processes # => Run milabench, spawning processes for the benches

finally:
warden.terminate() # => SIGTERM all processes using GPUs

destroy(*processes) # => SIGTERM+SIGKILL milabench processes

# destroy waited 30s

# warden.__exit__ # => SIGKILL all processes still using GPUs


2 changes: 1 addition & 1 deletion milabench/commands/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def execute_command(

fut = execute(pack, *argv, **{**_kwargs, **kwargs})
coro.append(fut)
warden.add_process(*pack.processes)
warden.extend(pack.processes)

if timeout:
delay = pack.config.get("max_duration", timeout_delay)
Expand Down

0 comments on commit acec2f8

Please sign in to comment.