Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow list of jobs from scheduler using username #187

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/source/user/projectconf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ configurations. Those can better be set with the :ref:`projectconf execconfig`.
If a single worker is defined it will be used as default in the submission
of new Flows.

.. warning::

By default, jobflow-remote fetches the status of the jobs from the scheduler by
passing the list of ids. If the selected scheduler does not support this option
(e.g. SGE), it is also necessary to specify the username on the worker machine
through the ``scheduler_username`` option. Jobflow-remote will use that as a filter,
instead of the list of ids.

.. _projectconf jobstore:

JobStore
Expand Down
6 changes: 6 additions & 0 deletions src/jobflow_remote/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ class WorkerBase(BaseModel):
None,
description="Options for batch execution. If define the worker will be considered a batch worker",
)
scheduler_username: Optional[str] = Field(
None,
description="If defined, the list of jobs running on the worker will be fetched based on the"
"username instead that from the list of job ids. May be necessary for some "
"scheduler_type (e.g. SGE)",
)
gpetretto marked this conversation as resolved.
Show resolved Hide resolved
model_config = ConfigDict(extra="forbid")

@field_validator("scheduler_type")
Expand Down
19 changes: 14 additions & 5 deletions src/jobflow_remote/jobs/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,12 +844,18 @@ def check_run_status(self, filter: dict | None = None) -> None: # noqa: A002
if not ids_docs:
continue

worker = self.get_worker(worker_name)

qjobs_dict = {}
try:
ids_list = list(ids_docs)
queue = self.get_queue_manager(worker_name)
qjobs = queue.get_jobs_list(ids_list)
qjobs_dict = {qjob.job_id: qjob for qjob in qjobs}
qjobs = queue.get_jobs_list(
jobs=ids_list, user=worker.scheduler_username
)
qjobs_dict = {
qjob.job_id: qjob for qjob in qjobs if qjob.job_id in ids_list
}
except Exception:
logger.warning(
f"error trying to get jobs list for worker: {worker_name}",
Expand All @@ -874,7 +880,6 @@ def check_run_status(self, filter: dict | None = None) -> None: # noqa: A002
f"remote job with id {remote_doc['process_id']} is running"
)
elif qstate in [None, QState.DONE, QState.FAILED]:
worker = self.get_worker(worker_name)
# if the worker is local go directly to DOWNLOADED, as files
# are not copied locally
if not worker.is_local:
Expand Down Expand Up @@ -995,8 +1000,12 @@ def update_batch_jobs(self) -> None:
processes = list(batch_processes_data)
queue_manager = self.get_queue_manager(worker_name)
if processes:
qjobs = queue_manager.get_jobs_list(processes)
running_processes = {qjob.job_id for qjob in qjobs}
qjobs = queue_manager.get_jobs_list(
jobs=processes, user=worker.scheduler_username
)
running_processes = {
qjob.job_id for qjob in qjobs if qjob.job_id in processes
}
stopped_processes = set(processes) - running_processes
for pid in stopped_processes:
self.job_controller.remove_batch_process(pid, worker_name)
Expand Down
4 changes: 4 additions & 0 deletions src/jobflow_remote/remote/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def get_jobs_list(
user: str | None = None,
timeout: int | None = None,
) -> list[QJob]:
# in order to avoid issues with schedulers that do not support query by
# list of job ids, if the user is passed ignore the job ids.
if user is not None:
jobs = None
job_cmd = self.scheduler_io.get_jobs_list_cmd(jobs, user)
stdout, stderr, returncode = self.execute_cmd(job_cmd, timeout=timeout)
return self.scheduler_io.parse_jobs_list_output(
Expand Down
57 changes: 57 additions & 0 deletions tests/integration/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,60 @@ def test_undefined_additional_stores(worker, job_controller) -> None:
job_controller.count_jobs(states=[JobState.COMPLETED, JobState.REMOTE_ERROR])
== 2
)


def test_submit_flow_with_scheduler_username(monkeypatch, job_controller) -> None:
from jobflow import Flow

from jobflow_remote import submit_flow
from jobflow_remote.jobs.runner import Runner
from jobflow_remote.jobs.state import FlowState, JobState
from jobflow_remote.remote.queue import QueueManager
from jobflow_remote.testing import add

remote_worker_name = "test_remote_worker"

job = add(1, 1)
flow = Flow([job])
submit_flow(flow, worker=remote_worker_name)

# modify the runner so that uses a patched version of the worker
# where the scheduler_username is set
runner = Runner()
patched_worker = runner.get_worker(remote_worker_name).model_copy()
patched_worker.scheduler_username = "jobflow"

def patched_get_worker(self, worker_name):
if worker_name != remote_worker_name:
return runner.workers[worker_name]
return patched_worker

# Patch the get_jobs_list function to ensure that it is called with
# the correct parameters.
patch_called = False
user_arg = None
orig_get_jobs_list = QueueManager.get_jobs_list

def patched_get_jobs_list(self, jobs=None, user=None, timeout=None):
nonlocal patch_called
nonlocal user_arg
patch_called = True
user_arg = user
return orig_get_jobs_list(self=self, jobs=jobs, user=user, timeout=timeout)

with monkeypatch.context() as m:
m.setattr(Runner, "get_worker", patched_get_worker)
m.setattr(QueueManager, "get_jobs_list", patched_get_jobs_list)
runner.run_all_jobs(max_seconds=30)

assert patch_called, "The patched method was not called"
assert (
user_arg == "jobflow"
), f"The argument for user passed to QueueManager.get_jobs_list is '{user_arg}' instead of 'jobflow'"

assert (
job_controller.count_jobs(states=JobState.COMPLETED) == 1
), f"Jobs not marked as completed, full job info:\n{job_controller.get_jobs({})}"
assert (
job_controller.count_flows(states=FlowState.COMPLETED) == 1
), f"Flows not marked as completed, full flow info:\n{job_controller.get_flows({})}"