Skip to content

Commit

Permalink
stop worker if channel is closed
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Nov 2, 2024
1 parent a38df9d commit 734d187
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 10 deletions.
15 changes: 14 additions & 1 deletion pyzeebe/grpc_internals/zeebe_adapter_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import logging
from collections.abc import Callable
from typing import NoReturn

import grpc
from typing_extensions import TypeAlias
from zeebe_grpc.gateway_pb2_grpc import GatewayStub

from pyzeebe.errors import (
Expand All @@ -14,6 +18,8 @@
from pyzeebe.errors.pyzeebe_errors import PyZeebeError
from pyzeebe.grpc_internals.grpc_utils import is_error_status

Callback: TypeAlias = Callable[[], None]

logger = logging.getLogger(__name__)


Expand All @@ -25,11 +31,15 @@ def __init__(self, grpc_channel: grpc.aio.Channel, max_connection_retries: int =
self.retrying_connection = False
self._max_connection_retries = max_connection_retries
self._current_connection_retries = 0
self._on_disconnect_callbacks: list[Callback] = []

@property
def connected(self) -> bool:
return self._connected

def add_disconnect_callback(self, callback: Callback) -> None:
self._on_disconnect_callbacks.append(callback)

def _should_retry(self) -> bool:
return self._max_connection_retries == -1 or self._current_connection_retries < self._max_connection_retries

Expand All @@ -45,10 +55,13 @@ async def _handle_grpc_error(self, grpc_error: grpc.aio.AioRpcError) -> NoReturn

async def _close(self) -> None:
try:
self._connected = False
await self._channel.close()
except Exception as exception:
logger.exception("Failed to close channel, %s exception was raised", type(exception).__name__)
finally:
self._connected = False
for callback in self._on_disconnect_callbacks:
callback()


def _create_pyzeebe_error_from_grpc_error(grpc_error: grpc.aio.AioRpcError) -> PyZeebeError:
Expand Down
12 changes: 7 additions & 5 deletions pyzeebe/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ def __init__(
tenant_ids (list[str]): A list of tenant IDs for which to activate jobs. New in Zeebe 8.3.
"""
super().__init__(before, after, exception_handler)
self._stop_event = anyio.Event()
self.zeebe_adapter = ZeebeAdapter(grpc_channel, max_connection_retries)
self.zeebe_adapter.add_disconnect_callback(self._stop_event.set)
self.name = name or socket.gethostname()
self.request_timeout = request_timeout
self.poll_retry_delay = poll_retry_delay
self.tenant_ids = tenant_ids
self._job_pollers: list[JobPoller] = []
self._job_executors: list[JobExecutor] = []
self._stop_event = anyio.Event()

def _init_tasks(self) -> None:
self._job_executors, self._job_pollers = [], []
Expand Down Expand Up @@ -110,11 +111,12 @@ async def stop(self) -> None:
"""
Stop the worker. This will emit a signal asking tasks to complete the current task and stop polling for new.
"""
for poller in self._job_pollers:
await poller.stop()
async with anyio.create_task_group() as tg:
for poller in self._job_pollers:
tg.start_soon(poller.stop)

for executor in self._job_executors:
await executor.stop()
for executor in self._job_executors:
tg.start_soon(executor.stop)

self._stop_event.set()

Expand Down
10 changes: 9 additions & 1 deletion tests/unit/grpc_internals/zeebe_adapter_base_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, Mock

import grpc
import pytest
Expand Down Expand Up @@ -67,6 +67,9 @@ async def test_raises_unkown_grpc_status_code_on_unkown_status_code(
await zeebe_adapter._handle_grpc_error(error)

async def test_closes_after_retries_exceeded(self, zeebe_adapter: ZeebeAdapterBase):
on_disconnect_callback = Mock()
zeebe_adapter.add_disconnect_callback(on_disconnect_callback)

error = grpc.aio.AioRpcError(grpc.StatusCode.UNAVAILABLE, None, None)

zeebe_adapter._channel.close = AsyncMock()
Expand All @@ -76,8 +79,12 @@ async def test_closes_after_retries_exceeded(self, zeebe_adapter: ZeebeAdapterBa

assert zeebe_adapter.connected is False
zeebe_adapter._channel.close.assert_awaited_once()
on_disconnect_callback.assert_called_once()

async def test_closes_after_internal_error(self, zeebe_adapter: ZeebeAdapterBase):
on_disconnect_callback = Mock()
zeebe_adapter.add_disconnect_callback(on_disconnect_callback)

error = grpc.aio.AioRpcError(grpc.StatusCode.INTERNAL, None, None)

zeebe_adapter._channel.close = AsyncMock()
Expand All @@ -87,3 +94,4 @@ async def test_closes_after_internal_error(self, zeebe_adapter: ZeebeAdapterBase

assert zeebe_adapter.connected is False
zeebe_adapter._channel.close.assert_awaited_once()
on_disconnect_callback.assert_called_once()
22 changes: 19 additions & 3 deletions tests/unit/worker/worker_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from unittest.mock import AsyncMock, Mock
from uuid import uuid4

Expand Down Expand Up @@ -277,12 +278,12 @@ async def test_poller_failed(self, zeebe_worker: ZeebeWorker):
async def test_second_poller_should_cancel(self, zeebe_worker: ZeebeWorker):
zeebe_worker._init_tasks = Mock()

poller2_cancel_event = anyio.Event()
poller2_cancel_event = asyncio.Event()

async def poll2():
try:
await anyio.Event().wait()
except anyio.get_cancelled_exc_class():
await asyncio.Event().wait()
except asyncio.CancelledError:
poller2_cancel_event.set()

poller_mock = AsyncMock(spec_set=JobPoller, poll=AsyncMock(side_effect=[Exception("test_exception")]))
Expand All @@ -295,3 +296,18 @@ async def poll2():
poller_mock.poll.assert_awaited_once()
poller2_mock.poll.assert_awaited_once()
assert poller2_cancel_event.is_set()

async def test_stop_after_retries_exceeded(self, zeebe_worker: ZeebeWorker):
@zeebe_worker.task(str(uuid4()))
def dummy_function():
pass

zeebe_worker.zeebe_adapter._gateway_stub.ActivateJobs.side_effect = [
grpc.aio.AioRpcError(grpc.StatusCode.INTERNAL, None, None)
]
zeebe_worker.zeebe_adapter._max_connection_retries = 1

await zeebe_worker.work()

zeebe_worker.zeebe_adapter._gateway_stub.ActivateJobs.assert_called_once()
assert zeebe_worker._stop_event.is_set() is True

0 comments on commit 734d187

Please sign in to comment.