Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jul 3, 2024
1 parent 9425a5b commit cc5fa04
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 30 deletions.
9 changes: 2 additions & 7 deletions .github/workflows/tests_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,13 @@ jobs:

- name: dependencies
run: |
if [[ ! -d "~/.cargo/bin" ]]; then
wget --no-check-certificate --secure-protocol=TLSv1_2 -qO- https://sh.rustup.rs | sh -s -- -y
fi
export PATH="~/.cargo/bin:${PATH}"
python -m pip install -U pip
python -m pip install -U poetry
- name: install
run: |
pip install pytest
poetry lock --no-update
pip install -e .
poetry install --with dev
- name: tests
env:
Expand All @@ -74,4 +69,4 @@ jobs:
POSTGRES_DB: milabench
POSTGRES_HOST: localhost
POSTGRES_PORT: 5432
run: pytest tests/integration
run: poetry run pytest tests/integration
4 changes: 3 additions & 1 deletion .github/workflows/tests_unit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ jobs:
- name: dependencies
run: |
pip install poetry
poetry env use python3.10
poetry install --with dev
- name: tests
run: |
poetry run pytest --ignore=tests/integration tests/
source $(poetry env info -p)/bin/activate
pytest --ignore=tests/integration tests/
26 changes: 16 additions & 10 deletions benchmate/benchmate/warden.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ def __exit__(self, *args):
def destroy(*processes, step=1, timeout=30):
processes = list(processes)

def kill(proc, signal):
def kill(proc, sig):
try:
if getattr(proc, "did_setsid", False):
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
os.killpg(os.getpgid(proc.pid), sig)
else:
os.kill(proc.pid, signal.SIGTERM)
os.kill(proc.pid, sig)
except ProcessLookupError:
pass

Expand All @@ -249,11 +249,9 @@ def kill(proc, signal):
elapsed = 0
def wait(proc):
nonlocal elapsed

while (ret := proc.poll()) is None and elapsed < timeout:
time.sleep(step)
elapsed += step

return ret is None

k = 0
Expand All @@ -280,25 +278,33 @@ def wait(proc):


@contextmanager
def process_cleaner():
def process_cleaner(timeout=30):
"""Delay signal handling until all the processes have been killed"""

with Protected():
def kill_everything(processes, warden):
def _():
warden.terminate()
destroy(*processes, timeout=timeout)
warden.kill()
return _

with Protected() as signalhandler:
with GPUProcessWarden() as warden: # => SIGTERM all processes using GPUs
processes = []
try: # NOTE: we have not waited much between both signals

signalhandler.stop = kill_everything(processes, warden)

try: # NOTE: we have not waited much between both signals
warden.kill() # => SIGKILL all processes using GPUs

yield processes # => Run milabench, spawning processes for the benches

finally:
warden.terminate() # => SIGTERM all processes using GPUs

destroy(*processes) # => SIGTERM+SIGKILL milabench processes
destroy(*processes, timeout=timeout) # => SIGTERM+SIGKILL milabench processes

# destroy waited 30s

# warden.__exit__ # => SIGKILL all processes still using GPUs


6 changes: 4 additions & 2 deletions milabench/commands/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ async def execute_command(
for pack in command.packs():
pack.phase = phase

timeout_tasks = []
with process_cleaner() as warden:
for pack, argv, _kwargs in command.commands():
await pack.send(event="config", data=pack.config)
Expand All @@ -77,7 +76,10 @@ async def execute_command(
try:
return await asyncio.wait_for(asyncio.gather(*coro), timeout=delay)

except TimeoutError | asyncio.TimeoutError:
except TimeoutError:
await force_terminate(pack, delay)
return [-1 for _ in coro]
except asyncio.TimeoutError:
await force_terminate(pack, delay)
return [-1 for _ in coro]

Expand Down
39 changes: 29 additions & 10 deletions tests/benchmate/test_protected.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,30 @@ def _worker(delay):
print('done')


def fake_poll(self):
def poll(*args, **kwargs):
try:
if self.exitcode is not None:
return self.exitcode

if self.alive():
pid, rc = os.waitpid(self.pid, 0)
return rc
else:
return 0
except:
return None

return poll

def spawn(delay, warden):
procs = []
for _ in range(10):
proc = multiprocessing.Process(target=_worker, args=(delay,))
proc.start()
proc.poll = fake_poll(proc)
procs.append(proc)
warden.add_process(proc)
warden.append(proc)

return procs

Expand Down Expand Up @@ -58,7 +74,7 @@ def test_process_cleaner_process():
proc = multiprocessing.Process(target=_protected_process, args=(60,))
proc.start()

time.sleep(1)
time.sleep(2)
os.kill(proc.pid, signal.SIGINT)

elapsed = time.time() - start
Expand All @@ -69,16 +85,19 @@ def test_keyboard_cleaner_process():
start = time.time()

with pytest.raises(KeyboardInterrupt):
with process_cleaner() as warden:
with process_cleaner(10) as warden:
procs = spawn(60, warden)

assert len(warden) != 0
print(warden[0])

time.sleep(1)
os.kill(os.getpid(), signal.SIGINT)

wait(procs)

elapsed = time.time() - start
assert elapsed < 30
assert elapsed < 12


def test_keyboard_cleaner_process_ended():
Expand Down Expand Up @@ -106,10 +125,10 @@ def ctor(*args, **kwargs):
return kwargs

with pytest.raises(KeyboardInterrupt):
with process_cleaner() as warden:
with process_cleaner(timeout=10) as warden:
mx = Multiplexer(timeout=0, constructor=ctor)
proc = mx.start(["sleep", "60"], info={}, env={}, **{})
warden.add_process(proc)
warden.append(proc)

time.sleep(2)
os.kill(os.getpid(), signal.SIGINT)
Expand All @@ -119,7 +138,7 @@ def ctor(*args, **kwargs):
print(entry)

elapsed = time.time() - start
assert elapsed < 30
assert elapsed < 12


def test_protected_multiplexer_ended():
Expand All @@ -128,10 +147,10 @@ def test_protected_multiplexer_ended():
start = time.time()

with pytest.raises(KeyboardInterrupt):
with process_cleaner() as warden:
with process_cleaner(timeout=10) as warden:
mx = Multiplexer(timeout=0, constructor=lambda **kwargs: kwargs)
proc = mx.start(["sleep", "1"], info={}, env={}, **{})
warden.add_process(proc)
warden.append(proc)

time.sleep(2)
os.kill(os.getpid(), signal.SIGINT)
Expand All @@ -141,4 +160,4 @@ def test_protected_multiplexer_ended():
print(entry)

elapsed = time.time() - start
assert elapsed < 30
assert elapsed < 10
Empty file added tests/test_mock_run.py
Empty file.

0 comments on commit cc5fa04

Please sign in to comment.