Skip to content

Commit

Permalink
Merge pull request #15 from taskbadger/sk/canvas
Browse files Browse the repository at this point in the history
safer execution and checking
  • Loading branch information
snopoke authored Sep 9, 2023
2 parents 8db3e65 + b3d7a85 commit e792600
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 47 deletions.
16 changes: 9 additions & 7 deletions taskbadger/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,26 +120,28 @@ def task_prerun_handler(sender=None, **kwargs):
@task_success.connect
def task_success_handler(sender=None, **kwargs):
_update_task(sender, StatusEnum.SUCCESS)
exit_session()
exit_session(sender)


@task_failure.connect
def task_failure_handler(sender=None, einfo=None, **kwargs):
_update_task(sender, StatusEnum.ERROR, einfo)
exit_session()
exit_session(sender)


@task_retry.connect
def task_retry_handler(sender=None, einfo=None, **kwargs):
_update_task(sender, StatusEnum.ERROR, einfo)
exit_session()
exit_session(sender)


def _update_task(signal_sender, status, einfo=None):
log.debug("celery_task_success %s", signal_sender)
log.debug("celery_task_update %s %s", signal_sender, status)
if not hasattr(signal_sender, "taskbadger_task"):
return

task = signal_sender.taskbadger_task
if not task:
if task is None:
return

if task.status in TERMINAL_STATES:
Expand All @@ -162,8 +164,8 @@ def enter_session():
session.__enter__()


def exit_session():
if not Badger.is_configured():
def exit_session(signal_sender):
if not hasattr(signal_sender, "taskbadger_task") or not Badger.is_configured():
return
session = Badger.current.session()
if session.client:
Expand Down
6 changes: 4 additions & 2 deletions taskbadger/mug.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import dataclasses
from contextlib import ContextDecorator
from contextvars import ContextVar
from typing import Union

from _contextvars import ContextVar

from taskbadger.internal import AuthenticatedClient

_local = ContextVar("taskbadger_client")
Expand Down Expand Up @@ -64,6 +63,9 @@ def __exit__(self, *args, **kwargs):
class MugMeta(type):
@property
def current(cls):
# Note that changes in the parent thread are not propagated to child threads
# i.e. if this is called in a child thread before configuration is set in the parent thread
# the config will not propagate to the child thread.
mug = _local.get(None)
if mug is None:
mug = Badger(GLOBAL_MUG)
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from taskbadger.mug import Badger, Settings


@pytest.fixture
def bind_settings():
Badger.current.bind(Settings("https://taskbadger.net", "token", "org", "proj"))
yield
Badger.current.bind(None)
83 changes: 45 additions & 38 deletions tests/test_celery.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
"""
Note
====
As part of the Celery fixture setup a 'ping' task is run which executes
before the `bind_settings` fixture is executed. This means that if any code
calls `Badger.is_configured()` (or similar), the `_local` ContextVar in the
Celery runner thread will not have the configuration set.
"""
import logging
from unittest import mock

import celery
import pytest

from taskbadger import StatusEnum
from taskbadger.celery import Task
from taskbadger.mug import Badger, Settings
from taskbadger.mug import Badger
from tests.utils import task_for_test


@pytest.fixture
def bind_settings():
Badger.current.bind(Settings("https://taskbadger.net", "token", "org", "proj"))
@pytest.fixture(autouse=True)
def check_log_errors(caplog):
yield
Badger.current.bind(None)
errors = [r.getMessage() for r in caplog.get_records("call") if r.levelno == logging.ERROR]
if errors:
pytest.fail(f"log errors during tests: {errors}")


def test_celery_task(celery_session_app, celery_session_worker, bind_settings):
@celery_session_app.task(bind=True, base=Task)
def add_normal(self, a, b):
assert self.request.get("taskbadger_task") is not None
assert self.taskbadger_task is not None
assert Badger.current.session().client is not None
assert self.request.get("taskbadger_task") is not None, "missing task in request"
assert self.taskbadger_task is not None, "missing task on self"
assert Badger.current.session().client is not None, "missing client"
return a + b

