Skip to content

Commit

Permalink
Add flush method (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdeziel authored Oct 12, 2023
1 parent 20e86f2 commit 82b83f2
Show file tree
Hide file tree
Showing 12 changed files with 770 additions and 361 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
working-directory: ${{ github.workspace }}/pyensign
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- name: Checkout Code
Expand Down
27 changes: 23 additions & 4 deletions pyensign/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ async def publish(self, topic, events, on_ack=None, on_nack=None):
done_callback=lambda: self.publishers.pop(topic_hash, None),
)

# Create a concurrent task to queue the events from the user
await self.pool.schedule(publisher.queue_events(events))
# Queue all the events to be published
# await self.pool.schedule(publisher.queue_events(events))
await publisher.queue_events(events)

@catch_rpc_error
async def subscribe(self, topics, query=None, consumer_group=None):
Expand Down Expand Up @@ -346,13 +347,31 @@ async def status(self, attempts=0, last_checked_at=None):
rep = await self.stub.Status(params)
return rep.status, rep.version, rep.uptime, rep.not_before, rep.not_after

async def flush(self, timeout=timedelta(seconds=2.0)):
"""
Flush all events that have been queued by publishers but not yet published.
"""

# Materialize the list of current publishers. New publishers created after this
# point will not be flushed.
publishers = list(self.publishers.values())
for publisher in publishers:
await publisher.flush(timeout=timeout)

# Materialize the list of current subscribers. New subscribers created after
# this point will not be flushed.
subscribers = list(self.subscribers.values())
for subscriber in subscribers:
await subscriber.flush(timeout=timeout)

async def close(self):
"""
Close the connection to the server and all ongoing streams.
Close the connection to the server and all ongoing streams. This blocks until
all requests have been flushed and all streams have been closed.
"""
await self._close_streams()
if self.channel:
await self.channel.close()
await self._close_streams()
self.channel = None

async def _close_streams(self):
Expand Down
32 changes: 31 additions & 1 deletion pyensign/ensign.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import inspect
from datetime import timedelta

from ulid import ULID

Expand Down Expand Up @@ -605,10 +606,37 @@ async def status(self):
status, version, uptime, _, _ = await self.client.status()
return ServerStatus(status, version, uptime)

async def flush(self, timeout=2.0):
"""
Flush all pending events to and from the server. This method blocks until
either all pending events have been published or subscribed or the timeout is
reached. An exception is raised if the timeout is reached before all events
have been flushed.
Parameters
----------
timeout: float (optional) (default: 2.0)
Specify the timeout in seconds.
Raises
------
asyncio.TimeoutError
If the timeout is reached before all events have been flushed.
"""

if timeout <= 0:
raise ValueError("timeout must be greater than 0")

await self.client.flush(timeout=timedelta(seconds=timeout))

async def close(self):
"""
Close the Ensign client.
Close the Ensign client. This method blocks until all pending events have been
flushed and should be called before exiting the application. After the client
is closed it cannot be used. The `flush()` method should be used instead if the
client still needs to be used.
"""

await self.client.close()

async def __aenter__(self):
Expand Down Expand Up @@ -695,6 +723,7 @@ async def wrapper(*args, **kwargs):
_client = None
raise e

await _client.close()
_client = None
return res

Expand All @@ -717,6 +746,7 @@ async def wrapper(*args, **kwargs):
_client = None
raise e

await _client.close()
_client = None

return wrapper
Expand Down
27 changes: 23 additions & 4 deletions pyensign/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle
import asyncio
from enum import Enum
from functools import total_ordering
from google.protobuf.timestamp_pb2 import Timestamp

from pyensign.ack import Ack
Expand Down Expand Up @@ -96,6 +97,13 @@ def proto(self):
created=self.created,
)

def published(self):
"""
Returns True if the event has been published (sent to the server).
"""

return self._state >= EventState.PUBLISHED

