Skip to content

Commit

Permalink
[TEMP] Before moving to work from home
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 22, 2024
1 parent f50da46 commit a2822d2
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 105 deletions.
134 changes: 54 additions & 80 deletions milatools/cli/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shutil
import sys
from logging import getLogger as get_logger
from typing import Awaitable

from typing_extensions import deprecated

Expand Down Expand Up @@ -105,15 +106,6 @@ def add_mila_code_arguments(subparsers: argparse._SubParsersAction):
code_parser.set_defaults(function=code)


async def _check_disk_quota_task(remote: RemoteV2) -> None:
try:
await check_disk_quota(remote)
except MilatoolsUserError:
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")


async def code(
path: str,
command: str,
Expand Down Expand Up @@ -148,11 +140,17 @@ async def code(
home = await login_node.get_output_async("echo $HOME", display=False, hide=True)
path = home if path == "." else f"{home}/{path}"

check_disk_quota_task = asyncio.create_task(_check_disk_quota_task(login_node))
# Raise errors that are meant to be shown to the user (disk quota is reached).
try:
await check_disk_quota(login_node)
except MilatoolsUserError:
# Raise errors meant to be shown to the user (disk quota exceeded).
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")

# NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an
# unknown cluster?
sync_vscode_extensions_task: Awaitable | None = None
if no_internet_on_compute_nodes(cluster):
# Sync the VsCode extensions from the local machine over to the target cluster.
console.log(
Expand All @@ -163,18 +161,14 @@ async def code(
# todo: use the mila or the local machine as the reference for vscode
# extensions?
# todo: could also perhaps make this function asynchronous instead of using a
# multiprocessing process for it.
sync_vscode_extensions_task = asyncio.create_task(
asyncio.to_thread(
sync_vscode_extensions,
LocalV2(),
[login_node],
),
name="sync_vscode_extensions",
# different thread to run it (seems to not exit as gracefully now).
sync_vscode_extensions_task = asyncio.to_thread(
sync_vscode_extensions,
LocalV2(),
[login_node],
)

compute_node_task: asyncio.Task[ComputeNode]

compute_node_task: Awaitable[ComputeNode]
if job or node:
if job and node:
logger.warning(
Expand All @@ -183,9 +177,7 @@ async def code(
)
job_id_or_node = job or node
assert job_id_or_node is not None
compute_node_task = asyncio.create_task(
connect_to_running_job(job_id_or_node, login_node)
)
compute_node_task = connect_to_running_job(job_id_or_node, login_node)
else:
if cluster in DRAC_CLUSTERS and not any("--account" in flag for flag in alloc):
logger.warning(
Expand All @@ -212,77 +204,40 @@ async def code(
alloc = alloc + [f"--job-name={job_name}"]

if persist:
compute_node_task = asyncio.create_task(
sbatch(login_node, sbatch_flags=alloc, job_name=job_name)
compute_node_task = sbatch(
login_node, sbatch_flags=alloc, job_name=job_name
)
# compute_node = await sbatch(
# login_node, sbatch_flags=alloc, job_name=job_name
# )
else:
# NOTE: Here we actually need the job name to be known, so that we can
# scancel jobs if the call is interrupted.
compute_node_task = asyncio.create_task(
salloc(login_node, salloc_flags=alloc, job_name=job_name)
compute_node_task = salloc(
login_node, salloc_flags=alloc, job_name=job_name
)
# compute_node = await salloc(
# login_node, salloc_flags=alloc, job_name=job_name
# )
try:
_, _, compute_node = await asyncio.gather(
check_disk_quota_task,

if sync_vscode_extensions_task:
# Sync the extensions while waiting for the compute node to be allocated.
# Also waits for the extensions to be synced before launching VsCode.
_, compute_node = await asyncio.gather(
sync_vscode_extensions_task,
compute_node_task,
return_exceptions=True,
)
except:
# If any of the tasks failed, we want to raise the exception.
for task in (
check_disk_quota_task,
sync_vscode_extensions_task,
compute_node_task,
):
if not task.done():
task.cancel()
for task in (
check_disk_quota_task,
sync_vscode_extensions_task,
compute_node_task,
):
if exception := task.exception():
raise exception
raise

if isinstance(compute_node, BaseException):
raise compute_node
else:
# Wait for the compute node to be allocated.
compute_node = await compute_node_task

try:
while True:
code_command_to_run = (
code_command,
"--new-window",
"--wait",
"--remote",
f"ssh-remote+{compute_node.hostname}",
path,
)
console.log(
f"(local) {shlex.join(code_command_to_run)}", style="bold green"
)
await run_async(code_command_to_run)
# TODO: BUG: This now requires two Ctrl+C's instead of one!
console.print(
"The editor was closed. Reopen it with <Enter> or terminate the "
"process with <Ctrl+C> (maybe twice)."
)
if currently_in_a_test():
# TODO: This early exit kills the job when it is not persistent.
break
input()
await launch_vscode(
compute_node=compute_node,
code_command=code_command,
path=path,
)
except KeyboardInterrupt:
logger.info("Keyboard interrupt.")

if not persist:
# Cancel the job explicitly.
# TODO: Check what happens when calling `del(compute_node)`, perhaps that is
# enough?
await compute_node.close()
console.print(f"Ended session on '{compute_node.hostname}'")
return compute_node.job_id
Expand All @@ -300,6 +255,25 @@ async def code(
return compute_node


async def launch_vscode(compute_node: ComputeNode, code_command: str, path: str):
while True:
code_command_to_run = (
code_command,
"--new-window",
"--wait",
"--remote",
f"ssh-remote+{compute_node.hostname}",
path,
)
console.log(f"(local) {shlex.join(code_command_to_run)}", style="bold green")
await run_async(code_command_to_run)
console.print(
"The editor was closed. Reopen it with <Enter> or terminate the "
"process with <Ctrl+C> (maybe twice)."
)
input()


async def connect_to_running_job(
jobid_or_nodename: int | str,
login_node: RemoteV2,
Expand Down
10 changes: 7 additions & 3 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,15 @@ def mila():
args_dict = _convert_uppercase_keys_to_lowercase(args_dict)

if inspect.iscoroutinefunction(function):
# TODO: Need to let the function handle KeyboardInterrupt by itself, here it
# seems like it never gets there.
try:
return asyncio.run(function(**args_dict))
# NOTE: Not using `asyncio.run` here, because it doesn't exit as cleanly
# when interrupted (prints out some ignored exceptions in
# SubprocessTransport.__del__). Not sure why or what the difference is
# between them.
return asyncio.get_event_loop().run_until_complete(function(**args_dict))
except KeyboardInterrupt:
from milatools.cli import console

console.log("Terminated by user.")
exit()

Expand Down
4 changes: 1 addition & 3 deletions milatools/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def _parse_lfs_quota_output(


async def check_disk_quota(remote: RemoteV2) -> None:
cluster = remote.hostname

# NOTE: This is what the output of the command looks like on the Mila cluster:
#
# Disk quotas for usr normandf (uid 1471600598):
Expand All @@ -101,7 +99,7 @@ async def check_disk_quota(remote: RemoteV2) -> None:
home_disk_quota_output = await remote.get_output_async(
"lfs quota -u $USER $HOME", display=False, hide=True
)
_check_disk_quota_common_part(home_disk_quota_output, cluster)
_check_disk_quota_common_part(home_disk_quota_output, remote.hostname)


def check_disk_quota_v1(remote: RemoteV1 | RemoteV2) -> None:
Expand Down
39 changes: 25 additions & 14 deletions milatools/utils/compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,21 @@ async def close(self):
"""Cancels the running job using `scancel`."""
logger.info(f"Stopping job {self.job_id}.")
if self.salloc_subprocess is not None:
# NOTE: This will exit cleanly because we don't have nested terminals or
# job steps.
if self.salloc_subprocess.stdin is not None:
# NOTE: This will exit cleanly because we don't have nested terminals or
# job steps.
await self.salloc_subprocess.communicate("exit\n".encode()) # noqa: UP012
else:
# todo: Not sure what the best way to do this is..
self.salloc_subprocess.send_signal(signal=signal.SIGINT)
self.salloc_subprocess.send_signal(signal=signal.SIGKILL)
self.salloc_subprocess.kill()
# The scancel below is done even though it's redundant, just to be safe.
await self.login_node.run_async(
f"scancel {self.job_id}", display=True, hide=False
f"scancel {self.job_id}",
display=True,
hide=False,
warn=True,
)

def __repr__(self) -> str:
Expand Down Expand Up @@ -255,8 +261,13 @@ async def salloc(
# eventually be done at the same time as something else (while waiting for the
# job to start) using things like `asyncio.gather` and `asyncio.wait_for`.
logger.debug(f"(local) $ {shlex.join(command)}")
console.log(f"({login_node.hostname}) $ {salloc_command}", style="green")
console.log(
f"({login_node.hostname}) $ {salloc_command}", style="green", markup=False
)
async with cancel_new_jobs_on_interrupt(login_node, job_name):
# BUG: IF stdin is not set (or set to PIPE?) then writing `salloc`, then the
# terminal is actually 'live' and affects the compute node! For instance if
# you do `mila code .` and then write `salloc`, it spawns a second job!!
salloc_subprocess = await asyncio.subprocess.create_subprocess_exec(
*command,
shell=False,
Expand All @@ -281,13 +292,15 @@ async def salloc(

try:
console.log(f"Waiting for job {job_id} to start.", style="green")
_node, state = await wait_while_job_is_pending(login_node, job_id)
await wait_while_job_is_pending(login_node, job_id)
except (KeyboardInterrupt, asyncio.CancelledError):
if salloc_subprocess is not None:
logger.debug("Killing the salloc subprocess following a KeyboardInterrupt.")
salloc_subprocess.send_signal(signal.SIGINT)
salloc_subprocess.terminate()
login_node.run(f"scancel {job_id}", display=True, hide=False)
await salloc_subprocess.communicate("exit\n".encode()) # noqa: UP012
# salloc_subprocess.send_signal(signal.SIGINT)
# salloc_subprocess.send_signal(signal.SIGKILL)
# salloc_subprocess.terminate()
await login_node.run_async(f"scancel {job_id}", display=True, hide=False)
raise

# todo: Are there are states between `PENDING` and `RUNNING`?
Expand Down Expand Up @@ -385,7 +398,7 @@ async def _wait_while_job_is_in_state(login_node: RemoteV2, job_id: int, state:
f"Job {job_id} was allocated node(s) {nodes!r} and is in state "
f"{current_state!r}."
)
return nodes, current_state
return current_state

waiting_until = f"Waiting {wait_time_seconds} seconds until job {job_id} "
condition: str | None = None
Expand Down Expand Up @@ -418,11 +431,9 @@ async def _wait_while_job_is_in_state(login_node: RemoteV2, job_id: int, state:
attempt += 1


async def wait_while_job_is_pending(
login_node: RemoteV2, job_id: int
) -> tuple[str, str]:
async def wait_while_job_is_pending(login_node: RemoteV2, job_id: int) -> str:
"""Waits until a job show up in `sacct` then waits until its state is not PENDING.
Returns the `Node` and `State` from `sacct` after the job is no longer pending.
Returns the `State` from `sacct` after the job is no longer pending.
"""
return await _wait_while_job_is_in_state(login_node, job_id, state="PENDING")
await _wait_while_job_is_in_state(login_node, job_id, state="PENDING")
35 changes: 34 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import questionary
from fabric.connection import Connection

from milatools.cli import console
from milatools.cli.init_command import setup_ssh_config
from milatools.utils.compute_node import get_queued_milatools_job_ids
from milatools.utils.remote_v1 import RemoteV1
Expand Down Expand Up @@ -212,10 +213,19 @@ def test_something(remote: Remote):

@pytest.fixture(scope="session")
def job_name(request: pytest.FixtureRequest) -> str | None:
# TODO: Make the job name different based on the runner that is launching tests, so
# that the `launches_job_fixture` doesn't scancel the test jobs launched from
# another runner (e.g. me on my dev machine or laptop) on a cluster
return getattr(request, "param", JOB_NAME)


@pytest_asyncio.fixture(scope="session")
# TODO: This `launches_jobs` fixture has issues.
# - It can cancel jobs from the self-hosted runner in addition to its own
# - Because it is session-scoped, if one test uses it, then all tests have their jobs
# cancelled at the end of their run. Therefore, might as well make it auto-used, no?


@pytest_asyncio.fixture(scope="session", autouse=True)
async def launches_job_fixture(login_node_v2: RemoteV2, job_name: str):
jobs_before = await get_queued_milatools_job_ids(login_node_v2, job_name=job_name)
if jobs_before:
Expand All @@ -231,13 +241,36 @@ async def launches_job_fixture(login_node_v2: RemoteV2, job_name: str):

new_jobs = jobs_after - jobs_before
if new_jobs:
console.log(f"Cancelling jobs {new_jobs} after running tests...")
login_node_v2.run(
"scancel " + " ".join(str(job_id) for job_id in new_jobs), display=True
)
else:
logger.debug("Test apparently didn't launch any new jobs.")


@pytest_asyncio.fixture(scope="session")
async def register_test_job(login_node_v2: RemoteV2):
"""Fixture that gives a function to register a job that should be cancelled after
tests are done running.
This is more targeted than `launches_jobs_fixture` that does an scancel based on the
job name.
"""
jobs: list[int] = []
register_job = jobs.append
try:
yield register_job
finally:
if jobs:
console.log(f"Cancelling jobs {jobs} after running tests...")
login_node_v2.run(
"scancel " + " ".join(str(job_id) for job_id in jobs), display=True
)
else:
logger.debug("Test apparently didn't launch any new jobs.")


launches_jobs = pytest.mark.usefixtures(launches_job_fixture.__name__)


Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_code/test_code_mila__sbatch_.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Checking disk quota on $HOME...
Disk usage: X / LIMIT GiB and X / LIMIT files
(mila) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep 7d'
(mila) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep
7d'
JOB_ID

(local) echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob
Expand Down
Loading

0 comments on commit a2822d2

Please sign in to comment.