celery_session_worker.reload()
Expand Down Expand Up @@ -74,35 +85,6 @@ def add_with_task_args_in_decorator(self, a, b):
create.assert_called_once_with(mock.ANY, status=StatusEnum.PENDING, monitor_id="123", value_max=10)


def test_celery_task_error(celery_session_app, celery_session_worker, bind_settings):
@celery_session_app.task(bind=True, base=Task)
def add_error(self, a, b):
assert self.taskbadger_task is not None
assert Badger.current.session().client is not None
raise Exception("error")

celery_session_worker.reload()

with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch(
"taskbadger.celery.update_task_safe"
) as update, mock.patch("taskbadger.celery.get_task") as get_task:
get_task.return_value = task_for_test()
result = add_error.delay(2, 2)
with pytest.raises(Exception):
result.get(timeout=10, propagate=True)

create.assert_called()
update.assert_has_calls(
[
mock.call(mock.ANY, status=StatusEnum.PROCESSING, data=mock.ANY),
mock.call(mock.ANY, status=StatusEnum.ERROR, data=mock.ANY),
]
)
data_kwarg = update.call_args_list[1][1]["data"]
assert "Traceback" in data_kwarg["exception"]
assert Badger.current.session().client is None


def test_celery_task_retry(celery_session_app, celery_session_worker, bind_settings):
@celery_session_app.task(bind=True, base=Task)
def add_retry(self, a, b, is_retry=False):
Expand Down Expand Up @@ -207,7 +189,7 @@ def task_signature(self, a):
with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch(
"taskbadger.celery.update_task_safe"
) as update, mock.patch("taskbadger.celery.get_task") as get_task:
result = chain()
result = chain.delay()
assert result.get(timeout=10, propagate=True) == 16

assert create.call_count == 3
Expand All @@ -216,6 +198,31 @@ def task_signature(self, a):
assert Badger.current.session().client is None


def test_task_map(celery_session_worker, bind_settings):
"""Tasks executed in a map or starmap are not executed as tasks"""

@celery.shared_task(bind=True, base=Task)
def task_map(self, a):
assert self.taskbadger_task is None
assert Badger.current.session().client is None
return a * 2

celery_session_worker.reload()

task_map = task_map.map(list(range(5)))

with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch(
"taskbadger.celery.update_task_safe"
) as update, mock.patch("taskbadger.celery.get_task") as get_task:
result = task_map.delay()
assert result.get(timeout=10, propagate=True) == [0, 2, 4, 6, 8]

assert create.call_count == 0
assert get_task.call_count == 0
assert update.call_count == 0
assert Badger.current.session().client is None


def test_celery_task_already_in_terminal_state(celery_session_worker, bind_settings):
@celery.shared_task(bind=True, base=Task)
def add_manual_update(self, a, b, is_retry=False):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_celery_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from unittest import mock

import pytest

from taskbadger import StatusEnum
from taskbadger.celery import Task
from taskbadger.mug import Badger
from tests.utils import task_for_test


def test_celery_task_error(celery_session_app, celery_session_worker, bind_settings):
@celery_session_app.task(bind=True, base=Task)
def add_error(self, a, b):
assert self.taskbadger_task is not None
assert Badger.current.session().client is not None
raise Exception("error")

celery_session_worker.reload()

with mock.patch("taskbadger.celery.create_task_safe") as create, mock.patch(
"taskbadger.celery.update_task_safe"
) as update, mock.patch("taskbadger.celery.get_task") as get_task:
get_task.return_value = task_for_test()
result = add_error.delay(2, 2)
with pytest.raises(Exception):
result.get(timeout=10, propagate=True)

create.assert_called()
update.assert_has_calls(
[
mock.call(mock.ANY, status=StatusEnum.PROCESSING, data=mock.ANY),
mock.call(mock.ANY, status=StatusEnum.ERROR, data=mock.ANY),
]
)
data_kwarg = update.call_args_list[1][1]["data"]
assert "Traceback" in data_kwarg["exception"]
assert Badger.current.session().client is None

0 comments on commit e792600

Please sign in to comment.