Skip to content

Commit

Permalink
Add get_time_to_job, skip test if wait is too long
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 15, 2024
1 parent c04dd71 commit 289de0e
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 3 deletions.
65 changes: 65 additions & 0 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import contextvars
import datetime
import functools
import itertools
import multiprocessing
Expand Down Expand Up @@ -390,3 +391,67 @@ def add_arguments(self, actions):
actions, key=lambda action: not isinstance(action, _HelpAction)
)
super().add_arguments(actions)


def td_format(td_object: datetime.timedelta) -> str:
"""Represent an (absolute) `datetime.timedelta` in text.
When negative, the delta will be represented in the same way as if it were positive.
>>> td_format(datetime.timedelta(days=1, hours=2, minutes=3, seconds=4))
'1 day, 2 hours, 3 minutes and 4 seconds'
>>> td_format(datetime.timedelta(seconds=1))
'1 second'
>>> td_format(datetime.timedelta(seconds=0))
'0 seconds'
>>> td_format(datetime.timedelta(seconds=-1, days=-1))
'1 day and 1 second'
Slightly modified from https://stackoverflow.com/a/13756038/6388696
"""
td_object = abs(td_object)
seconds = int(td_object.total_seconds())
if seconds == 0:
return "0 seconds"
periods = [
("year", 60 * 60 * 24 * 365),
("month", 60 * 60 * 24 * 30),
("day", 60 * 60 * 24),
("hour", 60 * 60),
("minute", 60),
("second", 1),
]
strings: list[str] = []
for period_name, period_seconds in periods:
if seconds >= period_seconds:
period_value, seconds = divmod(seconds, period_seconds)
has_s = "s" if period_value > 1 else ""
strings.append(f"{period_value} {period_name}{has_s}")
if len(strings) > 1:
return ", ".join(strings[:-1]) + " and " + strings[-1]
return strings[0]


def td_format_from_now(td_object: datetime.timedelta) -> str:
"""Represent a `datetime.timedelta` from now, in text.
Can also be negative.
>>> td_format(datetime.timedelta(days=1, hours=2, minutes=3, seconds=4))
'in 1 day, 2 hours, 3 minutes and 4 seconds'
>>> td_format(datetime.timedelta(seconds=1))
'in 1 second'
>>> td_format(datetime.timedelta(seconds=0))
'now'
>>> td_format(datetime.timedelta(seconds=-1, days=-1))
'1 day and 1 second ago'
Slightly modified from https://stackoverflow.com/a/13756038/6388696
"""
seconds = int(td_object.total_seconds())
if seconds == 0:
return "now"
delta_text = td_format(td_object)
if seconds > 0:
return f"in {delta_text}"
return f"{delta_text} ago"
42 changes: 39 additions & 3 deletions milatools/utils/compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,39 @@ async def cancel_new_jobs_on_interrupt(login_node: RemoteV2, job_name: str):
raise


async def get_time_to_job(
login_node_v2: RemoteV2, allocation_flags: list[str]
) -> datetime.timedelta:
"""Gets an estimate of the time before the job starts using `sbatch --test-only`.
If the job could have started already, we return a timedelta of 0 seconds.
"""
sbatch_test_command = (
"sbatch --test-only " + shlex.join(allocation_flags) + " --wrap 'srun sleep 7d'"
)
if login_node_v2.hostname in DRAC_CLUSTERS:
# Can't run `sbatch` from $HOME in these clusters.
sbatch_test_command = f"cd $SCRATCH && {sbatch_test_command}"
out = await login_node_v2.run_async(sbatch_test_command, display=False, hide=True)
# Example stderr from the above command:
# sbatch: Job 4600173 to start at 2024-04-15T10:27:57 using 1 processors on nodes cn-b004 in partition long
pattern = re.compile(
r"Job [0-9]+ to start at ([0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2})"
)
start_time_string = pattern.search(out.stderr)
if not start_time_string:
raise ValueError(f"Could not find the start time in the output: {out.stderr}")
start_time_string = start_time_string.group(1)
# datetime from this string:
# 2024-04-15T10:27:57
# start_time = datetime.datetime.fromisoformat(start_time_string)
datetime_format = "%Y-%m-%dT%H:%M:%S"
start_time = datetime.datetime.strptime(start_time_string, datetime_format)
now = datetime.datetime.now()
# Can also get a small negative delta if the job could have started immediately
return max(datetime.timedelta(seconds=0), start_time - now)


async def salloc(
login_node: RemoteV2, salloc_flags: list[str], job_name: str
) -> ComputeNode:
Expand Down Expand Up @@ -288,8 +321,9 @@ async def sbatch(
`sacct` command.
"""
# idea: Find the job length from the sbatch flags if possible so we can do
# --wrap='sleep {job_duration}' instead of 'sleep 7d' so the job doesn't look
# like it failed or was interrupted, just cleanly exits before the end time.
# --wrap='sleep {job_duration}' instead of 'sleep 7d'.
# todo: Should we use --ntasks=1 --overlap in the wrapped `srun`, so that only one
# task sleeps? Does that change anything?
sbatch_command = (
"sbatch --parsable " + shlex.join(sbatch_flags) + " --wrap 'srun sleep 7d'"
)
Expand All @@ -307,7 +341,9 @@ async def sbatch(
try:
await wait_while_job_is_pending(login_node, job_id)
except (KeyboardInterrupt, asyncio.CancelledError):
console.log(f"Received KeyboardInterrupt, cancelling job {job_id}")
console.log(
f"Received KeyboardInterrupt, cancelling job {job_id}", style="yellow"
)
login_node.run(f"scancel {job_id}", display=True, hide=False)
raise

Expand Down
18 changes: 18 additions & 0 deletions tests/utils/test_compute_node.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import asyncio
import datetime
import re
from logging import getLogger as get_logger

import pytest
import pytest_asyncio

from milatools.cli.utils import td_format
from milatools.utils.compute_node import (
ComputeNode,
get_queued_milatools_job_ids,
get_time_to_job,
salloc,
sbatch,
)
Expand Down Expand Up @@ -146,6 +149,21 @@ class TestComputeNode(RunnerTests):
async def runner(
self, login_node_v2: RemoteV2, persist: bool, allocation_flags: list[str]
):
# IDEA: Check how long it would take to get an allocation. If it takes too long,
# skip the tests.
# TODO: Add this to `mila code` and others.
time_to_job = await get_time_to_job(login_node_v2, allocation_flags)
if time_to_job > datetime.timedelta(minutes=5):
pytest.skip(
reason="It would take a long time to get the allocation to run tests."
)
elif time_to_job:
logger.info(
f"The job should start in approximately {td_format(time_to_job)}."
)
else:
logger.info("The job is expected to start as soon as requested.")

if persist:
runner = await sbatch(
login_node_v2, sbatch_flags=allocation_flags, job_name="mila-code"
Expand Down

0 comments on commit 289de0e

Please sign in to comment.