diff --git a/hyakvnc/__main__.py b/hyakvnc/__main__.py index 11d7ad1..a1e2570 100644 --- a/hyakvnc/__main__.py +++ b/hyakvnc/__main__.py @@ -11,6 +11,8 @@ from datetime import datetime from pathlib import Path from typing import Optional, Union + +from hyakvnc.apptainer import ApptainerInstanceInfo from .vncsession import HyakVncSession from .config import HyakVncConfig from .slurmutil import ( @@ -19,6 +21,8 @@ get_historical_job_infos, cancel_job, SbatchCommand, + get_job_status, + wait_for_job_running, ) from .util import wait_for_file, repeat_until from .version import VERSION @@ -110,12 +114,16 @@ def create_node_signal_handler(signal_number, frame): # Template to name the apptainer instance: apptainer_instance_name = f"{app_config.apptainer_instance_prefix}-$SLURM_JOB_ID-{container_name}" # Command to start the apptainer instance: - apptainer_cmd = f"apptainer instance start --writable-tmpfs --cleanenv {container_path} {apptainer_instance_name}" + apptainer_cmd = ( + "apptainer instance start " + + str(container_path) + + " " + + str(apptainer_instance_name) + + " && while true; do sleep 2; done" + ) # Command to start the apptainer instance and keep it running: - apptainer_cmd_with_rest = ( - apptainer_env_vars_string + apptainer_cmd + " && while true; do sleep 10; done" - ) + apptainer_cmd_with_rest = apptainer_env_vars_string + apptainer_cmd # The sbatch wrap functionality allows submitting commands without an sbatch script:t sbatch_opts["wrap"] = apptainer_cmd_with_rest @@ -137,14 +145,10 @@ def create_node_signal_handler(signal_number, frame): logger.info( f"Launched sbatch job {job_id} with account {app_config.account} on partition {app_config.partition}. Waiting for job to start running" ) - try: - wait_for_job_status( - job_id, - states=["RUNNING"], - timeout=app_config.sbatch_post_timeout, - poll_interval=app_config.sbatch_post_poll_interval, - ) - except TimeoutError: + + if not wait_for_job_running( + job_id, timeout=app_config.sbatch_post_timeout, poll_interval=app_config.sbatch_post_poll_interval + ): logger.error(f"Job {job_id} did not start running within {app_config.sbatch_post_timeout} seconds") try: job = get_historical_job_infos(job_id=job_id) @@ -171,37 +175,41 @@ def create_node_signal_handler(signal_number, frame): / f"{real_instance_name}.json" ).expanduser() - logger.info("Waiting for Apptainer instance to start running") - if wait_for_file(str(instance_file), timeout=app_config.sbatch_post_timeout): - logger.info("Apptainer instance started running. Waiting for VNC session to start") - time.sleep(5) - - def get_session(): - try: - sessions = HyakVncSession.find_running_sessions(app_config, job_id=job_id) - if sessions: - my_sessions = [s for s in sessions if s.job_id == job_id] - if my_sessions: - return my_sessions[0] - except LookupError as e: - logger.debug(f"Could not get session info for job {job_id}: {e}") - return None - - sesh = repeat_until(lambda: get_session(), lambda x: x is not None, timeout=app_config.sbatch_post_timeout) - if not sesh: - logger.warning(f"No running VNC sessions found for job {job_id}. Canceling and exiting.") + logger.info(f"Job is running on nodes {job.node_list}. Waiting for Apptainer instance to start running.") + if not wait_for_file(str(instance_file), timeout=app_config.sbatch_post_timeout): + logger.error(f"Could not find instance file at {instance_file} before timeout") + kill_self() + logger.info("Apptainer instance started running. Waiting for VNC session to start") + time.sleep(5) + try: + instance_info = ApptainerInstanceInfo.from_json(instance_file) + sesh = HyakVncSession(job_id, instance_info, app_config) + except (ValueError, FileNotFoundError, RuntimeError) as e: + logger.error("Could not parse instance file: {instance_file}") + kill_self() + else: + time.sleep(1) + try: + sesh.parse_vnc_info() + except RuntimeError as e: + logger.error(f"Could not parse VNC info: {e}") + sesh.stop() kill_self() - else: - if sesh.wait_until_alive(timeout=app_config.sbatch_post_timeout): - print_connection_string(session=sesh) - exit(0) - else: - logger.error("VNC session for SLURM job {job_id} doesn't seem to be alive") + time.sleep(1) + if Path(sesh.vnc_log_file_path).expanduser().is_file(): + if not Path(sesh.vnc_pid_file_path).expanduser().is_file(): + logger.error(f"Could not find PID file for job at {sesh.vnc_pid_file_path}") + with open(sesh.vnc_log_file_path, "r") as f: + log_contents = f.read() + logger.error(f"VNC session for SLURM job {job_id} failed to start. Log contents:\n{log_contents}") sesh.stop() - exit(1) - else: - logger.info(f"Could not find instance file at {instance_file} before timeout") - kill_self() + kill_self() + if not sesh.wait_until_alive(timeout=app_config.sbatch_post_timeout): + logger.error(f"VNC session for SLURM job {job_id} doesn't seem to be alive") + sesh.stop() + kill_self() + print_connection_string(session=sesh) + exit(0) def cmd_stop(job_id: Optional[int] = None, stop_all: bool = False): @@ -382,6 +390,16 @@ def main(): # check_slurm_version() if args.command == "create": + if args.mem: + app_config.mem = args.mem + if args.cpus: + app_config.cpus = args.cpus + if args.time: + app_config.timelimit = f"{args.time}:00:00" + if args.gpus: + app_config.gpus = args.gpus + if args.timeout: + app_config.sbatch_post_timeout = float(args.timeout) try: cmd_create(args.container, dry_run=args.dry_run) except (TimeoutError, RuntimeError) as e: diff --git a/hyakvnc/config.py b/hyakvnc/config.py index 0fa3075..38bd44a 100644 --- a/hyakvnc/config.py +++ b/hyakvnc/config.py @@ -51,10 +51,10 @@ def __init__( ] = "klone.hyak.uw.edu", # intermediate host address between local machine and compute node account: Optional[str] = None, # account to use for sbatch jobs | -A, --account, SBATCH_ACCOUNT partition: Optional[str] = None, # partition to use for sbatch jobs | -p, --partition, SBATCH_PARTITION - cluster: Optional[str] = "klone", # cluster to use for sbatch jobs | --clusters, SBATCH_CLUSTERS + cluster: Optional[str] = None, # cluster to use for sbatch jobs | --clusters, SBATCH_CLUSTERS gpus: Optional[str] = None, # number of gpus to use for sbatch jobs | -G, --gpus, SBATCH_GPUS timelimit: Optional[str] = None, # time limit for sbatch jobs | --time, SBATCH_TIMELIMIT - mem: Optional[str] = "8G", # memory limit for sbatch jobs | --mem, SBATCH_MEM + mem: Optional[str] = None, # memory limit for sbatch jobs | --mem, SBATCH_MEM cpus: Optional[ int ] = 4, # number of cpus to use for sbatch jobs | -c, --cpus-per-task (not settable by env var) @@ -83,9 +83,9 @@ def __init__( get_default_partition(cluster=self.cluster, account=self.account), ) self.gpus = gpus or get_first_env(["HYAKVNC_SLURM_GPUS", "SBATCH_GPUS"], None) - self.timelimit = timelimit or get_first_env(["HYAKVNC_SLURM_TIMELIMIT", "SBATCH_TIMELIMIT"], None) - self.mem = mem or get_first_env(["HYAKVNC_SLURM_MEM", "SBATCH_MEM"], None) - self.cpus = int(cpus or get_first_env(["HYAKVNC_SLURM_CPUS", "SBATCH_CPUS_PER_TASK"])) + self.timelimit = timelimit or get_first_env(["HYAKVNC_SLURM_TIMELIMIT", "SBATCH_TIMELIMIT"], "1:00:00") + self.mem = mem or get_first_env(["HYAKVNC_SLURM_MEM", "SBATCH_MEM"], "2G") + self.cpus = int(cpus or get_first_env(["HYAKVNC_SLURM_CPUS", "SBATCH_CPUS_PER_TASK"], default="2")) self.sbatch_output_path = sbatch_output_path or get_first_env( ["HYAKVNC_SBATCH_OUTPUT_PATH", "SBATCH_OUTPUT"], "/dev/stdout" ) diff --git a/hyakvnc/slurmutil.py b/hyakvnc/slurmutil.py index f9913f4..84e95ac 100755 --- a/hyakvnc/slurmutil.py +++ b/hyakvnc/slurmutil.py @@ -299,6 +299,54 @@ def wait_for_job_status( raise TimeoutError(f"Timed out waiting for job {job_id} to be in one of the following states: {states}") +slurm_states_active = { + "SIGNALING", + "CONFIGURING", + "STAGE_OUT", + "SUSPENDED", + "REQUEUE_HOLD", + "REQUEUE_FED", + "PENDING", + "RESV_DEL_HOLD", + "STOPPED", + "RUNNING", + "RESIZING", + "REQUEUED", +} +slurm_states_success = {"COMPLETED", "COMPLETING"} +slurm_states_cancelled = {"CANCELLED", "REVOKED"} +slurm_states_timeout = {"DEADLINE", "TIMEOUT"} +slurm_states_failed = {"PREEMPTED", "OUT_OF_MEMORY", "FAILED", "NODE_FAIL", "BOOT_FAIL"} + + +def wait_for_job_running(job_id: int, timeout: Optional[float] = None, poll_interval: float = 1.0) -> bool: + """ + Waits for the specified job to be in one of the specified states. + :param job_id: job id to wait for + :param timeout: timeout for waiting for job to be in one of the specified states + :param poll_interval: poll interval for waiting for job to be in one of the specified states + :return: True if the job is in one of the specified states, False otherwise + :raises TimeoutError: if the job is not in one of the specified states after the timeout + """ + begin_time = time.time() + assert isinstance(job_id, int), "Job id must be an integer" + assert (timeout is None) or (timeout > 0), "Timeout must be greater than zero" + assert poll_interval > 0, "Poll interval must be greater than zero" + timeout = timeout or -1.0 + while time.time() < begin_time + timeout: + try: + res = get_job_status(job_id) + except (RuntimeError, LookupError): + return False + else: + if res == "RUNNING": + return True + elif res not in slurm_states_active: + return False + time.sleep(poll_interval) + return False + + def get_historical_job_infos( after: Optional[Union[datetime, timedelta]] = None, before: Optional[Union[datetime, timedelta]] = None, diff --git a/hyakvnc/vncsession.py b/hyakvnc/vncsession.py index 6d57d1b..a960da5 100644 --- a/hyakvnc/vncsession.py +++ b/hyakvnc/vncsession.py @@ -1,5 +1,6 @@ import pprint import re +import time from pathlib import Path from typing import Optional, Union, List, Dict @@ -29,8 +30,8 @@ def __init__( self.apptainer_instance_info = apptainer_instance_info self.app_config = app_config self.vnc_port = None - self.vnc_log_file_path = None - self.vnc_pid_file_path = None + self.vnc_log_file_path = "" + self.vnc_pid_file_path = "" def parse_vnc_info(self) -> None: logOutPath = self.apptainer_instance_info.logOutPath @@ -59,8 +60,6 @@ def parse_vnc_info(self) -> None: if not self.vnc_log_file_path.is_file(): logger.debug(f"Could not find vnc log file at {self.vnc_log_file_path}") self.vnc_pid_file_path = self.vnc_log_file_path.with_suffix(".pid") - if not self.vnc_pid_file_path.is_file(): - logger.debug(f"Could not find vnc PID file at {self.vnc_pid_file_path}") def vnc_pid_file_exists(self) -> bool: if not self.vnc_pid_file_path: @@ -190,7 +189,7 @@ def find_running_sessions(app_config: HyakVncConfig, job_id: Optional[int] = Non try: sesh.parse_vnc_info() except RuntimeError as e: - logger.debug("Could not parse VNC info for session {sesh}: {e}") + logger.debug(f"Could not parse VNC info for session {sesh}: {e}") else: if sesh.is_alive(): logger.debug(f"Session {sesh} is alive")