Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jun 26, 2024
1 parent 94b27a7 commit 78bfdfc
Show file tree
Hide file tree
Showing 12 changed files with 394 additions and 183 deletions.
9 changes: 1 addition & 8 deletions benchmarks/accelerate_opt/benchfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,13 @@
)
from milabench.pack import Package
from milabench.utils import select_nodes
from milabench.sizer import resolve_argv


def resolve_placeholder(pack, name):
placeholder = pack.config["argv"][name]
return resolve_argv(pack, [placeholder])


class AccelerateBenchmark(Package):
base_requirements = "requirements.in"

def make_env(self):
env = super().make_env()
value = resolve_placeholder(pack, "--cpus_per_gpu")
value = self.resolve_argument("--cpus_per_gpu")
env["OMP_NUM_THREADS"] = str(value)
return env

Expand Down
5 changes: 1 addition & 4 deletions benchmate/benchmate/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,7 @@ def pytorch(folder, batch_size, num_workers, distributed=False, epochs=60):
def synthetic(model, batch_size, fixed_batch):
return SyntheticData(
tensors=generate_tensor_classification(
model,
batch_size,
(3, 244, 244),
device=accelerator.fetch_device(0)
model, batch_size, (3, 244, 244), device=accelerator.fetch_device(0)
),
n=1000,
fixed_batch=fixed_batch,
Expand Down
4 changes: 2 additions & 2 deletions benchmate/benchmate/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def monitor_fn():
}
for gpu in get_gpu_info()["gpus"].values()
}
mblog({"task": "main", "gpudata": data})
mblog({"task": "train", "gpudata": data})

monitor_fn()
monitor = Monitor(3, monitor_fn)
Expand Down Expand Up @@ -74,7 +74,7 @@ def monitor_fn():
}
for gpu in get_gpu_info()["gpus"].values()
}
return {"task": "main", "gpudata": data, "time": time.time(), "units": "s"}
return {"task": "train", "gpudata": data, "time": time.time(), "units": "s"}

monitor = CustomMonitor(0.5, monitor_fn)

