Skip to content

Commit

Permalink
Enable GPU warden
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jun 28, 2024
1 parent 5fc051b commit b9ad2a9
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions benchmate/benchmate/warden.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import traceback
import warnings
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass

Expand Down Expand Up @@ -60,8 +60,8 @@ def _cuda_parse_processes():

info = []
for line in output.split("\n"):
frags = line.splti(",")
info.append(ProcessInfo(**dict(*zip(metrics, frags))))
frags = line.split(",")
info.append(ProcessInfo(**dict(zip(metrics, frags))))
return info


Expand All @@ -83,21 +83,21 @@ class GPUProcessWarden:

def __init__(self, kill_on_start=True, kill_on_end=True):
self.gpus = get_gpu_info()
self.arch = self.gpus["arch"]
self.arch = self.gpus.get("arch", "cpu")
self.fetch_fun = backends.get(self.arch, _default)
self.kill_on_start = kill_on_start
self.kill_on_end = kill_on_end
self.dead_processes = []

def __enter__(self):
if self.kill_on_start:
self.ensure_free()
self.ensure_free(True)

return self

def __exit__(self, *args):
if self.kill_on_end:
self.ensure_free()
self.ensure_free(False)

return None

Expand All @@ -113,22 +113,31 @@ def kill(self, pid, signal):
return

try:
os.kill(pid, signal)
os.kill(int(pid), signal)
except PermissionError:
syslog("PermissionError: Could not kill process {0}", pid)
self.dead_processes.append(pid)
except ProcessLookupError:
self.dead_processes.append(pid)

def ensure_free(self):
processes = self.fetch_processes()
if len(processes) == 0:
return
def ensure_free(self, start):
try:
processes = self.fetch_processes()
if len(processes) == 0:
return

syslog("Found {0} still using devices after bench ended", len(processes))
if start:
syslog("Found {0} still using devices before bench started", len(processes))
else:
syslog("Found {0} still using devices after bench ended", len(processes))

for process in processes:
self.kill(process.pid, signal.SIGTERM)

for process in processes:
self.kill(process.pid, signal.SIGKILL)
for process in processes:
self.kill(process.pid, signal.SIGTERM)

for process in processes:
self.kill(process.pid, signal.SIGKILL)
except:
traceback.print_exc()


class Protected:
Expand Down Expand Up @@ -224,7 +233,6 @@ def kill(proc, signal):

# Wait a total amout of time, not per process
elapsed = 0

def wait(proc):
nonlocal elapsed

Expand All @@ -234,9 +242,26 @@ def wait(proc):

return ret is None

for proc in processes:
if wait(proc):
kill(proc, signal.SIGKILL)
k = 0
start = - time.time()
stats = defaultdict(int)
while processes and (start + time.time()) < timeout:
finished = []

for i, proc in enumerate(processes):
if wait(proc):
kill(proc, signal.SIGKILL)
stats["alive"] += 1
else:
stats["dead"] += 1
finished.append(i)

for i in reversed(finished):
del processes[i]

k += 1

syslog("{0} processes needed to be killed, retried {1} times, waited {2} s", stats["alive"], k, elapsed)


class SignalProtected(Protected):
Expand All @@ -258,4 +283,7 @@ def process_cleaner():
"""Delay signal handling until all the processes have been killed"""

with SignalProtected() as warden:
yield warden
# GPU warden will check PID using GPUs and kill them
# NOTE: the GPU warden is not aware of milabench and kill ANYTHING that use the GPU
with GPUProcessWarden():
yield warden

0 comments on commit b9ad2a9

Please sign in to comment.