Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async Firehose Client: block on make message handler call, add on error callback #157

Merged
merged 11 commits into from
Oct 27, 2023
143 changes: 79 additions & 64 deletions atproto/firehose/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ def __init__(
self._reconnect_no = 0
self._max_reconnect_delay_sec = 64

self._on_message_callback: t.Optional[t.Union[OnMessageCallback, AsyncOnMessageCallback]] = None
self._on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None

def update_params(self, params: t.Union[ParamsModelBase, t.Dict[str, t.Any]]) -> None:
"""Update params.

Expand All @@ -110,6 +107,8 @@ def _get_client(self):
return connect(self._websocket_uri, max_size=_MAX_MESSAGE_SIZE_BYTES)

def _get_async_client(self):
# NOTE: I've noticed that the close operation often takes the entire timeout for some reason
# By default this is 10 seconds, which is pretty long. Maybe shorten it?
MarshalX marked this conversation as resolved.
Show resolved Hide resolved
return aconnect(self._websocket_uri, max_size=_MAX_MESSAGE_SIZE_BYTES)

def _get_reconnection_delay(self) -> int:
Expand All @@ -118,6 +117,19 @@ def _get_reconnection_delay(self) -> int:

return min(base_sec, self._max_reconnect_delay_sec) + rand_sec


class _WebsocketClient(_WebsocketClientBase):
def __init__(
self, method: str, base_uri: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None
) -> None:
super().__init__(method, base_uri, params)

# TODO: Not sure if this should be a Lock or not, the async is using an Event now
self._stop_lock = threading.Lock()
DXsmiley marked this conversation as resolved.
Show resolved Hide resolved

self._on_message_callback: t.Optional[OnMessageCallback] = None
self._on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None

def _process_raw_frame(self, data: bytes) -> None:
frame = Frame.from_bytes(data)
if isinstance(frame, ErrorFrame):
Expand All @@ -127,9 +139,22 @@ def _process_raw_frame(self, data: bytes) -> None:
else:
raise FirehoseDecodingError('Unknown frame type')

def _process_message_frame(self, frame: 'MessageFrame') -> None:
try:
if self._on_message_callback is not None:
self._on_message_callback(frame)
except Exception as e: # noqa: BLE001
if self._on_callback_error_callback:
try:
self._on_callback_error_callback(e)
except: # noqa
traceback.print_exc()
else:
traceback.print_exc()

def start(
MarshalX marked this conversation as resolved.
Show resolved Hide resolved
self,
on_message_callback: t.Union[OnMessageCallback, AsyncOnMessageCallback],
on_message_callback: OnMessageCallback,
on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None,
) -> None:
"""Subscribe to Firehose and start client.
Expand All @@ -144,41 +169,6 @@ def start(
self._on_message_callback = on_message_callback
self._on_callback_error_callback = on_callback_error_callback

def stop(self):
"""Unsubscribe and stop the Firehose client.

Returns:
:obj:`None`
"""
raise NotImplementedError

def _process_message_frame(self, frame: 'MessageFrame') -> None:
raise NotImplementedError


class _WebsocketClient(_WebsocketClientBase):
def __init__(
self, method: str, base_uri: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None
) -> None:
super().__init__(method, base_uri, params)

self._stop_lock = threading.Lock()

def _process_message_frame(self, frame: 'MessageFrame') -> None:
try:
self._on_message_callback(frame)
except Exception as e: # noqa: BLE001
if self._on_callback_error_callback:
try:
self._on_callback_error_callback(e)
except: # noqa
traceback.print_exc()
else:
traceback.print_exc()

def start(self, *args, **kwargs):
super().start(*args, **kwargs)

while not self._stop_lock.locked():
try:
if self._reconnect_no != 0:
Expand Down Expand Up @@ -207,7 +197,12 @@ def start(self, *args, **kwargs):
if self._stop_lock.locked():
self._stop_lock.release()

def stop(self):
def stop(self) -> None:
"""Unsubscribe and stop the Firehose client.

Returns:
:obj:`None`
"""
if not self._stop_lock.locked():
self._stop_lock.acquire()

Expand All @@ -217,50 +212,69 @@ def __init__(
self, method: str, base_uri: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None
) -> None:
super().__init__(method, base_uri, params)
self._stop_event = asyncio.Event()
self._on_message_callback: t.Optional[AsyncOnMessageCallback] = None
self._on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None

self._loop = asyncio.get_event_loop()
self._on_message_tasks: t.Set[asyncio.Task] = set()

self._stop_lock = asyncio.Lock()

def _on_message_callback_done(self, task: asyncio.Task) -> None:
self._on_message_tasks.discard(task)
async def _process_raw_frame(self, data: bytes) -> None:
frame = Frame.from_bytes(data)
if isinstance(frame, ErrorFrame):
raise FirehoseError(XrpcError(frame.body.error, frame.body.message))
if isinstance(frame, MessageFrame):
await self._process_message_frame(frame)
else:
raise FirehoseDecodingError('Unknown frame type')

exception = task.exception()
if exception:
async def _process_message_frame(self, frame: 'MessageFrame') -> None:
try:
if self._on_message_callback is not None:
await self._on_message_callback(frame)
except Exception as exception: # noqa: BLE001
if not self._on_callback_error_callback:
_print_exception(exception)
return

try:
self._on_callback_error_callback(exception)
except: # noqa
traceback.print_exc()

def _process_message_frame(self, frame: 'MessageFrame') -> None:
task: asyncio.Task = self._loop.create_task(self._on_message_callback(frame))
self._on_message_tasks.add(task)
task.add_done_callback(self._on_message_callback_done)
async def start(
self,
on_message_callback: AsyncOnMessageCallback,
on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None,
) -> None:
"""Subscribe to Firehose and start client.

async def start(self, *args, **kwargs):
super().start(*args, **kwargs)
Args:
on_message_callback: Callback that will be called on the new Firehose message.
on_callback_error_callback: Callback that will be called if the `on_message_callback` raised an exception.

while not self._stop_lock.locked():
Returns:
:obj:`None`
"""
self._on_message_callback = on_message_callback
self._on_callback_error_callback = on_callback_error_callback

self._stop_event = asyncio.Event()

while not self._stop_event.is_set():
try:
if self._reconnect_no != 0:
# TODO: This sleep can potentially get pretty long,
# allow it to be interrupted by stop()?
await asyncio.sleep(self._get_reconnection_delay())

async with self._get_async_client() as client:
self._reconnect_no = 0

while not self._stop_lock.locked():
while not self._stop_event.is_set():
raw_frame = await client.recv()
MarshalX marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(raw_frame, str):
# skip text frames (should not be occurred)
continue

try:
self._process_raw_frame(raw_frame)
await self._process_raw_frame(raw_frame)
except Exception as e: # noqa: BLE001
_handle_frame_decoding_error(e)
except Exception as e: # noqa: BLE001
Expand All @@ -270,12 +284,13 @@ async def start(self, *args, **kwargs):
if should_stop:
break

if self._stop_lock.locked():
self._stop_lock.release()
async def stop(self) -> None:
"""Unsubscribe and stop the Firehose client.

async def stop(self):
if not self._stop_lock.locked():
await self._stop_lock.acquire()
Returns:
:obj:`None`
"""
self._stop_event.set()


FirehoseClient = _WebsocketClient
Expand Down