Expand Down
8 changes: 7 additions & 1 deletion benchmate/benchmate/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class BenchObserver:
"""

def __init__(
self, *args, backward_callback=None, step_callback=None, stdout=False, rank=None, **kwargs
self,
*args,
backward_callback=None,
step_callback=None,
stdout=False,
rank=None,
**kwargs,
):
self.wrapped = None
self.args = args
Expand Down
210 changes: 184 additions & 26 deletions benchmate/benchmate/warden.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,89 @@
from dataclasses import dataclass
import re
import logging
import os
import re
import signal
import subprocess
import time
import traceback
import signal
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass

from voir.instruments.gpu import get_gpu_info
from milabench.syslog import syslog
from voir.instruments.gpu import get_gpu_info

log = logging.getLogger(__name__)


@dataclass
class ProcessInfo:
gpu: int
pid: int
type: str
process_name: str
memory: int
unit: str
gpu_name: int = None
type: str = None
process_name: str = None
memory: int = None
used_memory: str = None
gpu_bus_id: str = None
gpu_serial: str = None
gpu_uuid: str = None


def _hpu_parse_processes():
output = subprocess.check_output(["hl-smi"], text=True)

line_format = re.compile(
r"\|(\s+)(?P<gpu>\d+)(\s+)(?P<pid>\d+)(\s+)(?P<type>\w+)(\s+)(?P<process_name>\w+)(\s+)(?P<memory>\d+)((?P<unit>\w+))(\s+)"
r"\|(\s+)(?P<gpu_name>\d+)(\s+)(?P<pid>\d+)(\s+)(?P<type>\w+)(\s+)(?P<process_name>\w+)(\s+)(?P<memory>\d+)((?P<used_memory>\w+))(\s+)"
)

info = []
for line in output.split("\n"):
if match := line_format.match(line):
info.append(ProcessInfo(**match.groupdict()))

return info


def _cuda_parse_processes():
metrics = [
"pid",
"gpu_name",
"gpu_bus_id",
"gpu_serial",
"gpu_uuid",
"process_name",
"used_memory",
]
query = ",".join(metrics)
cmd = ["nvidia-smi", f"--query-compute-apps={query}", "--format=csv,noheader"]
output = subprocess.check_output(cmd, text=True)

info = []
for line in output.split("\n"):
frags = line.splti(",")
info.append(ProcessInfo(**dict(*zip(metrics, frags))))
return info


def _default():
return []


backends = {
"hpu": _hpu_parse_processes,
"cpu": _default
"cuda": _cuda_parse_processes,
# ROCM
# XPU
"cpu": _default,
}


class GPUProcessWarden:
"""Ensure all the process using the GPU are killed before & after the bench"""

def __init__(self, kill_on_start=True, kill_on_end=True):
self.gpus = get_gpu_info()
self.arch = self.gpus['arch']
self.arch = self.gpus["arch"]
self.fetch_fun = backends.get(self.arch, _default)
self.kill_on_start = kill_on_start
self.kill_on_end = kill_on_end
Expand All @@ -56,19 +92,19 @@ def __init__(self, kill_on_start=True, kill_on_end=True):
def __enter__(self):
if self.kill_on_start:
self.ensure_free()

return self

def __exit__(self, *args):
if self.kill_on_end:
self.ensure_free()

return None

def fetch_processes(self):
try:
return self.fetch_fun()
except :
except:
traceback.print_exc()
return []

Expand All @@ -77,26 +113,148 @@ def kill(self, pid, signal):
return

try:
os.kill(pid, signal):
os.kill(pid, signal)
except ProcessLookupError:
self.dead_processes.append(pid)

def ensure_free(self):
processes = self.fetch_processes()
if len(processes) == 0:
return

syslog("Found {0} still using devices after bench ended", len(processes))

# Keyboard interrupt
for process in processes:
self.kill(process.pid, signal.SIGINT)
syslog("Found {0} still using devices after bench ended", len(processes))

# Sig Term, please close now
for process in processes:
self.kill(process.pid, signal.SIGTERM)

# Sig Kill, just die

for process in processes:
self.kill(process.pid, signal.SIGKILL)


class Protected:
"""Prevent a signal to be raised during the execution of some code"""

def __init__(self):
self.signal_received = False
self.handlers = {}
self.start = 0
self.delayed = 0
self.signal_installed = False

def __enter__(self):
"""Override the signal handlers with our delayed handler"""
self.signal_received = False

try:
self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler)
self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler)
self.signal_installed = True

except ValueError: # ValueError: signal only works in main thread
warnings.warn(
"SIGINT/SIGTERM protection hooks could not be installed because "
"Runner is executing inside a thread/subprocess, results could get lost "
"on interruptions"
)

return self

def stop(self):
pass

def handler(self, sig, frame):
"""Register the received signal for later"""

log.warning("Delaying signal %d to finish operations", sig)
log.warning(
"Press CTRL-C again to terminate the program now (You may lose results)"
)

self.start = time.time()
self.signal_received = (sig, frame)

# if CTRL-C is pressed again the original handlers will handle it
# and make the program stop
self.restore_handlers()

self.stop()

def restore_handlers(self):
"""Restore old signal handlers"""
if not self.signal_installed:
return

signal.signal(signal.SIGINT, self.handlers[signal.SIGINT])
signal.signal(signal.SIGTERM, self.handlers[signal.SIGTERM])

self.signal_installed = False

def maybe_stop(self):
"""Raise the delayed signal if any or restore the old signal handlers"""

if not self.signal_received:
self.restore_handlers()

else:
self.delayed = time.time() - self.start

log.warning("Termination was delayed by %.4f s", self.delayed)
handler = self.handlers[self.signal_received[0]]

if callable(handler):
handler(*self.signal_received)

def __exit__(self, *args):
self.maybe_stop()


def destroy(*processes, step=1, timeout=30):
def kill(proc, signal):
try:
if getattr(proc, "did_setsid", False):
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
else:
os.kill(proc.pid, signal.SIGTERM)
except ProcessLookupError:
pass

for proc in processes:
kill(proc, signal.SIGTERM)

# Wait a total amout of time, not per process
elapsed = 0

def wait(proc):
nonlocal elapsed

while (ret := proc.poll()) is None and elapsed < timeout:
time.sleep(step)
elapsed += step

return ret is None

for proc in processes:
if wait(proc):
kill(proc, signal.SIGKILL)


class SignalProtected(Protected):
"""Delay event handling until all the processes are killed"""

def __init__(self):
super().__init__()
self.processes = []

def add_process(self, *processes):
self.processes.extend(processes)

def stop(self):
destroy(*self.processes)


@contextmanager
def process_cleaner():
"""Delay signal handling until all the processes have been killed"""

with SignalProtected() as warden:
yield warden
6 changes: 3 additions & 3 deletions milabench/_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""This file is generated, do not modify"""

__tag__ = "v0.1.0-20-g7246295a"
__commit__ = "7246295a356186b55fa4b2b75480e3700c279b15"
__date__ = "2024-06-20 09:18:17 -0400"
__tag__ = "v0.1.0-30-g94b27a71"
__commit__ = "94b27a71145d3ba754a2713aeca60e5a28be4bc5"
__date__ = "2024-06-25 13:49:52 -0400"
Loading

0 comments on commit 78bfdfc

Please sign in to comment.