diff --git a/taskbadger/celery.py b/taskbadger/celery.py index 1461bba..6095f05 100644 --- a/taskbadger/celery.py +++ b/taskbadger/celery.py @@ -120,26 +120,28 @@ def task_prerun_handler(sender=None, **kwargs): @task_success.connect def task_success_handler(sender=None, **kwargs): _update_task(sender, StatusEnum.SUCCESS) - exit_session() + exit_session(sender) @task_failure.connect def task_failure_handler(sender=None, einfo=None, **kwargs): _update_task(sender, StatusEnum.ERROR, einfo) - exit_session() + exit_session(sender) @task_retry.connect def task_retry_handler(sender=None, einfo=None, **kwargs): _update_task(sender, StatusEnum.ERROR, einfo) - exit_session() + exit_session(sender) def _update_task(signal_sender, status, einfo=None): - log.debug("celery_task_success %s", signal_sender) + log.debug("celery_task_update %s %s", signal_sender, status) + if not hasattr(signal_sender, "taskbadger_task"): + return task = signal_sender.taskbadger_task - if not task: + if task is None: return if task.status in TERMINAL_STATES: @@ -162,8 +164,8 @@ def enter_session(): session.__enter__() -def exit_session(): - if not Badger.is_configured(): +def exit_session(signal_sender): + if not hasattr(signal_sender, "taskbadger_task") or not Badger.is_configured(): return session = Badger.current.session() if session.client: diff --git a/taskbadger/mug.py b/taskbadger/mug.py index 75178af..30e5128 100644 --- a/taskbadger/mug.py +++ b/taskbadger/mug.py @@ -1,9 +1,8 @@ import dataclasses from contextlib import ContextDecorator +from contextvars import ContextVar from typing import Union -from _contextvars import ContextVar - from taskbadger.internal import AuthenticatedClient _local = ContextVar("taskbadger_client") @@ -64,6 +63,9 @@ def __exit__(self, *args, **kwargs): class MugMeta(type): @property def current(cls): + # Note that changes in the parent thread are not propagated to child threads + # i.e. if this is called in a child thread before configuration is set in the parent thread + # the config will not propagate to the child thread. mug = _local.get(None) if mug is None: mug = Badger(GLOBAL_MUG) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7ec0a8d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +import pytest + +from taskbadger.mug import Badger, Settings + + +@pytest.fixture +def bind_settings(): + Badger.current.bind(Settings("https://taskbadger.net", "token", "org", "proj")) + yield + Badger.current.bind(None) diff --git a/tests/test_celery.py b/tests/test_celery.py index 4b29afe..232dc6b 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -1,3 +1,13 @@ +""" +Note +==== + +As part of the Celery fixture setup a 'ping' task is run which executes +before the `bind_settings` fixture is executed. This means that if any code +calls `Badger.is_configured()` (or similar), the `_local` ContextVar in the +Celery runner thread will not have the configuration set. +""" +import logging from unittest import mock import celery @@ -5,23 +15,24 @@ from taskbadger import StatusEnum from taskbadger.celery import Task -from taskbadger.mug import Badger, Settings +from taskbadger.mug import Badger from tests.utils import task_for_test -@pytest.fixture -def bind_settings(): - Badger.current.bind(Settings("https://taskbadger.net", "token", "org", "proj")) +@pytest.fixture(autouse=True) +def check_log_errors(caplog): yield - Badger.current.bind(None) + errors = [r.getMessage() for r in caplog.get_records("call") if r.levelno == logging.ERROR] + if errors: + pytest.fail(f"log errors during tests: {errors}") def test_celery_task(celery_session_app, celery_session_worker, bind_settings): @celery_session_app.task(bind=True, base=Task) def add_normal(self, a, b): - assert self.request.get("taskbadger_task") is not None - assert self.taskbadger_task is not None - assert Badger.current.session().client is not None + assert self.request.get("taskbadger_task") is not None, "missing task in request" + assert self.taskbadger_task is not None, "missing task on self" + assert Badger.current.session().client is not None, "missing client" return a + b celery_session_worker.reload() @@ -74,35 +85,6 @@ def add_with_task_args_in_decorator(self, a, b): create.assert_called_once_with(mock.ANY, status=StatusEnum.PENDING, monitor_id="123", value_max=10) -def test_celery_task_error(celery_session_app, celery_session_worker, bind_settings): - @celery_session_app.task(bind=True, base=Task) - def add_error(self, a, b): - assert self.taskbadger_task is not None - assert Badger.current.session().client is not None - raise Exception("error") - - celery_session_worker.reload() - - with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch( - "taskbadger.celery.update_task_safe" - ) as update, mock.patch("taskbadger.celery.get_task") as get_task: - get_task.return_value = task_for_test() - result = add_error.delay(2, 2) - with pytest.raises(Exception): - result.get(timeout=10, propagate=True) - - create.assert_called() - update.assert_has_calls( - [ - mock.call(mock.ANY, status=StatusEnum.PROCESSING, data=mock.ANY), - mock.call(mock.ANY, status=StatusEnum.ERROR, data=mock.ANY), - ] - ) - data_kwarg = update.call_args_list[1][1]["data"] - assert "Traceback" in data_kwarg["exception"] - assert Badger.current.session().client is None - - def test_celery_task_retry(celery_session_app, celery_session_worker, bind_settings): @celery_session_app.task(bind=True, base=Task) def add_retry(self, a, b, is_retry=False): @@ -207,7 +189,7 @@ def task_signature(self, a): with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch( "taskbadger.celery.update_task_safe" ) as update, mock.patch("taskbadger.celery.get_task") as get_task: - result = chain() + result = chain.delay() assert result.get(timeout=10, propagate=True) == 16 assert create.call_count == 3 @@ -216,6 +198,31 @@ def task_signature(self, a): assert Badger.current.session().client is None +def test_task_map(celery_session_worker, bind_settings): + """Tasks executed in a map or starmap are not executed as tasks""" + + @celery.shared_task(bind=True, base=Task) + def task_map(self, a): + assert self.taskbadger_task is None + assert Badger.current.session().client is None + return a * 2 + + celery_session_worker.reload() + + task_map = task_map.map(list(range(5))) + + with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch( + "taskbadger.celery.update_task_safe" + ) as update, mock.patch("taskbadger.celery.get_task") as get_task: + result = task_map.delay() + assert result.get(timeout=10, propagate=True) == [0, 2, 4, 6, 8] + + assert create.call_count == 0 + assert get_task.call_count == 0 + assert update.call_count == 0 + assert Badger.current.session().client is None + + def test_celery_task_already_in_terminal_state(celery_session_worker, bind_settings): @celery.shared_task(bind=True, base=Task) def add_manual_update(self, a, b, is_retry=False): diff --git a/tests/test_celery_error.py b/tests/test_celery_error.py new file mode 100644 index 0000000..ce72de7 --- /dev/null +++ b/tests/test_celery_error.py @@ -0,0 +1,37 @@ +from unittest import mock + +import pytest + +from taskbadger import StatusEnum +from taskbadger.celery import Task +from taskbadger.mug import Badger +from tests.utils import task_for_test + + +def test_celery_task_error(celery_session_app, celery_session_worker, bind_settings): + @celery_session_app.task(bind=True, base=Task) + def add_error(self, a, b): + assert self.taskbadger_task is not None + assert Badger.current.session().client is not None + raise Exception("error") + + celery_session_worker.reload() + + with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch( + "taskbadger.celery.update_task_safe" + ) as update, mock.patch("taskbadger.celery.get_task") as get_task: + get_task.return_value = task_for_test() + result = add_error.delay(2, 2) + with pytest.raises(Exception): + result.get(timeout=10, propagate=True) + + create.assert_called() + update.assert_has_calls( + [ + mock.call(mock.ANY, status=StatusEnum.PROCESSING, data=mock.ANY), + mock.call(mock.ANY, status=StatusEnum.ERROR, data=mock.ANY), + ] + ) + data_kwarg = update.call_args_list[1][1]["data"] + assert "Traceback" in data_kwarg["exception"] + assert Badger.current.session().client is None