Skip to content

Commit

Permalink
Merge pull request #550 from opensafely-core/trace-async-state-changes
Browse files Browse the repository at this point in the history
trace async state changes
  • Loading branch information
bloodearnest authored Jan 9, 2023
2 parents 0b16493 + a1ca674 commit 137d466
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 28 deletions.
30 changes: 30 additions & 0 deletions jobrunner/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import secrets
import warnings
from contextlib import contextmanager
from datetime import datetime

Expand Down Expand Up @@ -57,3 +59,31 @@ def datestr_to_ns_timestamp(datestr):
ts += ns

return ts


def warn_assertions(f):
"""Helper decorator to catch assertions errors and emit as warnings.
In dev, this will cause tests to fail, and log output in prod.
Returns None, as that's the only thing it can reasonably do. As such, it
can only be used to decorate functions that also return None, and it emits
a warning for that too.
"""

@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
rvalue = f(*args, **kwargs)
if rvalue is not None:
raise AssertionError(
"warn_assertions can only be used on functions that return None:"
"{f.__name__} return {type(rvalue)}"
)
except AssertionError as exc:
# convert exception to warning
warnings.warn(str(exc))

return None

return wrapper
42 changes: 33 additions & 9 deletions jobrunner/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,24 @@ def handle_job(job, api, mode=None, paused=None):
else:
EXECUTOR_RETRIES.pop(job.id, None)

# check if we've transitioned since we last checked and trace it.
if initial_status.state in STATE_MAP:
initial_code, initial_message = STATE_MAP[initial_status.state]
if initial_code != job.status_code:
set_code(
job,
initial_code,
initial_message,
timestamp_ns=initial_status.timestamp_ns,
)

# handle the simple no change needed states.
if initial_status.state in STABLE_STATES:
if job.state == State.PENDING:
log.warning(
f"state error: got {initial_status.state} for a job we thought was PENDING"
)
# no action needed, simply update job message and timestamp
# no action needed, simply update job message and timestamp, which is likely a no-op
code, message = STATE_MAP[initial_status.state]
set_code(job, code, message)
return
Expand Down Expand Up @@ -334,7 +345,8 @@ def handle_job(job, api, mode=None, paused=None):
code, message = STATE_MAP[new_status.state]

# special case PENDING -> RUNNING transition
if new_status.state == ExecutorState.PREPARING:
# allow both states to do the transition to RUNNING, due to synchronous transitions
if new_status.state in [ExecutorState.PREPARING, ExecutorState.PREPARED]:
set_state(job, State.RUNNING, code, message)
else:
if job.state != State.RUNNING:
Expand Down Expand Up @@ -522,22 +534,34 @@ def set_state(job, state, code, message, error=None, results=None, **attrs):
set_code(job, code, message, error=error, results=results, **attrs)


