From 42e6fae24043bcc021b3e356a28e55e04aa6a752 Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Wed, 13 Mar 2024 09:54:44 +0100 Subject: [PATCH] tests and updates --- .github/workflows/testing.yml | 6 + src/jobflow_remote/cli/admin.py | 67 ++- src/jobflow_remote/cli/formatting.py | 2 +- src/jobflow_remote/cli/job.py | 4 +- src/jobflow_remote/cli/runner.py | 6 +- src/jobflow_remote/cli/utils.py | 3 + src/jobflow_remote/config/base.py | 4 +- src/jobflow_remote/jobs/daemon.py | 37 +- src/jobflow_remote/jobs/jobcontroller.py | 44 +- src/jobflow_remote/jobs/runner.py | 183 +++++++- src/jobflow_remote/testing/__init__.py | 34 ++ src/jobflow_remote/testing/cli.py | 50 ++ src/jobflow_remote/utils/db.py | 8 +- tests/conftest.py | 43 ++ tests/db/cli/test_admin.py | 75 +++ tests/db/cli/test_flow.py | 51 +++ tests/db/cli/test_job.py | 262 +++++++++++ tests/db/cli/test_project.py | 120 +++++ tests/db/conftest.py | 218 +++++++++ tests/db/jobs/test_daemon.py | 138 ++++++ tests/db/jobs/test_jobcontroller.py | 503 +++++++++++++++++++++ tests/integration/conftest.py | 45 +- tests/integration/test_advanced_options.py | 76 ++++ tests/integration/test_slurm.py | 2 +- 24 files changed, 1917 insertions(+), 64 deletions(-) create mode 100644 src/jobflow_remote/testing/cli.py create mode 100644 tests/db/cli/test_admin.py create mode 100644 tests/db/cli/test_flow.py create mode 100644 tests/db/cli/test_job.py create mode 100644 tests/db/cli/test_project.py create mode 100644 tests/db/conftest.py create mode 100644 tests/db/jobs/test_daemon.py create mode 100644 tests/db/jobs/test_jobcontroller.py create mode 100644 tests/integration/test_advanced_options.py diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2f0757b6..fabd3cf6 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -35,6 +35,12 @@ jobs: run: pre-commit run --all-files --show-diff-on-failure test: + services: + local_mongodb: + image: mongo:4.4 + ports: + - 27017:27017 + runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/src/jobflow_remote/cli/admin.py b/src/jobflow_remote/cli/admin.py index b51fd6b6..df560e32 100644 --- a/src/jobflow_remote/cli/admin.py +++ b/src/jobflow_remote/cli/admin.py @@ -9,6 +9,8 @@ from jobflow_remote.cli.types import ( db_ids_opt, end_date_opt, + flow_ids_opt, + flow_state_opt, force_opt, job_ids_indexes_opt, job_state_opt, @@ -158,17 +160,76 @@ def unlock( if not confirmed: raise typer.Exit(0) + with loading_spinner(False) as progress: + progress.add_task(description="Unlocking jobs...", total=None) + + num_unlocked = jc.unlock_jobs( + job_ids=job_id, + db_ids=db_id, + state=state, + start_date=start_date, + end_date=end_date, + ) + + out_console.print(f"{num_unlocked} jobs were unlocked") + + +@app_admin.command() +def unlock_flow( + job_id: job_ids_indexes_opt = None, + db_id: db_ids_opt = None, + flow_id: flow_ids_opt = None, + state: flow_state_opt = None, + start_date: start_date_opt = None, + end_date: end_date_opt = None, + force: force_opt = False, +): + """ + Forcibly removes the lock from the documents of the selected jobs. + WARNING: can lead to inconsistencies if the processes is actually running + """ + + job_ids_indexes = get_job_ids_indexes(job_id) + + jc = get_job_controller() + + if not force: with loading_spinner(False) as progress: progress.add_task( description="Checking the number of locked documents...", total=None ) - num_unlocked = jc.remove_lock_job( - job_ids=job_id, + flows_info = jc.get_flows_info( + job_ids=job_ids_indexes, db_ids=db_id, + flow_ids=flow_id, state=state, start_date=start_date, + locked=True, end_date=end_date, ) - out_console.print(f"{num_unlocked} jobs were unlocked") + if not flows_info: + exit_with_error_msg("No data matching the request") + + text = Text.from_markup( + f"[red]This operation will [bold]remove the lock[/bold] for (roughly) [bold]{len(flows_info)} Flow(s)[/bold]. Proceed anyway?[/red]" + ) + confirmed = Confirm.ask(text, default=False) + + if not confirmed: + raise typer.Exit(0) + + with loading_spinner(False) as progress: + progress.add_task(description="Unlocking flows...", total=None) + + num_unlocked = jc.unlock_flows( + job_ids=job_id, + db_ids=db_id, + flow_ids=flow_id, + state=state, + start_date=start_date, + end_date=end_date, + ) + + out_console.print(f"{num_unlocked} flows were unlocked") diff --git a/src/jobflow_remote/cli/formatting.py b/src/jobflow_remote/cli/formatting.py index 842f8a59..880e42f8 100644 --- a/src/jobflow_remote/cli/formatting.py +++ b/src/jobflow_remote/cli/formatting.py @@ -167,7 +167,7 @@ def format_flow_info(flow_info: FlowInfo): table.title_style = "bold" table.add_column("DB id") table.add_column("Name") - table.add_column("State [Remote]") + table.add_column("State") table.add_column("Job id (Index)") table.add_column("Worker") diff --git a/src/jobflow_remote/cli/job.py b/src/jobflow_remote/cli/job.py index f0603fcd..b25e428e 100644 --- a/src/jobflow_remote/cli/job.py +++ b/src/jobflow_remote/cli/job.py @@ -228,7 +228,7 @@ def set_state( ) if not succeeded: - exit_with_error_msg("Could not reset the remote attempts") + exit_with_error_msg("Could not change the job state") print_success_msg() @@ -675,7 +675,7 @@ def exec_config( hours=hours, verbosity=verbosity, raise_on_error=raise_on_error, - exec_config_value=exec_config_value, + exec_config=exec_config_value, ) diff --git a/src/jobflow_remote/cli/runner.py b/src/jobflow_remote/cli/runner.py index 68178794..480df6b1 100644 --- a/src/jobflow_remote/cli/runner.py +++ b/src/jobflow_remote/cli/runner.py @@ -91,7 +91,10 @@ def run( if not (transfer or complete or queue or checkout): transfer = complete = queue = checkout = True - runner.run(transfer=transfer, complete=complete, queue=queue, checkout=checkout) + try: + runner.run(transfer=transfer, complete=complete, queue=queue, checkout=checkout) + finally: + runner.cleanup() @app_runner.command() @@ -261,6 +264,7 @@ def status(): DaemonStatus.STOPPING: "gold1", DaemonStatus.SHUT_DOWN: "red", DaemonStatus.PARTIALLY_RUNNING: "gold1", + DaemonStatus.STARTING: "gold1", DaemonStatus.RUNNING: "green", }[current_status] text = Text() diff --git a/src/jobflow_remote/cli/utils.py b/src/jobflow_remote/cli/utils.py index 469594a5..80a6bb33 100644 --- a/src/jobflow_remote/cli/utils.py +++ b/src/jobflow_remote/cli/utils.py @@ -69,6 +69,9 @@ def cleanup_job_controller(): global _shared_job_controller if _shared_job_controller is not None: _shared_job_controller.close() + # set to None again, in case it needs to be used again in the same + # execution (e.g., in tests) + _shared_job_controller = None def start_profiling(): diff --git a/src/jobflow_remote/config/base.py b/src/jobflow_remote/config/base.py index 34dde1ac..71d0435f 100644 --- a/src/jobflow_remote/config/base.py +++ b/src/jobflow_remote/config/base.py @@ -37,12 +37,12 @@ class RunnerOptions(BaseModel): delay_refresh_limited: int = Field( 600, description="Delay between subsequent refresh from the DB of the number of submitted " - "and running jobs (seconds). Only use if a worker with max_jobs is present", + "and running jobs (seconds). Only used if a worker with max_jobs is present", ) delay_update_batch: int = Field( 60, description="Delay between subsequent refresh from the DB of the number of submitted " - "and running jobs (seconds). Only use if a worker with max_jobs is present", + "and running jobs (seconds). Only used if a batch worker is present", ) lock_timeout: Optional[int] = Field( 86400, diff --git a/src/jobflow_remote/jobs/daemon.py b/src/jobflow_remote/jobs/daemon.py index bbd20616..9332d498 100644 --- a/src/jobflow_remote/jobs/daemon.py +++ b/src/jobflow_remote/jobs/daemon.py @@ -118,6 +118,7 @@ class DaemonStatus(Enum): STOPPED = "STOPPED" STOPPING = "STOPPING" PARTIALLY_RUNNING = "PARTIALLY_RUNNING" + STARTING = "STARTING" RUNNING = "RUNNING" @@ -224,7 +225,7 @@ def check_supervisord_process(self) -> bool: logger.warning( f"Process with pid {pid} is not running but daemon files are present. Cleaning them up." ) - self.clean_files() + self.clean_files() return running @@ -258,7 +259,10 @@ def check_status(self) -> DaemonStatus: ) if all(pi.get("state") in RUNNING_STATES for pi in proc_info): - return DaemonStatus.RUNNING + if any(pi.get("state") == ProcessStates.STARTING for pi in proc_info): + return DaemonStatus.STARTING + else: + return DaemonStatus.RUNNING if any(pi.get("state") in RUNNING_STATES for pi in proc_info): return DaemonStatus.PARTIALLY_RUNNING @@ -483,12 +487,31 @@ def _verify_call_result( return None def kill(self, raise_on_error: bool = False) -> bool: - status = self.check_status() - if status == DaemonStatus.SHUT_DOWN: - logger.info("supervisord is not running. No process is running") - return True + # If the daemon is shutting down supervisord may not be able to identify + # the state. Try proceeding in that case, since we really want to kill + # the process + status = None + try: + status = self.check_status() + if status == DaemonStatus.SHUT_DOWN: + logger.info("supervisord is not running. No process is running") + return True + if status == DaemonStatus.STOPPED: + logger.info("Processes are already stopped.") + return True + except DaemonError as e: + msg = ( + f"Error while determining the state of the runner: {getattr(e, 'message', str(e))}." + f"Proceeding with the kill command." + ) + logger.warning(msg) - if status in (DaemonStatus.RUNNING, DaemonStatus.STOPPING): + if status in ( + None, + DaemonStatus.RUNNING, + DaemonStatus.STOPPING, + DaemonStatus.PARTIALLY_RUNNING, + ): interface = self.get_interface() result = interface.supervisor.signalAllProcesses(9) error = self._verify_call_result(result, "kill", raise_on_error) diff --git a/src/jobflow_remote/jobs/jobcontroller.py b/src/jobflow_remote/jobs/jobcontroller.py index 9c72938a..4397d869 100644 --- a/src/jobflow_remote/jobs/jobcontroller.py +++ b/src/jobflow_remote/jobs/jobcontroller.py @@ -959,7 +959,7 @@ def _full_rerun( for dep_id, dep_index in descendants: if max(flow_doc.ids_mapping[dep_id]) > dep_index: raise ValueError( - f"Job {job_id} has a child job ({dep_id}) which is not the last index ({dep_index}. " + f"Job {job_id} has a child job ({dep_id}) which is not the last index ({dep_index}). " "Rerunning the Job will lead to inconsistencies and is not allowed." ) @@ -1078,6 +1078,7 @@ def _set_job_properties( wait: int | None = None, break_lock: bool = False, acceptable_states: list[JobState] | None = None, + use_pipeline: bool = False, ) -> list[int]: """ Helper to set multiple values in a JobDoc while locking the Job. @@ -1088,7 +1089,7 @@ def _set_job_properties( ---------- values Dictionary with the values to be set. Will be passed to a pymongo - `find_one_and_update` method. + `update_one` method. db_id The db_id of the Job. job_id @@ -1105,6 +1106,8 @@ def _set_job_properties( acceptable_states List of JobState for which the Job values can be changed. If None all states are acceptable. + use_pipeline + if True a pipeline will be used in the update of the document Returns ------- list @@ -1135,7 +1138,9 @@ def _set_job_properties( ) values = dict(values) # values["updated_on"] = datetime.utcnow() - lock.update_on_release = {"$set": values} + lock.update_on_release = ( + [{"$set": values}] if use_pipeline else {"$set": values} + ) return [doc["db_id"]] return [] @@ -1256,7 +1261,7 @@ def retry_jobs( """ return self._many_jobs_action( method=self.retry_job, - action_description="rerunning", + action_description="retrying", job_ids=job_ids, db_ids=db_ids, flow_ids=flow_ids, @@ -1876,17 +1881,31 @@ def set_job_run_properties( exec_config = exec_config.model_dump() if update and isinstance(exec_config, dict): - for k, v in exec_config.items(): - set_dict[f"exec_config.{k}"] = v + # if the content is a string replace even if it is an update, + # merging is meaningless + cond = { + "$cond": { + "if": {"$eq": [{"$type": "$exec_config"}, "string"]}, + "then": exec_config, + "else": {"$mergeObjects": ["$exec_config", exec_config]}, + } + } + print(cond) + set_dict["exec_config"] = cond + else: set_dict["exec_config"] = exec_config if resources: if isinstance(resources, QResources): resources = resources.as_dict() + # if passing a QResources it is pointless to update + # all the keywords will be overwritten and if the previous + # value was a generic dictionary the merged dictionary will fail + # almost surely lead to failures + update = False if update: - for k, v in resources.items(): - set_dict[f"resources.{k}"] = v + set_dict["resources"] = {"$mergeObjects": ["$resources", resources]} else: set_dict["resources"] = resources @@ -1904,6 +1923,7 @@ def set_job_run_properties( raise_on_error=raise_on_error, values=set_dict, acceptable_states=[JobState.READY, JobState.WAITING], + use_pipeline=update, ) def get_flow_job_aggreg( @@ -1968,6 +1988,7 @@ def get_flows_info( start_date: datetime | None = None, end_date: datetime | None = None, name: str | None = None, + locked: bool = False, sort: list[tuple] | None = None, limit: int = 0, full: bool = False, @@ -1994,6 +2015,8 @@ def get_flows_info( name Pattern matching the name of Flow. Default is an exact match, but all conventions from python fnmatch can be used (e.g. *test*) + locked + If True only locked Flows will be selected. sort A list of (key, direction) pairs specifying the sort order for this query. Follows pymongo conventions. @@ -2017,6 +2040,7 @@ def get_flows_info( start_date=start_date, end_date=end_date, name=name, + locked=locked, ) # Only use the full aggregation if more job details are needed. @@ -2105,7 +2129,7 @@ def delete_flow(self, flow_id: str, delete_output: bool = False): self.flows.delete_one({"uuid": flow_id}) return True - def remove_lock_job( + def unlock_jobs( self, job_ids: tuple[str, int] | list[tuple[str, int]] | None = None, db_ids: str | list[str] | None = None, @@ -2168,7 +2192,7 @@ def remove_lock_job( ) return result.modified_count - def remove_lock_flow( + def unlock_flows( self, job_ids: str | list[str] | None = None, db_ids: str | list[str] | None = None, diff --git a/src/jobflow_remote/jobs/runner.py b/src/jobflow_remote/jobs/runner.py index 045b54c2..01c96577 100644 --- a/src/jobflow_remote/jobs/runner.py +++ b/src/jobflow_remote/jobs/runner.py @@ -295,23 +295,170 @@ def run( self.update_batch_jobs ) - try: - ticks_remaining: int | bool = True + ticks_remaining: int | bool = True + if ticks is not None: + ticks_remaining = ticks + + while ticks_remaining: + if self.stop_signal: + logger.info("stopping due to sigterm") + break + scheduler.run_pending() + time.sleep(1) + if ticks is not None: - ticks_remaining = ticks + ticks_remaining -= 1 - while ticks_remaining: - if self.stop_signal: - logger.info("stopping due to sigterm") - break - scheduler.run_pending() - time.sleep(1) + def run_all_jobs( + self, + max_seconds: int | None = None, + ): + """ + Use the runner to run all the jobs in the DB. + Mainly used for testing + """ + states = [ + JobState.CHECKED_OUT.value, + JobState.TERMINATED.value, + JobState.DOWNLOADED.value, + JobState.UPLOADED.value, + ] + + scheduler = SafeScheduler(seconds_after_failure=120) + + t0 = time.time() + # run a first call for each case, since schedule will wait for the delay + # to make the first execution. + self.checkout() + scheduler.every(self.runner_options.delay_checkout).seconds.do(self.checkout) - if ticks is not None: - ticks_remaining -= 1 + self.advance_state(states) + scheduler.every(self.runner_options.delay_advance_status).seconds.do( + self.advance_state, states=states + ) - finally: - self.cleanup() + self.check_run_status() + scheduler.every(self.runner_options.delay_check_run_status).seconds.do( + self.check_run_status + ) + + # Limited workers will only affect the process interacting with the queue + # manager. When a job is submitted or terminated the count in the + # limited_workers can be directly updated, since by construction only one + # process will take care of the queue state. + # The refresh can be run on a relatively high delay since it should only + # account for actions from the user (e.g. rerun, cancel), that can alter + # the number of submitted/running jobs. + if self.limited_workers: + self.refresh_num_current_jobs() + scheduler.every(self.runner_options.delay_refresh_limited).seconds.do( + self.refresh_num_current_jobs + ) + if self.batch_workers: + self.update_batch_jobs() + scheduler.every(self.runner_options.delay_update_batch).seconds.do( + self.update_batch_jobs + ) + + running_states = [ + JobState.READY.value, + JobState.CHECKED_OUT.value, + JobState.TERMINATED.value, + JobState.DOWNLOADED.value, + JobState.UPLOADED.value, + JobState.SUBMITTED.value, + JobState.RUNNING.value, + JobState.BATCH_RUNNING.value, + JobState.BATCH_SUBMITTED.value, + ] + query = {"state": {"$in": running_states}} + jobs_available = True + while jobs_available: + scheduler.run_pending() + time.sleep(0.2) + jobs_available = self.job_controller.count_jobs(query=query) + if max_seconds and time.time() - t0 > max_seconds: + raise RuntimeError( + "Could execute all the jobs within the selected amount of time" + ) + + def run_one_job( + self, + db_id: str | None = None, + job_id: tuple[str, str] | None = None, + max_seconds: int | None = None, + raise_at_timeout: bool = True, + ) -> bool: + """ + Use the runner to run a single Job until it reaches a terminal state. + The job should be in the READY state and there should be no + Mainly used for testing + """ + + states = [ + JobState.CHECKED_OUT.value, + JobState.TERMINATED.value, + JobState.DOWNLOADED.value, + JobState.UPLOADED.value, + ] + + scheduler = SafeScheduler(seconds_after_failure=120) + + t0 = time.time() + query = {} + if db_id: + query["db_id"] = db_id + if job_id: + query["uuid"] = job_id[0] + query["index"] = job_id[1] + job_data = self.job_controller.checkout_job(query=query) + if not job_data: + if not db_id and not job_id: + return None + elif not db_id: + job_data = job_id + else: + j_info = self.job_controller.get_job_info(db_id=db_id) + job_data = [j_info.uuid, j_info.index] + + filter = {"uuid": job_data[0], "index": job_data[1]} + self.advance_state(states) + scheduler.every(self.runner_options.delay_advance_status).seconds.do( + self.advance_state, + states=states, + filter=filter, + ) + + self.check_run_status() + scheduler.every(self.runner_options.delay_check_run_status).seconds.do( + self.check_run_status, filter=filter + ) + + running_states = [ + JobState.READY.value, + JobState.CHECKED_OUT.value, + JobState.TERMINATED.value, + JobState.DOWNLOADED.value, + JobState.UPLOADED.value, + JobState.SUBMITTED.value, + JobState.RUNNING.value, + ] + + while True: + scheduler.run_pending() + time.sleep(0.2) + job_info = self.job_controller.get_job_info( + job_id=job_data[0], job_index=job_data[1] + ) + if job_info.state.value not in running_states: + return True + if max_seconds and time.time() - t0 > max_seconds: + if raise_at_timeout: + raise RuntimeError( + "Could execute the job within the selected amount of time" + ) + else: + return False def _get_limited_worker_query(self, states: list[str]) -> dict | None: """ @@ -349,7 +496,7 @@ def _get_limited_worker_query(self, states: list[str]) -> dict | None: return None - def advance_state(self, states: list[str]): + def advance_state(self, states: list[str], filter: dict | None = None): """ Acquire the lock and advance the state of a single job. @@ -373,6 +520,8 @@ def advance_state(self, states: list[str]): return else: query = {"state": {"$in": states}} + if filter: + query.update(filter) with self.job_controller.lock_job_for_update( query=query, @@ -651,7 +800,7 @@ def complete_job(self, lock): err_msg = "the parsed output does not contain the required information to complete the job" raise RemoteError(err_msg, True) - def check_run_status(self): + def check_run_status(self, filter: dict | None = None): """ Check the status of all the jobs submitted to a queue. @@ -661,12 +810,14 @@ def check_run_status(self): """ logger.debug("check_run_status") # check for jobs that could have changed state - workers_ids_docs = defaultdict(dict) + workers_ids_docs: dict = defaultdict(dict) db_filter = { "state": {"$in": [JobState.SUBMITTED.value, JobState.RUNNING.value]}, "lock_id": None, "remote.retry_time_limit": {"$not": {"$gt": datetime.utcnow()}}, } + if filter: + db_filter.update(filter) projection = [ "db_id", "uuid", diff --git a/src/jobflow_remote/testing/__init__.py b/src/jobflow_remote/testing/__init__.py index ecc98e2f..dc418f0f 100644 --- a/src/jobflow_remote/testing/__init__.py +++ b/src/jobflow_remote/testing/__init__.py @@ -62,6 +62,17 @@ def add_big_undefined_store(a: float, b: float): return Response({"data": [result] * 5_000, "result": result}) +@job +def add_sleep(a, b): + """ + Adds two numbers together and sleeps for "b" seconds + """ + import time + + time.sleep(b) + return a + b + + @job def create_detour(detour_job: Job): """ @@ -70,3 +81,26 @@ def create_detour(detour_job: Job): from jobflow import Flow return Response(detour=Flow(detour_job)) + + +@job +def self_replace(n: int): + """ + Create a replace Job with the same job n times. + """ + from jobflow import Flow + + if n > 0: + return Response(replace=self_replace(n - 1)) + + return n + + +@job +def ignore_input(a: int): + """ + Can receive an input, but ignores it. + + Allows to test flows with failed parents + """ + return 1 diff --git a/src/jobflow_remote/testing/cli.py b/src/jobflow_remote/testing/cli.py new file mode 100644 index 00000000..48d775df --- /dev/null +++ b/src/jobflow_remote/testing/cli.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import IO, Any + +from typer.testing import CliRunner, Result + +from jobflow_remote.cli.jf import app + + +def run_check_cli( + cli_args: str | Sequence[str] | None = None, + cli_input: bytes | str | IO[Any] | None = None, + cli_env: Mapping[str, str] | None = None, + catch_exceptions: bool = False, + required_out: str | Sequence[str] | None = None, + excluded_out: str | Sequence[str] | None = None, + error: bool = False, + terminal_width: int = 1000, +) -> Result: + if isinstance(required_out, str): + required_out = [required_out] + if isinstance(excluded_out, str): + excluded_out = [excluded_out] + + cli_runner = CliRunner() + + result = cli_runner.invoke( + app, + args=cli_args, + input=cli_input, + env=cli_env, + catch_exceptions=catch_exceptions, + terminal_width=terminal_width, + ) + + # note that stderr is not captured separately + assert error == ( + result.exit_code != 0 + ), f"cli should have {'not ' if not error else ''}failed. exit code: {result.exit_code}. stdout: {result.stdout}" + + if required_out: + for ro in required_out: + assert ro in result.stdout, f"{ro} missing from stdout: {result.stdout}" + + if excluded_out: + for eo in excluded_out: + assert eo not in result.stdout, f"{eo} present in stdout: {result.stdout}" + + return result diff --git a/src/jobflow_remote/utils/db.py b/src/jobflow_remote/utils/db.py index 222ad933..e09ecba6 100644 --- a/src/jobflow_remote/utils/db.py +++ b/src/jobflow_remote/utils/db.py @@ -141,11 +141,12 @@ def __init__( self.unavailable_document = None self.lock_id = lock_id or suuid() self.kwargs = kwargs - self.update_on_release: dict = {} + self.update_on_release: dict | list = {} self.sleep = sleep self.max_wait = max_wait self.projection = projection self.get_locked_doc = get_locked_doc + self.release_with_pipeline: bool = False @classmethod def get_lock_time(cls, d: dict): @@ -248,7 +249,10 @@ def release(self, exc_type, exc_val, exc_tb): update = {"$set": {self.LOCK_KEY: None, self.LOCK_TIME_KEY: None}} # TODO maybe set on release only if no exception was raised? if self.update_on_release: - update = deep_merge_dict(update, self.update_on_release) + if isinstance(self.update_on_release, list): + update = [update] + self.update_on_release + else: + update = deep_merge_dict(update, self.update_on_release) logger.debug(f"release lock with update: {update}") # TODO if failed to release the lock maybe retry before failing result = self.collection.update_one( diff --git a/tests/conftest.py b/tests/conftest.py index 4e1f51de..f6fac06b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,7 @@ +import random +import time +import warnings + import pytest @@ -61,3 +65,42 @@ def tmp_dir(): @pytest.fixture(scope="session") def debug_mode(): return False + + +def _get_random_name(length=6): + return "".join(random.choice("abcdef") for _ in range(length)) + + +@pytest.fixture(scope="session") +def random_project_name(): + return _get_random_name() + + +@pytest.fixture(scope="function") +def daemon_manager(random_project_name): + from jobflow_remote.jobs.daemon import DaemonError, DaemonManager, DaemonStatus + + dm = DaemonManager.from_project_name(random_project_name) + yield dm + # kill the processes and shut down the daemon (otherwise will remain in the STOPPED state) + dm.kill(raise_on_error=True) + time.sleep(0.5) + dm.shut_down(raise_on_error=True) + for i in range(10): + time.sleep(1) + try: + if dm.check_status() == DaemonStatus.SHUT_DOWN: + break + except DaemonError: + pass + else: + warnings.warn("daemon manager did not shut down within the expected time") + + +@pytest.fixture(scope="function") +def runner(): + from jobflow_remote.jobs.runner import Runner + + runner = Runner() + yield runner + runner.cleanup() diff --git a/tests/db/cli/test_admin.py b/tests/db/cli/test_admin.py new file mode 100644 index 00000000..9967effd --- /dev/null +++ b/tests/db/cli/test_admin.py @@ -0,0 +1,75 @@ +import pytest + + +def test_reset(job_controller, four_jobs): + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["admin", "reset", "-m", "1"], + required_out="The database was NOT reset", + cli_input="y", + ) + assert job_controller.count_jobs() == 4 + + run_check_cli(["admin", "reset", "-m", "10"], cli_input="n") + assert job_controller.count_jobs() == 4 + + run_check_cli( + ["admin", "reset", "-m", "10"], + required_out="The database was reset", + cli_input="y", + ) + assert job_controller.count_jobs() == 0 + + +def test_unlock(job_controller, one_job): + from jobflow_remote.testing.cli import run_check_cli + + j = one_job.jobs[0] + # catch the warning coming from MongoLock + with pytest.warns(UserWarning, match="Could not release lock for document"): + with job_controller.lock_job(filter={"uuid": j.uuid}): + run_check_cli( + ["admin", "unlock", "-did", "1"], + required_out="1 jobs were unlocked", + cli_input="y", + ) + + with job_controller.lock_job(filter={"uuid": j.uuid}): + run_check_cli( + ["admin", "unlock", "-did", "1"], + excluded_out="1 jobs were unlocked", + cli_input="n", + ) + + run_check_cli( + ["admin", "unlock", "-did", "10"], + required_out="No data matching the request", + error=True, + ) + + +def test_unlock_flow(job_controller, one_job): + from jobflow_remote.testing.cli import run_check_cli + + # catch the warning coming from MongoLock + with pytest.warns(UserWarning, match="Could not release lock for document"): + with job_controller.lock_flow(filter={"uuid": one_job.uuid}): + run_check_cli( + ["admin", "unlock-flow", "-fid", one_job.uuid], + required_out="1 flows were unlocked", + cli_input="y", + ) + + with job_controller.lock_flow(filter={"uuid": one_job.uuid}): + run_check_cli( + ["admin", "unlock-flow", "-fid", one_job.uuid], + excluded_out="1 flows were unlocked", + cli_input="n", + ) + + run_check_cli( + ["admin", "unlock-flow", "-fid", "xxxx"], + required_out="No data matching the request", + error=True, + ) diff --git a/tests/db/cli/test_flow.py b/tests/db/cli/test_flow.py new file mode 100644 index 00000000..58b1dcf0 --- /dev/null +++ b/tests/db/cli/test_flow.py @@ -0,0 +1,51 @@ +def test_flows_list(job_controller, four_jobs): + + from jobflow_remote.testing.cli import run_check_cli + + columns = ["DB id", "Name", "State", "Flow id", "Num Jobs", "Last updated"] + outputs = columns + [f"f{i}" for i in range(1, 3)] + ["READY"] + + run_check_cli(["flow", "list"], required_out=outputs) + + # the output table is squeezed. Hard to check the stdout. Just check that runs correctly + run_check_cli(["flow", "list", "-v"]) + + # trigger the additional information + outputs = ["The number of Flows printed is limited by the maximum selected"] + run_check_cli(["flow", "list", "-m", "1"], required_out=outputs) + + outputs = ["READY"] + run_check_cli(["flow", "list", "-fid", four_jobs[0].uuid], required_out=outputs) + + +def test_delete(job_controller, four_jobs): + + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["flow", "delete", "-fid", four_jobs[0].uuid], + required_out="Deleted Flow", + cli_input="y", + ) + assert job_controller.count_flows() == 1 + + # don't confirm + run_check_cli(["flow", "delete", "-fid", four_jobs[1].uuid], cli_input="n") + assert job_controller.count_flows() == 1 + + run_check_cli( + ["flow", "delete", "-fid", four_jobs[0].uuid], + required_out="No flows matching criteria", + ) + + +def test_flow_info(job_controller, four_jobs): + + from jobflow_remote.testing.cli import run_check_cli + + columns = ["DB id", "Name", "State", "Job id", "(Index)", "Worker"] + outputs = columns + [f"add{i}" for i in range(1, 3)] + ["READY", "WAITING"] + excluded = [f"add{i}" for i in range(3, 5)] + run_check_cli( + ["flow", "info", "-j", "1"], required_out=outputs, excluded_out=excluded + ) diff --git a/tests/db/cli/test_job.py b/tests/db/cli/test_job.py new file mode 100644 index 00000000..91717f0b --- /dev/null +++ b/tests/db/cli/test_job.py @@ -0,0 +1,262 @@ +def test_jobs_list(job_controller, four_jobs): + + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing.cli import run_check_cli + + # split "job id" from "index", because it can be sent to a new line + columns = ["DB id", "Name", "State", "Job id", "(Index)", "Worker", "Last updated"] + outputs = columns + [f"add{i}" for i in range(1, 5)] + ["READY", "WAITING"] + + run_check_cli(["job", "list"], required_out=outputs) + + # the output table is squeezed. Hard to check the stdout. Just check that runs correctly + run_check_cli(["job", "list", "-v"]) + run_check_cli(["job", "list", "-vvv"]) + + outputs = ["add1", "READY"] + excluded = [f"add{i}" for i in range(2, 5)] + run_check_cli( + ["job", "list", "-did", "1"], required_out=outputs, excluded_out=excluded + ) + run_check_cli( + ["job", "list", "-q", '{"db_id": "1"}'], + required_out=outputs, + excluded_out=excluded, + ) + + # trigger the additional information + assert job_controller.set_job_state(JobState.REMOTE_ERROR, db_id="1") + run_check_cli( + ["job", "list"], required_out=["Get more information about the errors"] + ) + + +def test_job_info(job_controller, four_jobs): + + from jobflow_remote.testing.cli import run_check_cli + + outputs = ["name = 'add1'", "state = 'READY'"] + excluded_n = ["run_dir = None", "start_time = None"] + excluded = excluded_n + ["job = {"] + run_check_cli(["job", "info", "1"], required_out=outputs, excluded_out=excluded) + + outputs += excluded_n + run_check_cli( + ["job", "info", four_jobs[0].jobs[0].uuid, "-n"], required_out=outputs + ) + + outputs = ["state = 'READY'", "job = {"] + run_check_cli(["job", "info", "1", "-vvv"], required_out=outputs) + + run_check_cli( + ["job", "info", "10"], error=True, required_out="No data matching the request" + ) + + +def test_set_state(job_controller, four_jobs): + + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "set-state", "UPLOADED", "1"], required_out="operation completed" + ) + assert job_controller.set_job_state(JobState.UPLOADED, db_id="1") + run_check_cli( + ["job", "set-state", "UPLOADED", "10"], + required_out="Could not change the job state", + error=True, + ) + + +def test_rerun(job_controller, four_jobs): + + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing.cli import run_check_cli + + assert job_controller.set_job_state(JobState.COMPLETED, db_id="1") + + run_check_cli( + ["job", "rerun", "-did", "1", "-f"], + required_out="Operation completed: 1 jobs modified", + ) + assert job_controller.get_job_info(db_id="1").state == JobState.READY + # fails because already READY + run_check_cli(["job", "rerun", "-did", "1"], required_out="Error while rerunning") + + +def test_retry(job_controller, four_jobs): + + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing.cli import run_check_cli + + assert job_controller.set_job_state(JobState.UPLOADED, db_id="1") + assert job_controller.get_job_info(db_id="1").state == JobState.UPLOADED + + run_check_cli( + ["job", "retry", "-did", "1", "-bl"], + required_out="Operation completed: 1 jobs modified", + ) + assert job_controller.get_job_info(db_id="1").state == JobState.UPLOADED + run_check_cli(["job", "retry", "-did", "2"], required_out="Error while retrying") + + +def test_play_pause(job_controller, four_jobs): + + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "pause", "-did", "1"], + required_out="Operation completed: 1 jobs modified", + ) + run_check_cli(["job", "pause", "-did", "1"], required_out="Error while pausing") + assert job_controller.get_job_info(db_id="1").state == JobState.PAUSED + + run_check_cli( + ["job", "play", "-did", "1"], + required_out="Operation completed: 1 jobs modified", + ) + assert job_controller.get_job_info(db_id="1").state == JobState.READY + run_check_cli(["job", "play", "-did", "1"], required_out="Error while playing") + + +def test_stop(job_controller, four_jobs): + + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "stop", "-did", "1", "-bl"], + required_out="Operation completed: 1 jobs modified", + ) + run_check_cli(["job", "stop", "-did", "1"], required_out="Error while stopping") + assert job_controller.get_job_info(db_id="1").state == JobState.USER_STOPPED + + +def test_queue_out(job_controller, one_job): + + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "queue-out", "1"], + required_out=["The remote folder has not been created yet"], + ) + + runner = Runner() + runner.run_one_job(db_id="1") + + run_check_cli(["job", "queue-out", "1"], required_out=["Queue output from", "add"]) + + run_check_cli( + ["job", "queue-out", "10"], + required_out="No data matching the request", + error=True, + ) + + +def test_set_worker(job_controller, one_job): + + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "set", "worker", "-did", "1", "test_local_worker_2"], + required_out="Operation completed: 1 jobs modified", + ) + + assert job_controller.get_job_info(db_id="1").worker == "test_local_worker_2" + + +def test_set_exec_config(job_controller, one_job): + + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "set", "exec-config", "-did", "1", "test"], + required_out="Operation completed: 1 jobs modified", + ) + + assert job_controller.get_job_doc(db_id="1").exec_config == "test" + + +def test_set_resources(job_controller, one_job): + + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "set", "resources", "-did", "1", '{"ntasks": 1}'], + required_out="Operation completed: 1 jobs modified", + ) + + assert job_controller.get_job_doc(db_id="1").resources == {"ntasks": 1} + + +def test_job_dump(job_controller, one_job, tmp_dir): + + import os + + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli(["job", "dump", "-did", "1"]) + assert os.path.isfile("jobs_dump.json") + + +def test_output(job_controller, one_job): + + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli(["job", "output", "1"], required_out="has no output", error=True) + + runner = Runner() + runner.run_one_job(db_id="1") + + run_check_cli(["job", "output", "1"], required_out="6") + + +def test_files_list(job_controller, one_job): + + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "files", "ls", "1"], + required_out="The remote folder has not been created yet", + ) + + runner = Runner() + runner.run_one_job(db_id="1") + + run_check_cli(["job", "files", "ls", "1"], required_out=["queue.out", "queue.err"]) + + run_check_cli( + ["job", "files", "ls", "10"], + error=True, + required_out="No data matching the request", + ) + + +def test_files_get(job_controller, one_job, tmp_dir): + + import os + + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["job", "files", "get", "1", "queue.err"], + required_out="The remote folder has not been created yet", + ) + + runner = Runner() + runner.run_one_job(db_id="1") + + run_check_cli(["job", "files", "get", "1", "queue.out"]) + assert os.path.isfile("queue.out") + + run_check_cli( + ["job", "files", "get", "10", "queue.out"], + error=True, + required_out="No data matching the request", + ) diff --git a/tests/db/cli/test_project.py b/tests/db/cli/test_project.py new file mode 100644 index 00000000..0de19047 --- /dev/null +++ b/tests/db/cli/test_project.py @@ -0,0 +1,120 @@ +def test_list_projects(job_controller, random_project_name, monkeypatch, tmp_dir): + import os + + from monty.serialization import dumpfn + + from jobflow_remote import SETTINGS + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli(["project", "list"], required_out=random_project_name) + + # change project directory and test options there + with monkeypatch.context() as m: + print(tmp_dir, tmp_dir.__class__) + m.setattr(SETTINGS, "projects_folder", os.getcwd()) + run_check_cli(["project", "list"], required_out="No project available in") + + dumpfn({"name": "testtest", "xxx": 1}, "testest.yaml") + + output = [ + "The following project names exist in files in the project", + "testtest.", + ] + run_check_cli(["project", "list"], required_out=output) + + +def test_current_project(job_controller, random_project_name): + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli( + ["project"], required_out=f"The selected project is {random_project_name}" + ) + + +def test_generate(job_controller, random_project_name, monkeypatch, tmp_dir): + import os + + from jobflow_remote import SETTINGS + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli(["project", "list"], required_out=random_project_name) + + # change project directory and test options there + with monkeypatch.context() as m: + print(tmp_dir, tmp_dir.__class__) + m.setattr(SETTINGS, "projects_folder", os.getcwd()) + run_check_cli( + ["project", "generate", "test_proj_1"], + required_out="Configuration file for project test_proj_1 created in", + ) + run_check_cli( + ["project", "generate", "--full", "test_proj_2"], + required_out="Configuration file for project test_proj_2 created in", + ) + + run_check_cli( + ["project", "generate", "test_proj_1"], + required_out="Project with name test_proj_1 already exists", + error=True, + ) + + +def test_check(job_controller): + from jobflow_remote.testing.cli import run_check_cli + + output = [ + "✓ Worker test_local_worker", + "✓ Worker test_local_worker_2", + "✓ Jobstore", + "✓ Queue store", + ] + run_check_cli(["project", "check"], required_out=output) + + +def test_remove(job_controller, random_project_name, monkeypatch, tmp_dir): + import os + + from jobflow_remote import SETTINGS, ConfigManager + from jobflow_remote.testing.cli import run_check_cli + + run_check_cli(["project", "list"], required_out=random_project_name) + + # change project directory and test options there + with monkeypatch.context() as m: + print(tmp_dir, tmp_dir.__class__) + m.setattr(SETTINGS, "projects_folder", os.getcwd()) + cm = ConfigManager() + run_check_cli( + ["project", "generate", "test_proj_1"], + required_out="Configuration file for project test_proj_1 created in", + ) + cm = ConfigManager() + assert "test_proj_1" in cm.projects_data + run_check_cli( + ["project", "remove", "test_proj_1"], + required_out="This will delete also the folders", + cli_input="y", + ) + + cm = ConfigManager() + assert "test_proj_1" not in cm.projects_data + + run_check_cli( + ["project", "remove", "test_proj_1"], + required_out="Project test_proj_1 does not exist", + cli_input="y", + ) + + +def test_list_exec_config(job_controller): + from jobflow_remote.testing.cli import run_check_cli + + output = ["Name", "modules", "export", "pre_run", "post_run", "test"] + run_check_cli(["project", "exec_config", "list", "-v"], required_out=output) + + +def test_list_workers(job_controller): + from jobflow_remote.testing.cli import run_check_cli + + output = ["Name", "type", "info", "test_local_worker", "test_local_worker_2"] + run_check_cli(["project", "worker", "list", "-v"], required_out=output) diff --git a/tests/db/conftest.py b/tests/db/conftest.py new file mode 100644 index 00000000..cb1eb621 --- /dev/null +++ b/tests/db/conftest.py @@ -0,0 +1,218 @@ +import os +import random +import shutil +import tempfile +import warnings +from pathlib import Path + +import pytest + + +def _get_random_name(length=6): + return "".join(random.choice("abcdef") for _ in range(length)) + + +@pytest.fixture(scope="session") +def store_database_name(): + return "jfremote_db_tests__" + + +@pytest.fixture(scope="session") +def mongoclient(): + """ + Generate a MongoClient for a local database. + If a local DB is already available is that one (should be the one used in + the CI or by developers with an accessible local DB). Otherwise, generate + one with pymongo_inmemory, that should be installed. + """ + + import pymongo + + mc = pymongo.MongoClient(host="localhost", port=27017) + # try connecting to the DB with a short delay, since the DB is local it + # should not take long to reply + try: + with pymongo.timeout(1): + mc.server_info() + yield mc + except Exception as e: + warnings.warn( + f"Could not connect to a local DB {getattr(e, 'message', str(e))}. Trying with pymongo_inmemory" + ) + + try: + import pymongo_inmemory + except ImportError: + raise pytest.skip( + "No local DB and pymongo_inmemory. Either start a local mongodb or install pymongo_inmemory" + ) + + mc = pymongo_inmemory.MongoClient() + assert mc.server_info() + + yield mc + # stop the db started by pymongo_inmemory + mc.close() + + +@pytest.fixture(scope="session") +def mongo_jobstore(store_database_name): + from jobflow import JobStore + from maggma.stores import MongoStore + + store = JobStore(MongoStore(store_database_name, "outputs")) + store.connect() + return store + + +@pytest.fixture(scope="session", autouse=True) +def write_tmp_settings( + random_project_name, + store_database_name, + mongoclient, +): + """Collects the various sub-configs and writes them to a temporary file in a temporary directory.""" + + tmp_dir: Path = Path(tempfile.mkdtemp()) + + os.environ["JFREMOTE_PROJECTS_FOLDER"] = str(tmp_dir.resolve()) + workdir = tmp_dir / "jfr" + workdir.mkdir(exist_ok=True) + os.environ["JFREMOTE_PROJECT"] = random_project_name + # Set the config file to a random path so that we don't accidentally load the default + os.environ["JFREMOTE_CONFIG_FILE"] = _get_random_name(length=10) + ".json" + # This import must come after setting the env vars as jobflow loads the default config + # on import + from jobflow_remote.config import Project + + project = Project( + name=random_project_name, + jobstore={ + "docs_store": { + "type": "MongoStore", + "database": store_database_name, + "host": mongoclient.HOST, + "port": mongoclient.PORT, + "collection_name": "docs", + }, + "additional_stores": { + "big_data": { + "type": "GridFSStore", + "database": store_database_name, + "host": mongoclient.HOST, + "port": mongoclient.PORT, + "collection_name": "data", + }, + }, + }, + queue={ + "store": { + "type": "MongoStore", + "database": store_database_name, + "host": mongoclient.HOST, + "port": mongoclient.PORT, + "collection_name": "jobs", + }, + }, + log_level="debug", + workers={ + "test_local_worker": dict( + type="local", + scheduler_type="shell", + work_dir=str(workdir), + resources={}, + ), + "test_local_worker_2": dict( + type="local", + scheduler_type="shell", + work_dir=str(workdir), + resources={}, + ), + }, + exec_config={"test": {"export": {"TESTING_ENV_VAR": random_project_name}}}, + runner=dict( + delay_checkout=1, + delay_check_run_status=1, + delay_advance_status=1, + max_step_attempts=3, + delta_retry=(1, 1, 1), + ), + ) + project_json = project.model_dump_json(indent=2) + with open(tmp_dir / f"{random_project_name}.json", "w") as f: + f.write(project_json) + + # In some cases it seems that the SETTINGS have already been imported + # and thus not taking the new configurations into account. + # Regenerate the JobflowRemoteSettings after setting paths and project + import jobflow_remote + from jobflow_remote.config.settings import JobflowRemoteSettings + + jobflow_remote.SETTINGS = JobflowRemoteSettings() + + yield + shutil.rmtree(tmp_dir) + + +@pytest.fixture(scope="function") +def job_controller(random_project_name): + """Yields a jobcontroller instance for the test suite that also sets up the jobstore, + resetting it after every test. + """ + from jobflow_remote.jobs.jobcontroller import JobController + + jc = JobController.from_project_name(random_project_name) + assert jc.reset() + yield jc + + +@pytest.fixture(scope="function") +def one_job(random_project_name): + """ + Add one flow with one job to the DB + """ + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.testing import add + + j = add(1, 5) + flow = Flow([j]) + submit_flow(flow, worker="test_local_worker") + + yield flow + + +@pytest.fixture(scope="function") +def four_jobs(random_project_name): + """ + Add two flows with two jobs each to the DB + """ + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.testing import add + + add_first = add(1, 5) + add_first.name = "add1" + add_second = add(add_first.output, 5) + add_second.name = "add2" + + add_first.update_metadata({"test_meta": 1}) + + flow = Flow([add_first, add_second]) + flow.name = "f1" + submit_flow(flow, worker="test_local_worker") + + add_third = add(1, 5) + add_third.name = "add3" + add_fourth = add(add_third.output, 5) + add_fourth.name = "add4" + + flow2 = Flow([add_third, add_fourth]) + flow2.name = "f2" + submit_flow(flow2, worker="test_local_worker") + + flows = [flow, flow2] + + yield flows diff --git a/tests/db/jobs/test_daemon.py b/tests/db/jobs/test_daemon.py new file mode 100644 index 00000000..3d838fa7 --- /dev/null +++ b/tests/db/jobs/test_daemon.py @@ -0,0 +1,138 @@ +import os +import time + +import pytest + +# pytestmark = pytest.mark.skipif( +# not os.environ.get("CI"), +# reason="Only run integration tests in CI, unless forced with 'CI' env var", +# ) + + +def _wait_daemon_started(daemon_manager, max_wait: int = 10) -> bool: + from jobflow_remote.jobs.daemon import DaemonStatus + + for i in range(max_wait): + time.sleep(1) + state = daemon_manager.check_status() + assert state in (DaemonStatus.STARTING, DaemonStatus.RUNNING) + if state == DaemonStatus.RUNNING: + return True + raise RuntimeError( + f"The daemon did not start running within the expected time ({max_wait})" + ) + + +@pytest.mark.parametrize( + "single", + [True, False], +) +def test_start_stop(job_controller, single, daemon_manager): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.daemon import DaemonStatus + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add + + j = add(1, 5) + + flow = Flow([j]) + submit_flow(flow, worker="test_local_worker") + + assert job_controller.count_jobs(state=JobState.READY) == 1 + + assert daemon_manager.check_status() == DaemonStatus.SHUT_DOWN + assert daemon_manager.start(raise_on_error=True, single=single) + _wait_daemon_started(daemon_manager) + + finished_states = (JobState.REMOTE_ERROR, JobState.FAILED, JobState.COMPLETED) + + for i in range(20): + time.sleep(1) + jobs_info = job_controller.get_jobs_info() + if all(ji.state in finished_states for ji in jobs_info): + break + + assert job_controller.count_jobs(state=JobState.COMPLETED) == 1 + + processes_info = daemon_manager.get_processes_info() + expected_nprocs = 2 if single else 5 + assert len(processes_info) == expected_nprocs + + assert daemon_manager.stop(raise_on_error=True, wait=True) + assert daemon_manager.check_status() == DaemonStatus.STOPPED + + assert daemon_manager.start(raise_on_error=True, single=True) + time.sleep(0.5) + assert daemon_manager.shut_down(raise_on_error=True) + time.sleep(1) + assert daemon_manager.check_status() == DaemonStatus.SHUT_DOWN + + processes_info = daemon_manager.get_processes_info() + assert processes_info is None + + +def test_kill(job_controller, daemon_manager): + from jobflow_remote.jobs.daemon import DaemonStatus + + assert daemon_manager.check_status() == DaemonStatus.SHUT_DOWN + + # killing when shut down should not have an effect + assert daemon_manager.kill(raise_on_error=True) + + assert daemon_manager.start(raise_on_error=True, single=True) + _wait_daemon_started(daemon_manager) + + print(daemon_manager.get_processes_info()) + assert daemon_manager.kill(raise_on_error=True) + time.sleep(1) + print(daemon_manager.get_processes_info()) + assert daemon_manager.check_status() == DaemonStatus.STOPPED + + +def test_kill_supervisord(job_controller, daemon_manager, caplog): + import signal + import time + + from jobflow_remote.jobs.daemon import DaemonStatus + + assert daemon_manager.check_status() == DaemonStatus.SHUT_DOWN + assert daemon_manager.start(raise_on_error=True, single=True) + _wait_daemon_started(daemon_manager) + + processes_info = daemon_manager.get_processes_info() + supervisord_pid = processes_info["supervisord"]["pid"] + + # directly kill the supervisord process + os.kill(supervisord_pid, signal.SIGKILL) + # also kill all the processes + for process_dict in processes_info.values(): + os.kill(process_dict["pid"], signal.SIGKILL) + time.sleep(2) + assert daemon_manager.check_status() == DaemonStatus.SHUT_DOWN + # check that the warning message is present among the logged messages + # TODO if run alone the check below passes. If run among all the others + # the log message is not present. + # log_msg = caplog.messages + # assert len(log_msg) > 0 + # assert f"Process with pid {supervisord_pid} is not running but daemon files are present" in log_msg[-1] + + +def test_kill_one_process(job_controller, daemon_manager, caplog): + import signal + import time + + from jobflow_remote.jobs.daemon import DaemonStatus + + assert daemon_manager.check_status() == DaemonStatus.SHUT_DOWN + assert daemon_manager.start(raise_on_error=True, single=False) + _wait_daemon_started(daemon_manager) + + processes_info = daemon_manager.get_processes_info() + run_jobflow_queue_pid = processes_info["run_jobflow_queue"]["pid"] + + # directly kill the supervisord process + os.kill(run_jobflow_queue_pid, signal.SIGKILL) + time.sleep(1) + assert daemon_manager.check_status() == DaemonStatus.PARTIALLY_RUNNING diff --git a/tests/db/jobs/test_jobcontroller.py b/tests/db/jobs/test_jobcontroller.py new file mode 100644 index 00000000..d35907ac --- /dev/null +++ b/tests/db/jobs/test_jobcontroller.py @@ -0,0 +1,503 @@ +import pytest + + +def test_submit_flow(job_controller, runner): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.state import FlowState, JobState + from jobflow_remote.testing import add + + add_first = add(1, 5) + add_second = add(add_first.output, 5) + + flow = Flow([add_first, add_second]) + submit_flow(flow, worker="test_local_worker") + + runner.run_all_jobs(max_seconds=10) + + assert len(job_controller.get_jobs({})) == 2 + job_1, job_2 = job_controller.get_jobs({}) + assert job_1["job"]["function_args"] == [1, 5] + assert job_1["job"]["name"] == "add" + + output_1 = job_controller.jobstore.get_output(uuid=job_1["uuid"]) + assert output_1 == 6 + output_2 = job_controller.jobstore.get_output(uuid=job_2["uuid"]) + assert output_2 == 11 + assert ( + job_controller.count_jobs(state=JobState.COMPLETED) == 2 + ), f"Jobs not marked as completed, full job info:\n{job_controller.get_jobs({})}" + assert ( + job_controller.count_flows(state=FlowState.COMPLETED) == 1 + ), f"Flows not marked as completed, full flow info:\n{job_controller.get_flows({})}" + + +def test_queries(job_controller, runner): + """Test different options to query Jobs and Flows""" + import datetime + + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.state import FlowState, JobState + from jobflow_remote.testing import add + + add_first = add(1, 5) + add_first.name = "add1" + add_second = add(add_first.output, 5) + add_second.name = "add2" + + add_first.update_metadata({"test_meta": 1}) + + flow = Flow([add_first, add_second]) + flow.name = "f1" + submit_flow(flow, worker="test_local_worker") + + add_third = add(1, 5) + add_third.name = "add3" + add_fourth = add(add_third.output, 5) + add_fourth.name = "add4" + + flow2 = Flow([add_third, add_fourth]) + flow2.name = "f2" + submit_flow(flow2, worker="test_local_worker") + + date_create = datetime.datetime.now() + + assert runner.run_one_job(max_seconds=10) + + assert job_controller.count_jobs(state=JobState.COMPLETED) == 1 + assert job_controller.count_jobs(state=JobState.READY) == 2 + assert job_controller.count_jobs(db_ids="1") == 1 + assert job_controller.count_jobs(job_ids=(add_first.uuid, 1)) == 1 + jobs_start_date = job_controller.get_jobs_info(start_date=date_create) + assert len(jobs_start_date) == 1 + assert jobs_start_date[0].uuid == add_first.uuid + + jobs_end_date = job_controller.get_jobs_info(end_date=date_create) + assert jobs_end_date[0].uuid == add_second.uuid + + assert job_controller.count_jobs(metadata={"test_meta": 1}) == 1 + + assert job_controller.count_jobs(name="add") == 0 + assert job_controller.count_jobs(name="add1") == 1 + assert job_controller.count_jobs(name="add*") == 4 + + assert job_controller.count_jobs(flow_ids=flow.uuid) == 2 + + with job_controller.lock_job(filter={"uuid": add_second.uuid}): + assert job_controller.count_jobs(locked=True) == 1 + + assert ( + job_controller.count_jobs( + query={"uuid": {"$in": (add_first.uuid, add_second.uuid)}} + ) + == 2 + ) + + assert job_controller.count_flows(state=FlowState.READY) == 1 + assert job_controller.count_flows(state=FlowState.RUNNING) == 1 + assert job_controller.count_flows(job_ids=add_first.uuid) == 1 + assert job_controller.count_flows(db_ids="1") == 1 + assert job_controller.count_flows(flow_ids=flow.uuid) == 1 + assert job_controller.count_flows(start_date=date_create) == 1 + assert job_controller.count_flows(end_date=date_create) == 1 + assert job_controller.count_flows(name="f1") == 1 + assert job_controller.count_flows(name="f*") == 2 + assert ( + job_controller.count_flows(query={"uuid": {"$in": (flow.uuid, flow2.uuid)}}) + == 2 + ) + + +def test_rerun_completed(job_controller, runner): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add + from jobflow_remote.utils.db import JobLockedError + + j1 = add(1, 2) + j2 = add(j1.output, 2) + j3 = add(j2.output, 2) + flow = Flow([j1, j2, j3]) + + submit_flow(flow, worker="test_local_worker") + + assert runner.run_one_job(max_seconds=10, job_id=[j1.uuid, j1.index]) + j1_info = job_controller.get_job_info(job_id=j1.uuid, job_index=j1.index) + j2_info = job_controller.get_job_info(job_id=j2.uuid, job_index=j2.index) + j3_info = job_controller.get_job_info(job_id=j3.uuid, job_index=j3.index) + assert j1_info.state == JobState.COMPLETED + assert j2_info.state == JobState.READY + assert j3_info.state == JobState.WAITING + + # try rerunning the second job. Wrong state + with pytest.raises(ValueError, match="The Job is in the READY state"): + job_controller.rerun_job(job_id=j2.uuid, job_index=j2.index) + + # since the first job is completed, the force option is required + with pytest.raises(ValueError, match="Job in state COMPLETED cannot be rerun"): + job_controller.rerun_job(job_id=j1.uuid, job_index=j1.index) + + # two jobs are modified, as j2 is switched to WAITING + # Use rerun_jobs instead of rerun_job to test that as well + assert set(job_controller.rerun_jobs(job_ids=(j1.uuid, j1.index), force=True)) == { + j1_info.db_id, + j2_info.db_id, + } + assert ( + job_controller.get_job_info(job_id=j1.uuid, job_index=j1.index).state + == JobState.READY + ) + assert ( + job_controller.get_job_info(job_id=j2.uuid, job_index=j2.index).state + == JobState.WAITING + ) + + # run all the jobs + runner.run_all_jobs(max_seconds=20) + assert job_controller.count_jobs(state=JobState.COMPLETED) == 3 + with pytest.raises(ValueError, match="Job in state COMPLETED cannot be rerun"): + job_controller.rerun_job(job_id=j3.uuid, job_index=j3.index) + + # The last job can be rerun, but still needs the "force" option + assert set( + job_controller.rerun_job(job_id=j3.uuid, job_index=j3.index, force=True) + ) == {j3_info.db_id} + + # The remaining tests are to verify that everything is correct with locked jobs as well + with job_controller.lock_job(filter={"uuid": j2.uuid}): + with pytest.raises( + JobLockedError, + match=f"Job with db_id {j2_info.db_id} is locked with lock_id", + ): + job_controller.rerun_job(job_id=j2.uuid, job_index=j2.index) + + # try waiting, but fails in this case as well + with pytest.raises( + JobLockedError, + match=f"Job with db_id {j2_info.db_id} is locked with lock_id", + ): + job_controller.rerun_job(job_id=j2.uuid, job_index=j2.index, wait=1) + + # The rerun fails even if a child is locked + with job_controller.lock_job(filter={"uuid": j3.uuid}): + with pytest.raises( + JobLockedError, + match=f"Job with db_id {j3_info.db_id} is locked with lock_id", + ): + job_controller.rerun_job(job_id=j2.uuid, job_index=j2.index, force=True) + + assert ( + job_controller.get_job_info(job_id=j2.uuid, job_index=j2.index).state + == JobState.COMPLETED + ) + + # can rerun if breaking the lock + # catch the warning coming from MongoLock + with pytest.warns(UserWarning, match="Could not release lock for document"): + with job_controller.lock_job(filter={"uuid": j2.uuid}): + assert set( + job_controller.rerun_job( + job_id=j2.uuid, job_index=j2.index, force=True, break_lock=True + ) + ) == {j2_info.db_id, j3_info.db_id} + + assert ( + job_controller.get_job_info(job_id=j2.uuid, job_index=j2.index).state + == JobState.READY + ) + + +def test_rerun_failed(job_controller, runner): + from jobflow import Flow, OnMissing + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add, always_fails, ignore_input, self_replace + + j1 = always_fails() + j2 = add(j1.output, 2) + j3 = ignore_input(j1.output) + j3.config.on_missing_references = OnMissing.NONE + j4 = self_replace(j3.output) + flow = Flow([j1, j2, j3, j4]) + + submit_flow(flow, worker="test_local_worker") + + assert runner.run_one_job(max_seconds=10, job_id=[j1.uuid, j1.index]) + + j1_info = job_controller.get_job_info(job_id=j1.uuid, job_index=j1.index) + j2_info = job_controller.get_job_info(job_id=j2.uuid, job_index=j2.index) + j3_info = job_controller.get_job_info(job_id=j3.uuid, job_index=j3.index) + j4_info = job_controller.get_job_info(job_id=j4.uuid, job_index=j4.index) + + assert j1_info.state == JobState.FAILED + assert j2_info.state == JobState.WAITING + assert j3_info.state == JobState.READY + assert j4_info.state == JobState.WAITING + + # rerun without "force". Since the job is FAILED and the children are + # WAITING or READY + assert set(job_controller.rerun_job(job_id=j1.uuid, job_index=j1.index)) == { + j1_info.db_id, + j3_info.db_id, + } + + assert job_controller.count_jobs(state=JobState.READY) == 1 + + # run the first job again and the job with OnMissing.None + assert runner.run_one_job(max_seconds=10, job_id=[j1.uuid, j1.index]) + assert runner.run_one_job(max_seconds=10, job_id=[j3.uuid, j3.index]) + + # cannot rerun the first as one of the children is COMPLETED + with pytest.raises( + ValueError, + match="The child of Job.*has state COMPLETED which is not acceptable", + ): + job_controller.rerun_job(job_id=j1.uuid, job_index=j1.index) + + # can be rerun with the "force" option + assert set( + job_controller.rerun_job(job_id=j1.uuid, job_index=j1.index, force=True) + ) == {j1_info.db_id, j3_info.db_id, j4_info.db_id} + + assert job_controller.count_jobs(state=JobState.READY) == 1 + + # run again the jobs with j4. This generates a replace + assert runner.run_one_job(max_seconds=10, job_id=[j1.uuid, j1.index]) + assert runner.run_one_job(max_seconds=10, job_id=[j3.uuid, j3.index]) + assert runner.run_one_job(max_seconds=10, job_id=[j4.uuid, j4.index]) + + assert job_controller.count_jobs(job_ids=(j4.uuid, 2)) == 1 + + # At this point it is impossible to rerun, even with the "force" option + # because it will require rerunning j4, which is COMPLETED with a replacement + # already existing in the DB. + with pytest.raises( + ValueError, + match="The child of Job.*has state COMPLETED which is not acceptable", + ): + job_controller.rerun_job(job_id=j1.uuid, job_index=j1.index) + + with pytest.raises( + ValueError, match="Job.*has a child job.*which is not the last index" + ): + job_controller.rerun_job(job_id=j1.uuid, job_index=j1.index, force=True) + + assert job_controller.get_job_info(db_id=j1_info.db_id).state == JobState.FAILED + assert ( + job_controller.get_job_info(job_id=j4.uuid, job_index=2).state == JobState.READY + ) + + +def test_rerun_remote_error(job_controller, monkeypatch, runner): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add + + j1 = add(1, 2) + j2 = add(j1.output, 2) + flow = Flow([j1, j2]) + + submit_flow(flow, worker="test_local_worker") + + # patch the upload method of the runner to trigger a remote error + def upload_error(self, lock): + raise RuntimeError("FAKE ERROR") + + with monkeypatch.context() as m: + m.setattr(Runner, "upload", upload_error) + # patch this to 1 to avoid retrying multiple times + m.setattr(runner.runner_options, "max_step_attempts", 1) + with pytest.warns(match="FAKE ERROR"): + assert runner.run_one_job(max_seconds=10, job_id=(j1.uuid, j1.index)) + + j1_info = job_controller.get_job_info(job_id=j1.uuid, job_index=j1.index) + assert j1_info.state == JobState.REMOTE_ERROR + + # can rerun without "force" + assert job_controller.rerun_job(job_id=j1.uuid, job_index=j1.index, force=True) == [ + j1_info.db_id + ] + assert ( + job_controller.get_job_info(job_id=j1.uuid, job_index=j1.index).state + == JobState.READY + ) + + +def test_retry(job_controller, monkeypatch, runner): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add + + j = add(1, 5) + flow = Flow([j]) + submit_flow(flow, worker="test_local_worker") + + # cannot retry a READY job + with pytest.raises(ValueError, match="Job in state READY cannot be retried"): + job_controller.retry_job(job_id=j.uuid, job_index=j.index) + + # patch the upload method of the runner to trigger a remote error + def upload_error(self, lock): + raise RuntimeError("FAKE ERROR") + + # Run to get the Job the REMOTE_ERROR state + with monkeypatch.context() as m: + m.setattr(Runner, "upload", upload_error) + # patch this to 1 to avoid retrying multiple times + m.setattr(runner.runner_options, "max_step_attempts", 1) + with pytest.warns(match="FAKE ERROR"): + assert runner.run_one_job(max_seconds=10, job_id=(j.uuid, j.index)) + + j_info = job_controller.get_job_info(job_id=j.uuid, job_index=j.index) + assert j_info.state == JobState.REMOTE_ERROR + assert j_info.remote.retry_time_limit is not None + + assert job_controller.retry_job(job_id=j.uuid, job_index=j.index) == [j_info.db_id] + + j_info = job_controller.get_job_info(job_id=j.uuid, job_index=j.index) + assert j_info.state == JobState.CHECKED_OUT + assert j_info.remote.retry_time_limit is None + + # Run to fail only once + with monkeypatch.context() as m: + m.setattr(Runner, "upload", upload_error) + # patch to make the runner fail only once + m.setattr(runner.runner_options, "delta_retry", (30, 300, 1200)) + with pytest.warns(match="FAKE ERROR"): + assert not runner.run_one_job( + max_seconds=2, job_id=(j.uuid, j.index), raise_at_timeout=False + ) + + j_info = job_controller.get_job_info(job_id=j.uuid, job_index=j.index) + assert j_info.state == JobState.CHECKED_OUT + assert j_info.remote.retry_time_limit is not None + + +def test_pause_play(job_controller): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.state import FlowState, JobState + from jobflow_remote.testing import add + + j = add(1, 5) + flow = Flow([j]) + submit_flow(flow, worker="test_local_worker") + + assert job_controller.pause_jobs(job_ids=(j.uuid, 1)) == ["1"] + assert job_controller.get_job_info(job_id=j.uuid).state == JobState.PAUSED + assert job_controller.get_flows_info(job_ids=j.uuid)[0].state == FlowState.PAUSED + + assert job_controller.play_jobs(job_ids=(j.uuid, 1)) == ["1"] + assert job_controller.get_job_info(job_id=j.uuid).state == JobState.READY + assert job_controller.get_flows_info(job_ids=j.uuid)[0].state == FlowState.READY + + +def test_stop(job_controller, one_job): + from jobflow_remote.jobs.state import FlowState, JobState + + j = one_job.jobs[0] + assert job_controller.stop_jobs(job_ids=(j.uuid, 1)) == ["1"] + assert job_controller.get_job_info(job_id=j.uuid).state == JobState.USER_STOPPED + assert job_controller.get_flows_info(job_ids=j.uuid)[0].state == FlowState.STOPPED + + +def test_unlock_jobs(job_controller, one_job): + j = one_job.jobs[0] + # catch the warning coming from MongoLock + with pytest.warns(UserWarning, match="Could not release lock for document"): + with job_controller.lock_job(filter={"uuid": j.uuid}): + assert job_controller.unlock_jobs(job_ids=(j.uuid, 1)) == 1 + + +def test_unlock_flows(job_controller, one_job): + # catch the warning coming from MongoLock + with pytest.warns(UserWarning, match="Could not release lock for document"): + with job_controller.lock_flow(filter={"uuid": one_job.uuid}): + assert job_controller.unlock_flows(flow_ids=one_job.uuid) == 1 + + +def test_set_job_run_properties(job_controller, one_job): + from qtoolkit import QResources + from qtoolkit.core.data_objects import ProcessPlacement + + from jobflow_remote.config.base import ExecutionConfig + + # test setting worker + with pytest.raises( + ValueError, match="worker missing_worker is not present in the project" + ): + job_controller.set_job_run_properties(worker="missing_worker") + + assert job_controller.set_job_run_properties(worker="test_local_worker_2") + assert ( + job_controller.get_job_info(job_id=one_job[0].uuid).worker + == "test_local_worker_2" + ) + + # test setting exec config + with pytest.raises( + ValueError, + match="exec_config missing_exec_config is not present in the project", + ): + job_controller.set_job_run_properties(exec_config="missing_exec_config") + + assert job_controller.set_job_run_properties(exec_config="test") + assert job_controller.get_job_doc(job_id=one_job[0].uuid).exec_config == "test" + + ec1 = ExecutionConfig(modules=["some_module"]) + ec2 = ExecutionConfig(pre_run="command") + ec3 = ExecutionConfig(modules=["some_module"], pre_run="command") + assert job_controller.set_job_run_properties(exec_config=ec1) + assert job_controller.get_job_doc(job_id=one_job[0].uuid).exec_config == ec1 + + assert job_controller.set_job_run_properties(exec_config=ec2, update=False) + assert job_controller.get_job_doc(job_id=one_job[0].uuid).exec_config == ec2 + + assert job_controller.set_job_run_properties( + exec_config={"modules": ["some_module"]}, update=True + ) + assert job_controller.get_job_doc(job_id=one_job[0].uuid).exec_config == ec3 + + # test setting resources + qr = QResources( + queue_name="test", process_placement=ProcessPlacement.NO_CONSTRAINTS + ) + + assert job_controller.set_job_run_properties(resources={"ntasks": 10}) + assert job_controller.get_job_doc(job_id=one_job[0].uuid).resources == { + "ntasks": 10 + } + assert job_controller.set_job_run_properties(resources={"nodes": 10}) + assert job_controller.get_job_doc(job_id=one_job[0].uuid).resources == { + "ntasks": 10, + "nodes": 10, + } + assert job_controller.set_job_run_properties(resources={"ntasks": 20}, update=False) + assert job_controller.get_job_doc(job_id=one_job[0].uuid).resources == { + "ntasks": 20 + } + + assert job_controller.set_job_run_properties(resources=qr) + assert job_controller.get_job_doc(job_id=one_job[0].uuid).resources == qr + + +def test_reset(job_controller, four_jobs): + assert job_controller.count_jobs() == 4 + + assert not job_controller.reset(max_limit=1) + assert job_controller.reset(max_limit=10, reset_output=True) + + assert job_controller.count_jobs() == 0 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 112d7589..53b5e5cf 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -22,7 +22,7 @@ def mock_fabric_run(monkeypatch): ) -def _get_free_port(upper_bound=50_000): +def _get_free_port(upper_bound=90_000): """Returns a random free port, with an upper bound. The upper bound is required as Docker does not have @@ -156,11 +156,6 @@ def mongo_container(docker_client, db_port): ) -@pytest.fixture(scope="session") -def random_project_name(): - return _get_random_name() - - @pytest.fixture(scope="session") def store_database_name(): return _get_random_name() @@ -237,6 +232,31 @@ def write_tmp_settings( resources={"partition": "debug", "ntasks": 1, "time": "00:01:00"}, connect_kwargs={"allow_agent": False, "look_for_keys": False}, ), + "test_batch_remote_worker": dict( + type="remote", + host="localhost", + port=slurm_ssh_port, + scheduler_type="slurm", + work_dir="/home/jobflow/jfr", + user="jobflow", + password="jobflow", + pre_run="source /home/jobflow/.venv/bin/activate", + resources={"partition": "debug", "ntasks": 1, "time": "00:01:00"}, + connect_kwargs={"allow_agent": False, "look_for_keys": False}, + batch={ + "jobs_handle_dir": "/home/jobflow/jfr/batch_handle", + "work_dir": "/home/jobflow/jfr/batch_work", + "max_wait": 10, + }, + max_jobs=1, + ), + "test_max_jobs_worker": dict( + type="local", + scheduler_type="shell", + work_dir=str(workdir), + resources={}, + max_jobs=2, + ), }, exec_config={"test": {"export": {"TESTING_ENV_VAR": random_project_name}}}, runner=dict( @@ -265,16 +285,3 @@ def job_controller(random_project_name): jc = JobController.from_project_name(random_project_name) assert jc.reset() yield jc - - -@pytest.fixture(scope="session") -def daemon_manager(random_project_name): - from jobflow_remote.jobs.daemon import DaemonManager - - yield DaemonManager.from_project_name(random_project_name) - - -@pytest.fixture(scope="session") -def runner(daemon_manager): - yield daemon_manager.start(raise_on_error=True) - daemon_manager.stop(raise_on_error=True) diff --git a/tests/integration/test_advanced_options.py b/tests/integration/test_advanced_options.py new file mode 100644 index 00000000..741e15fb --- /dev/null +++ b/tests/integration/test_advanced_options.py @@ -0,0 +1,76 @@ +import os + +import pytest + +pytestmark = pytest.mark.skipif( + not os.environ.get("CI"), + reason="Only run integration tests in CI, unless forced with 'CI' env var", +) + + +def test_run_batch(job_controller, monkeypatch): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add_sleep + + job_ids = [] + for i in range(3): + add_first = add_sleep(2, 1) + add_second = add_sleep(add_first.output, 1) + + flow = Flow([add_first, add_second]) + submit_flow(flow, worker="test_batch_remote_worker") + job_ids.append([add_first.uuid, add_second.uuid]) + + runner = Runner() + + # set this so it will be called + monkeypatch.setattr(runner.runner_options, "delay_update_batch", 5) + + runner.run_all_jobs(max_seconds=120) + + assert job_controller.count_jobs(state=JobState.COMPLETED) == 6 + + +def test_max_jobs_worker(job_controller, daemon_manager): + import time + + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add_sleep + + # run the daemon in background to check what happens to the + # jobs during the execution + daemon_manager.start(raise_on_error=True) + + job_ids = [] + for i in range(4): + j = add_sleep(2, 5) + job_ids.append((j.uuid, 1)) + flow = Flow([j]) + submit_flow(flow, worker="test_max_jobs_worker") + + finished_states = (JobState.REMOTE_ERROR, JobState.FAILED, JobState.COMPLETED) + running_states = (JobState.RUNNING, JobState.SUBMITTED) + + max_running_jobs = 0 + for i in range(20): + time.sleep(1) + jobs_info = job_controller.get_jobs_info(job_ids=job_ids) + if all(ji.state in finished_states for ji in jobs_info): + break + current_running = sum(ji.state in running_states for ji in jobs_info) + max_running_jobs = max(max_running_jobs, current_running) + + jobs_info = job_controller.get_jobs_info(job_ids=job_ids) + assert all(ji.state == JobState.COMPLETED for ji in jobs_info) + + # the max running jobs should be two, meaning that it was reached and cannot + # be larger. The check could be <= 2, but if it does not reach two it will + # not be testing some parts of the code and the test is not complete. + assert max_running_jobs == 2 diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index d396af04..1a2803da 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -15,7 +15,7 @@ def test_project_init(random_project_name): assert len(cm.projects) == 1 assert cm.projects[random_project_name] project = cm.get_project() - assert len(project.workers) == 2 + assert len(project.workers) == 4 def test_paramiko_ssh_connection(job_controller, slurm_ssh_port):