diff --git a/doc/source/user/projectconf.rst b/doc/source/user/projectconf.rst index f9017fbe..9f23cb16 100644 --- a/doc/source/user/projectconf.rst +++ b/doc/source/user/projectconf.rst @@ -103,6 +103,14 @@ configurations. Those can better be set with the :ref:`projectconf execconfig`. If a single worker is defined it will be used as default in the submission of new Flows. +.. warning:: + + By default, jobflow-remote fetches the status of the jobs from the scheduler by + passing the list of ids. If the selected scheduler does not support this option + (e.g. SGE), it is also necessary to specify the username on the worker machine + through the ``scheduler_username`` option. Jobflow-remote will use that as a filter, + instead of the list of ids. + .. _projectconf jobstore: JobStore diff --git a/src/jobflow_remote/config/base.py b/src/jobflow_remote/config/base.py index 2f6a136e..f228621d 100644 --- a/src/jobflow_remote/config/base.py +++ b/src/jobflow_remote/config/base.py @@ -177,6 +177,12 @@ class WorkerBase(BaseModel): None, description="Options for batch execution. If define the worker will be considered a batch worker", ) + scheduler_username: Optional[str] = Field( + None, + description="If defined, the list of jobs running on the worker will be fetched based on the" + "username instead that from the list of job ids. May be necessary for some " + "scheduler_type (e.g. SGE)", + ) model_config = ConfigDict(extra="forbid") @field_validator("scheduler_type") diff --git a/src/jobflow_remote/jobs/runner.py b/src/jobflow_remote/jobs/runner.py index 2855ba86..42cfe956 100644 --- a/src/jobflow_remote/jobs/runner.py +++ b/src/jobflow_remote/jobs/runner.py @@ -844,12 +844,18 @@ def check_run_status(self, filter: dict | None = None) -> None: # noqa: A002 if not ids_docs: continue + worker = self.get_worker(worker_name) + qjobs_dict = {} try: ids_list = list(ids_docs) queue = self.get_queue_manager(worker_name) - qjobs = queue.get_jobs_list(ids_list) - qjobs_dict = {qjob.job_id: qjob for qjob in qjobs} + qjobs = queue.get_jobs_list( + jobs=ids_list, user=worker.scheduler_username + ) + qjobs_dict = { + qjob.job_id: qjob for qjob in qjobs if qjob.job_id in ids_list + } except Exception: logger.warning( f"error trying to get jobs list for worker: {worker_name}", @@ -874,7 +880,6 @@ def check_run_status(self, filter: dict | None = None) -> None: # noqa: A002 f"remote job with id {remote_doc['process_id']} is running" ) elif qstate in [None, QState.DONE, QState.FAILED]: - worker = self.get_worker(worker_name) # if the worker is local go directly to DOWNLOADED, as files # are not copied locally if not worker.is_local: @@ -995,8 +1000,12 @@ def update_batch_jobs(self) -> None: processes = list(batch_processes_data) queue_manager = self.get_queue_manager(worker_name) if processes: - qjobs = queue_manager.get_jobs_list(processes) - running_processes = {qjob.job_id for qjob in qjobs} + qjobs = queue_manager.get_jobs_list( + jobs=processes, user=worker.scheduler_username + ) + running_processes = { + qjob.job_id for qjob in qjobs if qjob.job_id in processes + } stopped_processes = set(processes) - running_processes for pid in stopped_processes: self.job_controller.remove_batch_process(pid, worker_name) diff --git a/src/jobflow_remote/remote/queue.py b/src/jobflow_remote/remote/queue.py index 9ce7645b..ec3e8ba3 100644 --- a/src/jobflow_remote/remote/queue.py +++ b/src/jobflow_remote/remote/queue.py @@ -220,6 +220,10 @@ def get_jobs_list( user: str | None = None, timeout: int | None = None, ) -> list[QJob]: + # in order to avoid issues with schedulers that do not support query by + # list of job ids, if the user is passed ignore the job ids. + if user is not None: + jobs = None job_cmd = self.scheduler_io.get_jobs_list_cmd(jobs, user) stdout, stderr, returncode = self.execute_cmd(job_cmd, timeout=timeout) return self.scheduler_io.parse_jobs_list_output( diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index f7e2ce34..2d242fa4 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -305,3 +305,60 @@ def test_undefined_additional_stores(worker, job_controller) -> None: job_controller.count_jobs(states=[JobState.COMPLETED, JobState.REMOTE_ERROR]) == 2 ) + + +def test_submit_flow_with_scheduler_username(monkeypatch, job_controller) -> None: + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import FlowState, JobState + from jobflow_remote.remote.queue import QueueManager + from jobflow_remote.testing import add + + remote_worker_name = "test_remote_worker" + + job = add(1, 1) + flow = Flow([job]) + submit_flow(flow, worker=remote_worker_name) + + # modify the runner so that uses a patched version of the worker + # where the scheduler_username is set + runner = Runner() + patched_worker = runner.get_worker(remote_worker_name).model_copy() + patched_worker.scheduler_username = "jobflow" + + def patched_get_worker(self, worker_name): + if worker_name != remote_worker_name: + return runner.workers[worker_name] + return patched_worker + + # Patch the get_jobs_list function to ensure that it is called with + # the correct parameters. + patch_called = False + user_arg = None + orig_get_jobs_list = QueueManager.get_jobs_list + + def patched_get_jobs_list(self, jobs=None, user=None, timeout=None): + nonlocal patch_called + nonlocal user_arg + patch_called = True + user_arg = user + return orig_get_jobs_list(self=self, jobs=jobs, user=user, timeout=timeout) + + with monkeypatch.context() as m: + m.setattr(Runner, "get_worker", patched_get_worker) + m.setattr(QueueManager, "get_jobs_list", patched_get_jobs_list) + runner.run_all_jobs(max_seconds=30) + + assert patch_called, "The patched method was not called" + assert ( + user_arg == "jobflow" + ), f"The argument for user passed to QueueManager.get_jobs_list is '{user_arg}' instead of 'jobflow'" + + assert ( + job_controller.count_jobs(states=JobState.COMPLETED) == 1 + ), f"Jobs not marked as completed, full job info:\n{job_controller.get_jobs({})}" + assert ( + job_controller.count_flows(states=FlowState.COMPLETED) == 1 + ), f"Flows not marked as completed, full flow info:\n{job_controller.get_flows({})}"