def set_code(job, code, message, error=None, results=None, **attrs):
def set_code(job, code, message, error=None, results=None, timestamp_ns=None, **attrs):
"""Set the granular status code state.
We also trace this transition with OpenTelemetry traces.
Note: timestamp precision in the db is to the nearest second, which made
sense when we were tracking fewer high level states. But now we are
tracking more granular states, subsecond precision is needed to avoid odd
collisions when states transition in <1s.
collisions when states transition in <1s. Due to this, timestamp parameter
should be the output of time.time() i.e. a float representing seconds.
"""
timestamp = time.time()
timestamp_s = int(timestamp)
timestamp_ns = int(timestamp * 1e9)
if timestamp_ns is None:
t = time.time()
timestamp_s = int(t)
timestamp_ns = int(t * 1e9)
else:
timestamp_s = int(timestamp_ns // 1e9)

if job.status_code_updated_at > timestamp_ns:
# we somehow have a negative duration, which honeycomb does funny things with.
# This can happen in tests, where things are fast, but we've seen it in production too.
log.warning(
f"negative state duration, clamping to 1ms ({job.status_code_updated_at} > {timestamp_ns})"
)
timestamp_ns = job.status_code_updated_at + 1e6 # set duration to 1ms
timestamp_s = int(timestamp_ns // 1e9)

# if code has changed then log
# if code has changed then trace it and update
if job.status_code != code:

# job trace: we finished the previous state
Expand Down
27 changes: 25 additions & 2 deletions jobrunner/tracing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from datetime import datetime

from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
Expand All @@ -9,7 +10,7 @@
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from jobrunner import config
from jobrunner.lib import database
from jobrunner.lib import database, warn_assertions
from jobrunner.models import Job, SavedJobRequest, State, StatusCode


Expand Down Expand Up @@ -48,6 +49,7 @@ def setup_default_tracing():
add_exporter(ConsoleSpanExporter())


@warn_assertions
def initialise_trace(job):
"""Initialise the trace for this job by creating a root span.
Expand Down Expand Up @@ -120,7 +122,9 @@ def record_final_state(job, timestamp_ns, error=None, results=None, **attrs):
# final states have no duration, so make last for 1 sec, just act
# as a marker
end_time = int(timestamp_ns + 1e9)
record_job_span(job, name, start_time, end_time, error, results, **attrs)
record_job_span(
job, name, start_time, end_time, error, results, final=True, **attrs
)

complete_job(job, timestamp_ns, error, results, **attrs)
except Exception:
Expand Down Expand Up @@ -160,11 +164,30 @@ def load_trace_context(job):
return propagation.set_span_in_context(trace.NonRecordingSpan(span_context), {})


MINIMUM_NS_TIMESTAMP = int(datetime(2000, 1, 1, 0, 0, 0).timestamp() * 1e9)


@warn_assertions
def record_job_span(job, name, start_time, end_time, error, results, **attrs):
"""Record a span for a job."""
if not _traceable(job):
return

# Due to @warn_assertions, this will be emitted as warnings in test, but
# the calling code swallows any exceptions.
assert start_time is not None
assert end_time is not None
assert (
start_time > MINIMUM_NS_TIMESTAMP
), f"start_time not in nanoseconds: {start_time}"
assert end_time > MINIMUM_NS_TIMESTAMP, f"end_time not in nanoseconds: {end_time}"
# Note: windows timer precision is low, so we sometimes get the same
# value of ns for two separate measurments. This means they are not always
# increasing, but they should never decrease. At least in theory...
assert (
end_time >= start_time
), f"end_time is before start_time, ({end_time} < {start_time})"

ctx = load_trace_context(job)
tracer = trace.get_tracer("jobs")
span = tracer.start_span(name, context=ctx, start_time=start_time)
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ lines_after_imports = 2
skip_glob = [".direnv", ".venv", "venv"]

[tool.pytest.ini_options]
filterwarnings = [
"error",
"ignore::DeprecationWarning:opentelemetry.*:",
"ignore::DeprecationWarning:pytest_freezegun:17",
"ignore::DeprecationWarning:pytest_responses:9",
]

[tool.setuptools.packages.find]
include = ["jobrunner*"]
Expand Down
35 changes: 23 additions & 12 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def job_factory(job_request=None, **kwargs):
if "updated_at" not in kwargs:
values["updated_at"] = int(timestamp)
if "status_code_updated_at" not in kwargs:
values["status_code_updated_at"] = int(values["created_at"] * 1e9)
values["status_code_updated_at"] = int(timestamp * 1e9)
values.update(kwargs)

values["job_request_id"] = job_request.id
Expand All @@ -94,10 +94,12 @@ def job_factory(job_request=None, **kwargs):
return job


def job_results_factory(**kwargs):
def job_results_factory(timestamp_ns=None, **kwargs):
if timestamp_ns is None:
timestamp_ns = time.time_ns()
values = deepcopy(JOB_RESULTS_DEFAULTS)
values.update(kwargs)
return JobResults(**values)
return JobResults(timestamp_ns=timestamp_ns, **values)


class StubExecutorAPI:
Expand Down Expand Up @@ -131,38 +133,46 @@ def __init__(self):
self.results = {}
self.state = {}
self.deleted = defaultdict(lambda: defaultdict(list))
self.last_time = int(time.time())

def add_test_job(
self,
exec_state,
job_state,
status_code=StatusCode.CREATED,
message="message",
timestamp=None,
**kwargs,
):
"""Create and track a db job object."""

job = job_factory(state=job_state, status_code=status_code, **kwargs)
if exec_state != ExecutorState.UNKNOWN:
self.state[job.id] = JobStatus(exec_state, message)
self.set_job_state(job, exec_state, message)
return job

def set_job_state(self, definition, state, message="message"):
def set_job_state(self, definition, state, message="message", timestamp_ns=None):
"""Directly set a job state."""
# handle the synchronous state meaning the state has completed
if timestamp_ns is None:
timestamp_ns = time.time_ns()
synchronous = getattr(self, "synchronous_transitions", [])
if state in synchronous:
if state == ExecutorState.PREPARING:
state = ExecutorState.PREPARED
if state == ExecutorState.FINALIZING:
state = ExecutorState.FINALIZED
self.state[definition.id] = JobStatus(state, message)
self.state[definition.id] = JobStatus(state, message, timestamp_ns)

def set_job_transition(self, definition, state, message="executor message"):
def set_job_transition(
self, definition, state, message="executor message", timestamp_ns=None
):
"""Set the next transition for this job when called"""
self.transitions[definition.id] = (state, message)
self.transitions[definition.id] = (state, message, timestamp_ns)

def set_job_result(self, definition, **kwargs):
def set_job_result(self, definition, timestamp_ns=None, **kwargs):
if timestamp_ns is None:
timestamp_ns = time.time_ns()
defaults = {
"outputs": {},
"unmatched_patterns": [],
Expand All @@ -176,17 +186,18 @@ def set_job_result(self, definition, **kwargs):

def do_transition(self, definition, expected, next_state):
current = self.get_status(definition)
timestamp_ns = time.time_ns()
if current.state != expected:
state = current.state
message = f"Invalid transition to {next_state}, currently state is {current.state}"
elif definition.id in self.transitions:
state, message = self.transitions[definition.id]
state, message, timestamp_ns = self.transitions[definition.id]
else:
state = next_state
message = "executor message"

self.set_job_state(definition, state)
return JobStatus(state, message)
self.set_job_state(definition, state, message, timestamp_ns)
return JobStatus(state, message, timestamp_ns)

def prepare(self, definition):
self.tracker["prepare"].add(definition.id)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_integration_with_cohortextractor(
# Check that the manifest contains what we expect. This is a subset of what used to be in the manifest, to support
# nicer UX for osrelease.
manifest_file = medium_privacy_workspace / "metadata" / "manifest.json"
manifest = json.load(manifest_file.open())
manifest = json.loads(manifest_file.read_text())
assert manifest["workspace"] == "testing"
assert manifest["repo"] == str(test_repo.path)

Expand Down Expand Up @@ -200,8 +200,8 @@ def test_integration_with_cohortextractor(
# labels.
assert not any(s.attributes["action_created"] == "unknown" for s in executed_jobs)

job_spans = [s for s in get_trace("loop") if s.name == "LOOP"]
assert len(job_spans) > 1
loop_spans = [s for s in get_trace("loop") if s.name == "LOOP"]
assert len(loop_spans) > 1


@pytest.mark.slow_test
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_integration_with_databuilder(
# Check that the manifest contains what we expect. This is a subset of what used to be in the manifest, to support
# nicer UX for osrelease.
manifest_file = medium_privacy_workspace / "metadata" / "manifest.json"
manifest = json.load(manifest_file.open())
manifest = json.loads(manifest_file.read_text())
assert manifest["workspace"] == "testing"
assert manifest["repo"] == str(test_repo.path)

Expand Down
Loading

0 comments on commit 137d466

Please sign in to comment.