Skip to content

Commit

Permalink
Use tasks for each subpart of mila code
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 22, 2024
1 parent 7a64300 commit b689b64
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 96 deletions.
97 changes: 69 additions & 28 deletions milatools/cli/code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import argparse
import asyncio
import shlex
import shutil
import sys
Expand All @@ -11,6 +12,7 @@
from milatools.cli import console
from milatools.cli.common import (
check_disk_quota,
check_disk_quota_v1,
find_allocation,
)
from milatools.cli.init_command import DRAC_CLUSTERS
Expand Down Expand Up @@ -103,6 +105,15 @@ def add_mila_code_arguments(subparsers: argparse._SubParsersAction):
code_parser.set_defaults(function=code)


async def _check_disk_quota_task(remote: RemoteV2) -> None:
try:
await check_disk_quota(remote)
except MilatoolsUserError:
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")


async def code(
path: str,
command: str,
Expand Down Expand Up @@ -134,40 +145,35 @@ async def code(

if not path.startswith("/"):
# Get $HOME because we have to give the full path to code
home = login_node.get_output("echo $HOME", display=False, hide=True)
home = await login_node.get_output_async("echo $HOME", display=False, hide=True)
path = home if path == "." else f"{home}/{path}"

try:
check_disk_quota(login_node)
except MilatoolsUserError:
# Raise errors that are meant to be shown to the user (disk quota is reached).
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")
check_disk_quota_task = asyncio.create_task(_check_disk_quota_task(login_node))
# Raise errors that are meant to be shown to the user (disk quota is reached).

# NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an
# unknown cluster?
if no_internet_on_compute_nodes(cluster):
# Sync the VsCode extensions from the local machine over to the target cluster.
console.log(
f"Installing VSCode extensions that are on the local machine on "
f"{cluster} in the background.",
f"{cluster}.",
style="cyan",
)
# todo: use the mila or the local machine as the reference for vscode
# extensions?
# todo: could also perhaps make this function asynchronous instead of using a
# multiprocessing process for it.
copy_vscode_extensions_process = make_process(
sync_vscode_extensions,
LocalV2(),
[login_node],
sync_vscode_extensions_task = asyncio.create_task(
asyncio.to_thread(
sync_vscode_extensions,
LocalV2(),
[login_node],
),
name="sync_vscode_extensions",
)
copy_vscode_extensions_process.start()
# todo: could potentially do this at the same time as the blocks above and just wait
# for the result here, instead of running each block in sequence.
if currently_in_a_test():
copy_vscode_extensions_process.join()

compute_node_task: asyncio.Task[ComputeNode]

if job or node:
if job and node:
Expand All @@ -177,7 +183,9 @@ async def code(
)
job_id_or_node = job or node
assert job_id_or_node is not None
compute_node = await connect_to_running_job(job_id_or_node, login_node)
compute_node_task = asyncio.create_task(
connect_to_running_job(job_id_or_node, login_node)
)
else:
if cluster in DRAC_CLUSTERS and not any("--account" in flag for flag in alloc):
logger.warning(
Expand All @@ -195,24 +203,57 @@ async def code(
# todo: Get the job name from the flags instead?
raise MilatoolsUserError(
"The job name flag (--job-name or -J) should be left unset for now "
"because we use the job name to gage how many people use `mila code` "
"on the various clusters. We also make use of the job name when the "
"call to `salloc` is interrupted before we have a chance to know the "
"job id."
"because we use the job name to measure how many people use `mila "
"code` on the various clusters. We also make use of the job name when "
"the call to `salloc` is interrupted before we have a chance to know "
"the job id."
)
job_name = "mila-code"
alloc = alloc + [f"--job-name={job_name}"]

if persist:
compute_node = await sbatch(
login_node, sbatch_flags=alloc, job_name=job_name
compute_node_task = asyncio.create_task(
sbatch(login_node, sbatch_flags=alloc, job_name=job_name)
)
# compute_node = await sbatch(
# login_node, sbatch_flags=alloc, job_name=job_name
# )
else:
# NOTE: Here we actually need the job name to be known, so that we can
# scancel jobs if the call is interrupted.
compute_node = await salloc(
login_node, salloc_flags=alloc, job_name=job_name
compute_node_task = asyncio.create_task(
salloc(login_node, salloc_flags=alloc, job_name=job_name)
)
# compute_node = await salloc(
# login_node, salloc_flags=alloc, job_name=job_name
# )
try:
_, _, compute_node = await asyncio.gather(
check_disk_quota_task,
sync_vscode_extensions_task,
compute_node_task,
return_exceptions=True,
)
except:
# If any of the tasks failed, we want to raise the exception.
for task in (
check_disk_quota_task,
sync_vscode_extensions_task,
compute_node_task,
):
if not task.done():
task.cancel()
for task in (
check_disk_quota_task,
sync_vscode_extensions_task,
compute_node_task,
):
if exception := task.exception():
raise exception
raise

if isinstance(compute_node, BaseException):
raise compute_node

try:
while True:
Expand Down Expand Up @@ -332,7 +373,7 @@ def code_v1(
command = get_code_command()

try:
check_disk_quota(remote)
check_disk_quota_v1(remote)
except MilatoolsUserError:
raise
except Exception as exc:
Expand Down
27 changes: 23 additions & 4 deletions milatools/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _parse_lfs_quota_output(
return (used_gb, max_gb), (used_files, max_files)


def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None:
async def check_disk_quota(remote: RemoteV2) -> None:
cluster = remote.hostname

# NOTE: This is what the output of the command looks like on the Mila cluster:
Expand All @@ -92,17 +92,36 @@ def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None:
# uid 1471600598 is using default block quota setting
# uid 1471600598 is using default file quota setting

# Need to assert this, otherwise .get_output calls .run which would spawn a job!
assert not isinstance(remote, SlurmRemote)
if not remote.get_output("which lfs", display=False, hide=True):
if not (await remote.get_output_async("which lfs", display=False, hide=True)):
logger.debug("Cluster doesn't have the lfs command. Skipping check.")
return

console.log("Checking disk quota on $HOME...")

home_disk_quota_output = await remote.get_output_async(
"lfs quota -u $USER $HOME", display=False, hide=True
)
_check_disk_quota_common_part(home_disk_quota_output, cluster)


def check_disk_quota_v1(remote: RemoteV1 | RemoteV2) -> None:
cluster = remote.hostname
# Need to check for this, because SlurmRemote is a subclass of RemoteV1 and
# .get_output calls SlurmRemote.run which would spawn a job!
assert not isinstance(remote, SlurmRemote)

if not (remote.get_output("which lfs", display=False, hide=True)):
logger.debug("Cluster doesn't have the lfs command. Skipping check.")
return

console.log("Checking disk quota on $HOME...")
home_disk_quota_output = remote.get_output(
"lfs quota -u $USER $HOME", display=False, hide=True
)
_check_disk_quota_common_part(home_disk_quota_output, cluster)


def _check_disk_quota_common_part(home_disk_quota_output: str, cluster: str):
if "not on a mounted Lustre filesystem" in home_disk_quota_output:
logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.")
return
Expand Down
10 changes: 5 additions & 5 deletions milatools/utils/vscode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_code_command() -> str:
return os.environ.get("MILATOOLS_CODE_COMMAND", "code")


def get_vscode_executable_path(code_command: str | None = None) -> str:
def get_local_vscode_executable_path(code_command: str | None = None) -> str:
if code_command is None:
code_command = get_code_command()

Expand All @@ -73,7 +73,7 @@ def get_vscode_executable_path(code_command: str | None = None) -> str:

def vscode_installed() -> bool:
try:
_ = get_vscode_executable_path()
_ = get_local_vscode_executable_path()
except CommandNotFoundError:
return False
return True
Expand Down Expand Up @@ -227,7 +227,7 @@ def _update_progress(

if isinstance(remote, LocalV2):
assert dest_hostname == "localhost"
code_server_executable = get_vscode_executable_path()
code_server_executable = get_local_vscode_executable_path()
extensions_on_dest = get_local_vscode_extensions()
else:
dest_hostname = remote.hostname
Expand Down Expand Up @@ -332,7 +332,7 @@ def install_vscode_extension(
def get_local_vscode_extensions(code_command: str | None = None) -> dict[str, str]:
output = subprocess.run(
(
get_vscode_executable_path(code_command=code_command),
get_local_vscode_executable_path(code_command=code_command),
"--list-extensions",
"--show-versions",
),
Expand Down Expand Up @@ -398,7 +398,7 @@ def extensions_to_install(


def find_code_server_executable(
remote: RemoteV1 | RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server"
remote: RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server"
) -> str | None:
"""Find the most recent `code-server` executable on the remote.
Expand Down
57 changes: 22 additions & 35 deletions tests/integration/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,41 @@
from pytest_regressions.file_regression import FileRegressionFixture

from milatools.cli.code import code
from milatools.cli.common import check_disk_quota
from milatools.cli.common import check_disk_quota, check_disk_quota_v1
from milatools.cli.utils import get_hostname_to_use_for_compute_node
from milatools.utils.compute_node import ComputeNode
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import RemoteV2

from ..cli.common import skip_param_if_on_github_cloud_ci
from ..conftest import job_name, launches_jobs
from .conftest import (
skip_if_not_already_logged_in,
skip_param_if_not_already_logged_in,
SLURM_CLUSTER,
)
from .test_slurm_remote import get_recent_jobs_info_dicts

logger = get_logger(__name__)


@pytest.mark.slow
@pytest.mark.parametrize(
"cluster",
[
skip_param_if_on_github_cloud_ci("mila"),
skip_param_if_not_already_logged_in("narval"),
skip_param_if_not_already_logged_in("beluga"),
skip_param_if_not_already_logged_in("cedar"),
pytest.param(
"graham",
marks=[
skip_if_not_already_logged_in("graham"),
pytest.mark.xfail(
raises=subprocess.CalledProcessError,
reason="Graham doesn't use a lustre filesystem for $HOME.",
strict=True,
),
],
),
skip_param_if_not_already_logged_in("niagara"),
],
indirect=True,
@pytest.mark.xfail(
SLURM_CLUSTER == "graham",
raises=subprocess.CalledProcessError,
reason="Graham doesn't use a lustre filesystem for $HOME.",
strict=True,
)
def test_check_disk_quota(
login_node: RemoteV1 | RemoteV2,
async def test_check_disk_quota(
login_node_v2: RemoteV2,
capsys: pytest.LogCaptureFixture,
caplog: pytest.LogCaptureFixture,
): # noqa: F811
):
if login_node_v2.hostname == "localhost":
pytest.skip(reason="Test doesn't work on localhost.")

with caplog.at_level(logging.DEBUG):
check_disk_quota(remote=login_node)
# TODO: Maybe figure out a way to actually test this, (not just by running it and
# expecting no errors).
await check_disk_quota(remote=login_node_v2)
check_disk_quota_v1(remote=login_node_v2)
# TODO: Maybe figure out a way to actually test this, (apart from just running it
# and expecting no errors).

# Check that it doesn't raise any errors.
# IF the quota is nearly met, then a warning is logged.
# IF the quota is met, then a `MilatoolsUserError` is logged.
Expand Down Expand Up @@ -95,9 +82,9 @@ async def get_job_info(
@pytest.mark.parametrize("persist", [True, False], ids=["sbatch", "salloc"])
@pytest.mark.parametrize(
job_name.__name__,
[
None,
],
# Don't set the `--job-name` in the `allocation_flags` fixture
# (this is necessary for `mila code` to work properly).
[None],
ids=[""],
indirect=True,
)
Expand Down
14 changes: 0 additions & 14 deletions tests/integration/test_code/test_code_mila0_None_True_.txt

This file was deleted.

2 changes: 1 addition & 1 deletion tests/integration/test_code/test_code_mila__salloc_.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Checking disk quota on $HOME...
Disk usage: X / LIMIT GiB and X / LIMIT files
(mila) $ salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code
(mila) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code
Waiting for job JOB_ID to start.
(local) echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob
--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_code/test_code_mila__sbatch_.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Checking disk quota on $HOME...
Disk usage: X / LIMIT GiB and X / LIMIT files
(mila) $ sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep 7d'
(mila) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep 7d'
JOB_ID

(local) echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob
Expand Down
Loading

0 comments on commit b689b64

Please sign in to comment.