Skip to content

Commit

Permalink
Merge pull request #509 from opensafely-core/db-workers
Browse files Browse the repository at this point in the history
feat: add MAX_DB_WORKERS to control DB job parallelism
  • Loading branch information
bloodearnest authored Nov 4, 2022
2 parents 27d5e8d + cca0405 commit 1268a2c
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 7 deletions.
1 change: 1 addition & 0 deletions jobrunner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _is_valid_backend_name(name):


MAX_WORKERS = int(os.environ.get("MAX_WORKERS") or max(cpu_count() - 1, 1))
MAX_DB_WORKERS = int(os.environ.get("MAX_DB_WORKERS") or MAX_WORKERS)

# This is a crude mechanism for preventing a single large JobRequest with lots
# of associated Jobs from hogging all the resources. We want this configurable
Expand Down
15 changes: 15 additions & 0 deletions jobrunner/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import datetime
import hashlib
import secrets
import shlex
from enum import Enum

from jobrunner.lib.commands import requires_db_access
from jobrunner.lib.database import databaseclass, migration
from jobrunner.lib.string_utils import slugify

Expand Down Expand Up @@ -46,6 +48,8 @@ class StatusCode(Enum):
WAITING_ON_DEPENDENCIES = "waiting_on_dependencies"
# waiting on available resources to run the job
WAITING_ON_WORKERS = "waiting_on_workers"
# waiting on available db resources to run the job
WAITING_ON_DB_WORKERS = "waiting_on_db_workers"
# reset for reboot
WAITING_ON_REBOOT = "waiting_on_reboot"

Expand Down Expand Up @@ -290,6 +294,17 @@ def output_files(self):
else:
return []

@property
def action_args(self):
if self.run_command:
return shlex.split(self.run_command)
else:
return []

@property
def requires_db(self):
return requires_db_access(self.action_args)


def deterministic_id(seed):
digest = hashlib.sha1(seed.encode("utf-8")).digest()
Expand Down
24 changes: 17 additions & 7 deletions jobrunner/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import datetime
import logging
import random
import shlex
import sys
import time
from typing import Optional
Expand All @@ -20,7 +19,6 @@
Privacy,
Study,
)
from jobrunner.lib.commands import requires_db_access
from jobrunner.lib.database import find_where, select_values, update
from jobrunner.lib.log_utils import configure_logging, set_log_context
from jobrunner.models import FINAL_STATUS_CODES, Job, State, StatusCode
Expand Down Expand Up @@ -239,7 +237,8 @@ def handle_job(job, api, mode=None, paused=None):
# work
not_started_reason = get_reason_job_not_started(job)
if not_started_reason:
set_code(job, StatusCode.WAITING_ON_WORKERS, not_started_reason)
code, message = not_started_reason
set_code(job, code, message)
return

expected_state = ExecutorState.PREPARING
Expand Down Expand Up @@ -387,11 +386,10 @@ def get_obsolete_files(definition, outputs):

def job_to_job_definition(job):

action_args = shlex.split(job.run_command)
allow_database_access = False
env = {"OPENSAFELY_BACKEND": config.BACKEND}
# Check `is True` so we fail closed if we ever get anything else
if requires_db_access(action_args) is True:
if job.requires_db:
if not config.USING_DUMMY_DATA_BACKEND:
allow_database_access = True
env["DATABASE_URL"] = config.DATABASE_URLS[job.database_name]
Expand All @@ -403,6 +401,7 @@ def job_to_job_definition(job):
if config.EMIS_ORGANISATION_HASH:
env["EMIS_ORGANISATION_HASH"] = config.EMIS_ORGANISATION_HASH
# Prepend registry name
action_args = job.action_args
image = action_args.pop(0)
full_image = f"{config.DOCKER_REGISTRY}/{image}"
if image.startswith("stata-mp"):
Expand Down Expand Up @@ -557,9 +556,20 @@ def get_reason_job_not_started(job):
required_resources = get_job_resource_weight(job)
if used_resources + required_resources > config.MAX_WORKERS:
if required_resources > 1:
return "Waiting on available workers for resource intensive job"
return (
StatusCode.WAITING_ON_WORKERS,
"Waiting on available workers for resource intensive job",
)
else:
return "Waiting on available workers"
return StatusCode.WAITING_ON_WORKERS, "Waiting on available workers"

if job.requires_db:
running_db_jobs = len([j for j in running_jobs if j.requires_db])
if running_db_jobs >= config.MAX_DB_WORKERS:
return (
StatusCode.WAITING_ON_DB_WORKERS,
"Waiting on available database workers",
)


def list_outputs_from_action(workspace, action):
Expand Down
1 change: 1 addition & 0 deletions jobrunner/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def trace_attributes(job):
orgs=",".join(job._job_request.get("orgs", [])),
state=job.state.name,
message=job.status_message,
requires_db=job.requires_db,
)

# local_run jobs don't have a commit
Expand Down
24 changes: 24 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,30 @@ def test_handle_job_waiting_on_workers(monkeypatch, db):
assert spans[-1].name == "ENTER WAITING_ON_WORKERS"


def test_handle_job_waiting_on_db_workers(monkeypatch, db):
monkeypatch.setattr(config, "MAX_DB_WORKERS", 0)
api = StubExecutorAPI()
job = api.add_test_job(
ExecutorState.UNKNOWN,
State.PENDING,
run_command="cohortextractor:latest generate_cohort",
)

run.handle_job(job, api)

# executor doesn't even know about it
assert job.id not in api.tracker["prepare"]

assert job.state == State.PENDING
assert job.status_message == "Waiting on available database workers"
assert job.status_code == StatusCode.WAITING_ON_DB_WORKERS

# tracing
spans = get_trace()
assert spans[-2].name == "CREATED"
assert spans[-1].name == "ENTER WAITING_ON_DB_WORKERS"


@pytest.mark.parametrize(
"exec_state,job_state,code,tracker",
[
Expand Down
2 changes: 2 additions & 0 deletions tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_trace_attributes(db):
state="PENDING",
message="message",
reusable_action="action_repo:commit",
requires_db=False,
)


Expand Down Expand Up @@ -77,6 +78,7 @@ def test_trace_attributes_missing(db):
orgs="org1,org2",
state="PENDING",
message="message",
requires_db=False,
)


Expand Down

0 comments on commit 1268a2c

Please sign in to comment.