diff --git a/jobrunner/lib/__init__.py b/jobrunner/lib/__init__.py index df5f6707..d90cad81 100644 --- a/jobrunner/lib/__init__.py +++ b/jobrunner/lib/__init__.py @@ -1,4 +1,6 @@ +import functools import secrets +import warnings from contextlib import contextmanager from datetime import datetime @@ -57,3 +59,31 @@ def datestr_to_ns_timestamp(datestr): ts += ns return ts + + +def warn_assertions(f): + """Helper decorator to catch assertions errors and emit as warnings. + + In dev, this will cause tests to fail, and log output in prod. + + Returns None, as that's the only thing it can reasonably do. As such, it + can only be used to decorate functions that also return None, and it emits + a warning for that too. + """ + + @functools.wraps(f) + def wrapper(*args, **kwargs): + try: + rvalue = f(*args, **kwargs) + if rvalue is not None: + raise AssertionError( + "warn_assertions can only be used on functions that return None:" + "{f.__name__} return {type(rvalue)}" + ) + except AssertionError as exc: + # convert exception to warning + warnings.warn(str(exc)) + + return None + + return wrapper diff --git a/jobrunner/run.py b/jobrunner/run.py index 5bfd35ba..00683e5f 100644 --- a/jobrunner/run.py +++ b/jobrunner/run.py @@ -235,13 +235,24 @@ def handle_job(job, api, mode=None, paused=None): else: EXECUTOR_RETRIES.pop(job.id, None) + # check if we've transitioned since we last checked and trace it. + if initial_status.state in STATE_MAP: + initial_code, initial_message = STATE_MAP[initial_status.state] + if initial_code != job.status_code: + set_code( + job, + initial_code, + initial_message, + timestamp_ns=initial_status.timestamp_ns, + ) + # handle the simple no change needed states. if initial_status.state in STABLE_STATES: if job.state == State.PENDING: log.warning( f"state error: got {initial_status.state} for a job we thought was PENDING" ) - # no action needed, simply update job message and timestamp + # no action needed, simply update job message and timestamp, which is likely a no-op code, message = STATE_MAP[initial_status.state] set_code(job, code, message) return @@ -334,7 +345,8 @@ def handle_job(job, api, mode=None, paused=None): code, message = STATE_MAP[new_status.state] # special case PENDING -> RUNNING transition - if new_status.state == ExecutorState.PREPARING: + # allow both states to do the transition to RUNNING, due to synchronous transitions + if new_status.state in [ExecutorState.PREPARING, ExecutorState.PREPARED]: set_state(job, State.RUNNING, code, message) else: if job.state != State.RUNNING: @@ -522,7 +534,7 @@ def set_state(job, state, code, message, error=None, results=None, **attrs): set_code(job, code, message, error=error, results=results, **attrs) -def set_code(job, code, message, error=None, results=None, **attrs): +def set_code(job, code, message, error=None, results=None, timestamp_ns=None, **attrs): """Set the granular status code state. We also trace this transition with OpenTelemetry traces. @@ -530,14 +542,26 @@ def set_code(job, code, message, error=None, results=None, **attrs): Note: timestamp precision in the db is to the nearest second, which made sense when we were tracking fewer high level states. But now we are tracking more granular states, subsecond precision is needed to avoid odd - collisions when states transition in <1s. - + collisions when states transition in <1s. Due to this, timestamp parameter + should be the output of time.time() i.e. a float representing seconds. """ - timestamp = time.time() - timestamp_s = int(timestamp) - timestamp_ns = int(timestamp * 1e9) + if timestamp_ns is None: + t = time.time() + timestamp_s = int(t) + timestamp_ns = int(t * 1e9) + else: + timestamp_s = int(timestamp_ns // 1e9) + + if job.status_code_updated_at > timestamp_ns: + # we somehow have a negative duration, which honeycomb does funny things with. + # This can happen in tests, where things are fast, but we've seen it in production too. + log.warning( + f"negative state duration, clamping to 1ms ({job.status_code_updated_at} > {timestamp_ns})" + ) + timestamp_ns = job.status_code_updated_at + 1e6 # set duration to 1ms + timestamp_s = int(timestamp_ns // 1e9) - # if code has changed then log + # if code has changed then trace it and update if job.status_code != code: # job trace: we finished the previous state diff --git a/jobrunner/tracing.py b/jobrunner/tracing.py index da013011..217cb15d 100644 --- a/jobrunner/tracing.py +++ b/jobrunner/tracing.py @@ -1,5 +1,6 @@ import logging import os +from datetime import datetime from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter @@ -9,7 +10,7 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from jobrunner import config -from jobrunner.lib import database +from jobrunner.lib import database, warn_assertions from jobrunner.models import Job, SavedJobRequest, State, StatusCode @@ -48,6 +49,7 @@ def setup_default_tracing(): add_exporter(ConsoleSpanExporter()) +@warn_assertions def initialise_trace(job): """Initialise the trace for this job by creating a root span. @@ -120,7 +122,9 @@ def record_final_state(job, timestamp_ns, error=None, results=None, **attrs): # final states have no duration, so make last for 1 sec, just act # as a marker end_time = int(timestamp_ns + 1e9) - record_job_span(job, name, start_time, end_time, error, results, **attrs) + record_job_span( + job, name, start_time, end_time, error, results, final=True, **attrs + ) complete_job(job, timestamp_ns, error, results, **attrs) except Exception: @@ -160,11 +164,30 @@ def load_trace_context(job): return propagation.set_span_in_context(trace.NonRecordingSpan(span_context), {}) +MINIMUM_NS_TIMESTAMP = int(datetime(2000, 1, 1, 0, 0, 0).timestamp() * 1e9) + + +@warn_assertions def record_job_span(job, name, start_time, end_time, error, results, **attrs): """Record a span for a job.""" if not _traceable(job): return + # Due to @warn_assertions, this will be emitted as warnings in test, but + # the calling code swallows any exceptions. + assert start_time is not None + assert end_time is not None + assert ( + start_time > MINIMUM_NS_TIMESTAMP + ), f"start_time not in nanoseconds: {start_time}" + assert end_time > MINIMUM_NS_TIMESTAMP, f"end_time not in nanoseconds: {end_time}" + # Note: windows timer precision is low, so we sometimes get the same + # value of ns for two separate measurments. This means they are not always + # increasing, but they should never decrease. At least in theory... + assert ( + end_time >= start_time + ), f"end_time is before start_time, ({end_time} < {start_time})" + ctx = load_trace_context(job) tracer = trace.get_tracer("jobs") span = tracer.start_span(name, context=ctx, start_time=start_time) diff --git a/pyproject.toml b/pyproject.toml index 33771b25..3e4ac653 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,12 @@ lines_after_imports = 2 skip_glob = [".direnv", ".venv", "venv"] [tool.pytest.ini_options] +filterwarnings = [ + "error", + "ignore::DeprecationWarning:opentelemetry.*:", + "ignore::DeprecationWarning:pytest_freezegun:17", + "ignore::DeprecationWarning:pytest_responses:9", +] [tool.setuptools.packages.find] include = ["jobrunner*"] diff --git a/tests/factories.py b/tests/factories.py index a9db08df..fe9f5ca2 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -78,7 +78,7 @@ def job_factory(job_request=None, **kwargs): if "updated_at" not in kwargs: values["updated_at"] = int(timestamp) if "status_code_updated_at" not in kwargs: - values["status_code_updated_at"] = int(values["created_at"] * 1e9) + values["status_code_updated_at"] = int(timestamp * 1e9) values.update(kwargs) values["job_request_id"] = job_request.id @@ -94,10 +94,12 @@ def job_factory(job_request=None, **kwargs): return job -def job_results_factory(**kwargs): +def job_results_factory(timestamp_ns=None, **kwargs): + if timestamp_ns is None: + timestamp_ns = time.time_ns() values = deepcopy(JOB_RESULTS_DEFAULTS) values.update(kwargs) - return JobResults(**values) + return JobResults(timestamp_ns=timestamp_ns, **values) class StubExecutorAPI: @@ -131,6 +133,7 @@ def __init__(self): self.results = {} self.state = {} self.deleted = defaultdict(lambda: defaultdict(list)) + self.last_time = int(time.time()) def add_test_job( self, @@ -138,31 +141,38 @@ def add_test_job( job_state, status_code=StatusCode.CREATED, message="message", + timestamp=None, **kwargs, ): """Create and track a db job object.""" job = job_factory(state=job_state, status_code=status_code, **kwargs) if exec_state != ExecutorState.UNKNOWN: - self.state[job.id] = JobStatus(exec_state, message) + self.set_job_state(job, exec_state, message) return job - def set_job_state(self, definition, state, message="message"): + def set_job_state(self, definition, state, message="message", timestamp_ns=None): """Directly set a job state.""" # handle the synchronous state meaning the state has completed + if timestamp_ns is None: + timestamp_ns = time.time_ns() synchronous = getattr(self, "synchronous_transitions", []) if state in synchronous: if state == ExecutorState.PREPARING: state = ExecutorState.PREPARED if state == ExecutorState.FINALIZING: state = ExecutorState.FINALIZED - self.state[definition.id] = JobStatus(state, message) + self.state[definition.id] = JobStatus(state, message, timestamp_ns) - def set_job_transition(self, definition, state, message="executor message"): + def set_job_transition( + self, definition, state, message="executor message", timestamp_ns=None + ): """Set the next transition for this job when called""" - self.transitions[definition.id] = (state, message) + self.transitions[definition.id] = (state, message, timestamp_ns) - def set_job_result(self, definition, **kwargs): + def set_job_result(self, definition, timestamp_ns=None, **kwargs): + if timestamp_ns is None: + timestamp_ns = time.time_ns() defaults = { "outputs": {}, "unmatched_patterns": [], @@ -176,17 +186,18 @@ def set_job_result(self, definition, **kwargs): def do_transition(self, definition, expected, next_state): current = self.get_status(definition) + timestamp_ns = time.time_ns() if current.state != expected: state = current.state message = f"Invalid transition to {next_state}, currently state is {current.state}" elif definition.id in self.transitions: - state, message = self.transitions[definition.id] + state, message, timestamp_ns = self.transitions[definition.id] else: state = next_state message = "executor message" - self.set_job_state(definition, state) - return JobStatus(state, message) + self.set_job_state(definition, state, message, timestamp_ns) + return JobStatus(state, message, timestamp_ns) def prepare(self, definition): self.tracker["prepare"].add(definition.id) diff --git a/tests/test_integration.py b/tests/test_integration.py index 50c2f1f7..f78f1798 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -158,7 +158,7 @@ def test_integration_with_cohortextractor( # Check that the manifest contains what we expect. This is a subset of what used to be in the manifest, to support # nicer UX for osrelease. manifest_file = medium_privacy_workspace / "metadata" / "manifest.json" - manifest = json.load(manifest_file.open()) + manifest = json.loads(manifest_file.read_text()) assert manifest["workspace"] == "testing" assert manifest["repo"] == str(test_repo.path) @@ -200,8 +200,8 @@ def test_integration_with_cohortextractor( # labels. assert not any(s.attributes["action_created"] == "unknown" for s in executed_jobs) - job_spans = [s for s in get_trace("loop") if s.name == "LOOP"] - assert len(job_spans) > 1 + loop_spans = [s for s in get_trace("loop") if s.name == "LOOP"] + assert len(loop_spans) > 1 @pytest.mark.slow_test @@ -327,7 +327,7 @@ def test_integration_with_databuilder( # Check that the manifest contains what we expect. This is a subset of what used to be in the manifest, to support # nicer UX for osrelease. manifest_file = medium_privacy_workspace / "metadata" / "manifest.json" - manifest = json.load(manifest_file.open()) + manifest = json.loads(manifest_file.read_text()) assert manifest["workspace"] == "testing" assert manifest["repo"] == str(test_repo.path) diff --git a/tests/test_run.py b/tests/test_run.py index 31027e11..aa9daa56 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,3 +1,5 @@ +import time + import pytest from opentelemetry import trace @@ -9,6 +11,85 @@ from tests.fakes import RecordingExecutor +@pytest.mark.parametrize( + "synchronous_transitions", + [ + [], + [ExecutorState.PREPARING, ExecutorState.FINALIZING], + ], +) +def test_handle_job_full_execution(synchronous_transitions, db, freezer): + # move to a whole second boundary for easier timestamp maths + freezer.move_to("2022-01-01T12:34:56") + + api = StubExecutorAPI() + api.synchronous_transitions = synchronous_transitions + + start = int(time.time() * 1e9) + + job = api.add_test_job(ExecutorState.UNKNOWN, State.PENDING, StatusCode.CREATED) + + freezer.tick(1) + + run.handle_job(job, api) + assert job.state == State.RUNNING + assert job.status_code == StatusCode.PREPARING + + freezer.tick(1) + api.set_job_state(job, ExecutorState.PREPARED) + + freezer.tick(1) + run.handle_job(job, api) + assert job.state == State.RUNNING + assert job.status_code == StatusCode.EXECUTING + + freezer.tick(1) + api.set_job_state(job, ExecutorState.EXECUTED) + + freezer.tick(1) + run.handle_job(job, api) + assert job.state == State.RUNNING + assert job.status_code == StatusCode.FINALIZING + + freezer.tick(1) + api.set_job_state(job, ExecutorState.FINALIZED) + api.set_job_result(job) + + freezer.tick(1) + run.handle_job(job, api) + assert job.state == State.SUCCEEDED + assert job.status_code == StatusCode.SUCCEEDED + + spans = get_trace("jobs") + assert [s.name for s in spans] == [ + "CREATED", + "PREPARING", + "PREPARED", + "EXECUTING", + "EXECUTED", + "FINALIZING", + "FINALIZED", + "SUCCEEDED", + "JOB", + ] + + span_times = [ + (s.name, (s.start_time - start) / 1e9, (s.end_time - start) / 1e9) + for s in spans[:-1] + if not s.name.startswith("ENTER") + ] + assert span_times == [ + ("CREATED", 0.0, 1.0), + ("PREPARING", 1.0, 2.0), + ("PREPARED", 2.0, 3.0), + ("EXECUTING", 3.0, 4.0), + ("EXECUTED", 4.0, 5.0), + ("FINALIZING", 5.0, 6.0), + ("FINALIZED", 6.0, 7.0), + ("SUCCEEDED", 7.0, 8.0), # this is always 1 second anyway! + ] + + def test_handle_pending_job_cancelled(db): api = StubExecutorAPI() job = api.add_test_job(ExecutorState.UNKNOWN, State.PENDING, cancelled=True) @@ -623,6 +704,7 @@ def test_handle_single_job_shortcuts_synchronous(db): assert [s.name for s in get_trace("jobs")] == [ "CREATED", "PREPARING", + "PREPARED", ] diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 0e7e364f..b1cb4dab 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -125,6 +125,7 @@ def test_initialise_trace(db): def test_finish_current_state(db): job = job_factory() + start_time = job.status_code_updated_at results = job_results_factory() ts = int(time.time() * 1e9) @@ -133,7 +134,7 @@ def test_finish_current_state(db): spans = get_trace("jobs") assert spans[-1].name == "CREATED" - assert spans[-1].start_time == int(job.created_at * 1e9) + assert spans[-1].start_time == start_time assert spans[-1].end_time == ts assert spans[-1].attributes["extra"] == "extra" assert spans[-1].attributes["job"] == job.id