def acked(self):
"""
Returns True if the event has been acked.
Expand Down Expand Up @@ -206,6 +214,9 @@ async def wait_for_ack(self):

return Ack(self.id, self.committed)

def mark_queued(self):
self._state = EventState.QUEUED

def mark_published(self):
self._state = EventState.PUBLISHED

Expand Down Expand Up @@ -326,17 +337,25 @@ def __str__(self):
return "{} v{}".format(self.name, self.semver())


@total_ordering
class EventState(Enum):
# Event has been created but not published
INITIALIZED = 0
# Event has been queued for publishing
QUEUED = 1
# Event has been published but not acked by the server
PUBLISHED = 1
PUBLISHED = 2
# Event has been received by subscriber but not acked by the user
SUBSCRIBED = 2
SUBSCRIBED = 3
# Event has been acked by a user or the server
ACKED = 3
ACKED = 4
# Event has been nacked by a user or the server
NACKED = 4
NACKED = 5

def __lt__(self, other):
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented


def from_object(obj, mimetype=None, encoder=None):
Expand Down
2 changes: 2 additions & 0 deletions pyensign/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def _handle_client_error(e):
raise EnsignAttributeError(
"error accessing field from Ensign response: {}".format(e)
) from e
elif isinstance(e, AuthenticationError):
raise e
elif isinstance(e, QueryNoRows):
raise e
elif isinstance(e, EnsignTopicNotFoundError):
Expand Down
119 changes: 7 additions & 112 deletions pyensign/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,6 @@

import grpc

from pyensign.exceptions import EnsignTypeError


class RequestIterator:
"""
RequestIterator is an asynchronous iterator that yields requests from an internal
client stream after sending an initial handshaking request. This iterator can be
used directly as the request iterator for the publish or subscribe RPCs.
"""

def __init__(self, queue, init_request):
self.queue = queue
self.init_request = init_request

async def __aiter__(self):
# Send the initial request to the server
yield self.init_request

# Publish events from the client until closed
# When this iterator is done, the gRPC stream will also be closed
while True:
req = await self.queue.read_request()
if req is None:
break
yield req


class ResponseIterator:
"""
Expand All @@ -47,94 +21,15 @@ async def __aiter__(self):
while True:
try:
rep = await self.stream.read()
except grpc.aio.AioRpcError:
except grpc.aio.AioRpcError as e:
logging.warning(
f"gRPC error occurred while reading from the stream: {e}"
)
break
except StopAsyncIteration:
logging.debug("gRPC stream was closed")
break
# Handle unexpected end of stream
if rep is grpc.aio.EOF:
break
yield rep


class PublishResponseIterator(ResponseIterator):
"""
PublishResponseIterator is an asynchronous iterator that reads responses from a
gRPC publish stream and executes user-defined callbacks for acks and nacks.
"""

def __init__(self, stream, pending, on_ack=None, on_nack=None):
"""
Parameters
----------
stream : grpc.aio.StreamStreamCall
The gRPC stream object which has been opened and is ready to be read from.
pending : dict
A dictionary of pending local ULIDs to events that have been published but
not yet acked or nacked.
"""
self.stream = stream
self.pending = pending
self.on_ack = on_ack
self.on_nack = on_nack

async def consume(self):
"""
Consume all responses from the stream until closed.
"""
async for rep in self:
# Handle messages from the server
rep_type = rep.WhichOneof("embed")
if rep_type == "ack":
event = self.pending.pop(rep.ack.id, None)
if event:
event.mark_acked(rep.ack)
if self.on_ack:
try:
await self.on_ack(rep.ack)
except Exception as e:
logging.warning(
f"unhandled exception while awaiting ack callback: {e}",
exc_info=True,
)
elif rep_type == "nack":
event = self.pending.pop(rep.nack.id, None)
if event:
event.mark_nacked(rep.nack)
if self.on_nack:
try:
await self.on_nack(rep.nack)
except Exception as e:
logging.warning(
f"unhandled exception while awaiting nack callback: {e}",
exc_info=True,
)
elif rep_type == "close_stream":
break
else:
raise EnsignTypeError(f"unexpected response type: {rep_type}")


class SubscribeResponseIterator(ResponseIterator):
"""
SubscribeResponseIterator is an asynchronous iterator that reads responses from a
gRPC subscribe stream and writes events to a client queue.
"""

def __init__(self, stream, queue):
self.stream = stream
self.queue = queue

async def consume(self):
"""
Consume all responses from the stream until closed.
"""
async for rep in self:
rep_type = rep.WhichOneof("embed")
if rep_type == "event":
await self.queue.write_response(rep.event)
elif rep_type == "close_stream":
break
else:
await self.queue.write_response(
EnsignTypeError(f"unexpected response type: {rep_type}")
)
break
Loading

0 comments on commit 82b83f2

Please sign in to comment.