Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Merge pull request #18 from larmoreg/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
larmoreg authored Dec 30, 2021
2 parents 0a2598a + 86bec41 commit 0ddd7ff
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 76 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install
run: poetry install -E kafka -E msgpack -E redis
- name: Pytest
run: poetry run pytest --cov-report=xml --cov=fastmicro tests/
run: poetry run pytest -x -s --cov-report=xml --cov=fastmicro tests/test_client.py tests/test_server.py
- name: Codecov
uses: codecov/[email protected]
with:
Expand Down
2 changes: 1 addition & 1 deletion fastmicro/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def _call(
for input_message in input_messages
]

await self.reply_topic.subscribe(self.broadcast_name)
await self.reply_topic.subscribe(self.broadcast_name, latest=True)

output_headers: List[HeaderABC[BT]] = list()
for i in range(0, len(input_headers), batch_size):
Expand Down
10 changes: 8 additions & 2 deletions fastmicro/messaging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
import asyncio
from contextlib import asynccontextmanager
from types import TracebackType
from typing import (
AsyncIterator,
Optional,
Sequence,
Type,
Expand Down Expand Up @@ -54,7 +56,9 @@ async def deserialize(
data = await self.serializer_type.deserialize(serialized)
return self.header_type(schema_type)(**data)

async def subscribe(self, topic_name: str, group_name: str) -> None:
async def subscribe(
self, topic_name: str, group_name: str, latest: bool = False
) -> None:
pass

async def unsubscribe(self, topic_name: str, group_name: str) -> None:
Expand All @@ -72,6 +76,7 @@ async def nack(
) -> None:
raise NotImplementedError

@asynccontextmanager
@abc.abstractmethod
async def receive(
self,
Expand All @@ -81,8 +86,9 @@ async def receive(
schema_type: Type[T],
batch_size: int = BATCH_SIZE,
timeout: Optional[float] = MESSAGING_TIMEOUT,
) -> Sequence[HeaderABC[T]]:
) -> AsyncIterator[Sequence[HeaderABC[T]]]:
raise NotImplementedError
yield

@abc.abstractmethod
async def send(
Expand Down
78 changes: 54 additions & 24 deletions fastmicro/messaging/kafka.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import aiokafka
import asyncio
from contextlib import asynccontextmanager
import sys
from typing import (
Any,
AsyncIterator,
cast,
Dict,
Generic,
Expand All @@ -28,10 +31,7 @@ class Header(HeaderABC[T], Generic[T]):
offset: Optional[int] = None


class Messaging(MessagingABC):
def header_type(self, schema_type: Type[T]) -> Type[Header[T]]:
return Header[schema_type] # type: ignore

class Consumer(aiokafka.AIOKafkaConsumer): # type: ignore
class ConsumerRebalanceListener(aiokafka.ConsumerRebalanceListener): # type: ignore
def __init__(self, lock: asyncio.Lock):
self.lock = lock
Expand All @@ -46,6 +46,31 @@ async def on_partitions_assigned(
) -> None:
self.lock.release()

def __init__(
self,
*args: Any,
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop(),
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)

self.lock = asyncio.Lock(loop=loop)

def subscribe(
self,
*args: Any,
listener: Optional[aiokafka.ConsumerRebalanceListener] = None,
**kwargs: Any,
) -> None:
if not listener:
listener = self.ConsumerRebalanceListener(self.lock)
super().subscribe(*args, listener=listener, **kwargs)


class Messaging(MessagingABC):
def header_type(self, schema_type: Type[T]) -> Type[Header[T]]:
return Header[schema_type] # type: ignore

def __init__(
self,
bootstrap_servers: str = KAFKA_BOOTSTRAP_SERVERS,
Expand All @@ -56,28 +81,26 @@ def __init__(
self.bootstrap_servers = bootstrap_servers
self.initialized = False

async def _get_consumer(
self, topic_name: str, group_name: str
) -> aiokafka.AIOKafkaConsumer:
async def _create_consumer(
self, topic_name: str, group_name: str, auto_offset_reset: str
) -> None:
key = topic_name, group_name
if key not in self.consumers:
consumer = aiokafka.AIOKafkaConsumer(
consumer = Consumer(
bootstrap_servers=self.bootstrap_servers,
loop=self.loop,
group_id=group_name,
auto_offset_reset="latest",
auto_offset_reset=auto_offset_reset,
enable_auto_commit=False,
isolation_level="read_committed",
)
consumer.subscribe([topic_name], listener=self.listener)
consumer.subscribe([topic_name])
await consumer.start()

self.consumers[key] = consumer
return self.consumers[key]

async def connect(self) -> None:
if not self.initialized:
self.lock = asyncio.Lock(loop=self.loop)
self.listener = self.ConsumerRebalanceListener(self.lock)
self.consumers: Dict[Tuple[str, str], aiokafka.AIOKafkaConsumer] = dict()

self.producer = aiokafka.AIOKafkaProducer(
Expand All @@ -95,8 +118,12 @@ async def cleanup(self) -> None:
await self.producer.stop()
self.initialized = False

async def subscribe(self, topic_name: str, group_name: str) -> None:
await self._get_consumer(topic_name, group_name)
async def subscribe(
self, topic_name: str, group_name: str, latest: bool = False
) -> None:
await self._create_consumer(
topic_name, group_name, "latest" if latest else "earliest"
)

async def ack(
self, topic_name: str, group_name: str, headers: Sequence[HeaderABC[T]]
Expand All @@ -113,7 +140,8 @@ async def ack(
for partition in partitions
}

consumer = await self._get_consumer(topic_name, group_name)
key = topic_name, group_name
consumer = self.consumers[key]
await consumer.commit(offsets)

async def nack(
Expand All @@ -131,6 +159,7 @@ async def _receive(
header.offset = raw_message.offset
return header

@asynccontextmanager
async def receive(
self,
topic_name: str,
Expand All @@ -139,24 +168,27 @@ async def receive(
schema_type: Type[T],
batch_size: int = BATCH_SIZE,
timeout: Optional[float] = MESSAGING_TIMEOUT,
) -> Sequence[Header[T]]:
consumer = await self._get_consumer(topic_name, group_name)
) -> AsyncIterator[Sequence[Header[T]]]:
key = topic_name, group_name
consumer = self.consumers[key]
temp = await consumer.getmany(
timeout_ms=int(timeout * 1000) if timeout is not None else sys.maxsize,
max_records=batch_size,
)
if not temp.items():
raise asyncio.TimeoutError(f"Timed out after {timeout} sec")

await self.lock.acquire()
await consumer.lock.acquire()

tasks = [
self._receive(message, schema_type)
for _, messages in temp.items()
for message in messages
]
headers = await asyncio.gather(*tasks)
self.lock.release()
return cast(Sequence[Header[T]], headers)
yield cast(Sequence[Header[T]], headers)

consumer.lock.release()

async def _send(
self,
Expand All @@ -166,9 +198,7 @@ async def _send(
serialized = await self.serialize(header)
await self.producer.send_and_wait(topic_name, serialized)

async def send(
self, topic_name: str, headers: Sequence[HeaderABC[T]]
) -> None:
async def send(self, topic_name: str, headers: Sequence[HeaderABC[T]]) -> None:
headers = cast(Sequence[Header[T]], headers)
tasks = [self._send(topic_name, header) for header in headers]
await asyncio.gather(*tasks)
9 changes: 6 additions & 3 deletions fastmicro/messaging/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from contextlib import asynccontextmanager
from typing import (
AsyncIterator,
cast,
Dict,
Generic,
Expand All @@ -20,7 +22,7 @@


class Queue(Generic[QT]):
def __init__(self):
def __init__(self) -> None:
self.nacked: asyncio.Queue[Tuple[bytes, QT]] = asyncio.Queue()
self.queue: asyncio.Queue[Tuple[bytes, QT]] = asyncio.Queue()
self.pending: Dict[bytes, QT] = dict()
Expand Down Expand Up @@ -94,6 +96,7 @@ async def _receive(self, queue: Queue[bytes], schema_type: Type[T]) -> Header[T]
header.message_id = message_id
return header

@asynccontextmanager
async def receive(
self,
topic_name: str,
Expand All @@ -102,14 +105,14 @@ async def receive(
schema_type: Type[T],
batch_size: int = BATCH_SIZE,
timeout: Optional[float] = MESSAGING_TIMEOUT,
) -> Sequence[Header[T]]:
) -> AsyncIterator[Sequence[Header[T]]]:
queue = await self._get_queue(topic_name)
tasks = [self._receive(queue, schema_type) for i in range(batch_size)]
try:
headers = await asyncio.wait_for(asyncio.gather(*tasks), timeout=timeout)
except asyncio.TimeoutError:
raise asyncio.TimeoutError(f"Timed out after {timeout} sec")
return cast(Sequence[Header[T]], headers)
yield cast(Sequence[Header[T]], headers)

async def _send(self, queue: Queue[bytes], header: Header[T]) -> None:
serialized = await self.serialize(header)
Expand Down
19 changes: 11 additions & 8 deletions fastmicro/messaging/redis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import aioredis
import asyncio
from contextlib import asynccontextmanager
import sys
from typing import (
AsyncIterator,
cast,
Dict,
Generic,
Expand Down Expand Up @@ -52,13 +54,11 @@ async def _group_exists(self, topic_name: str, group_name: str) -> bool:
return True
return False

async def _create_group(self, topic_name: str, group_name: str) -> None:
async def _create_group(self, topic_name: str, group_name: str, id: str) -> None:
if not await self._topic_exists(topic_name) or not await self._group_exists(
topic_name, group_name
):
await self.redis.xgroup_create(
topic_name, group_name, id="$", mkstream=True
)
await self.redis.xgroup_create(topic_name, group_name, id=id, mkstream=True)

async def connect(self) -> None:
if not self.initialized:
Expand All @@ -68,8 +68,10 @@ async def cleanup(self) -> None:
if self.initialized:
await self.redis.close()

async def subscribe(self, topic_name: str, group_name: str) -> None:
await self._create_group(topic_name, group_name)
async def subscribe(
self, topic_name: str, group_name: str, latest: bool = False
) -> None:
await self._create_group(topic_name, group_name, "$" if latest else "0")

async def ack(
self, topic_name: str, group_name: str, headers: Sequence[HeaderABC[T]]
Expand All @@ -92,6 +94,7 @@ async def _receive(
header.message_id = message_id
return header

@asynccontextmanager
async def receive(
self,
topic_name: str,
Expand All @@ -100,7 +103,7 @@ async def receive(
schema_type: Type[T],
batch_size: int = BATCH_SIZE,
timeout: Optional[float] = MESSAGING_TIMEOUT,
) -> Sequence[Header[T]]:
) -> AsyncIterator[Sequence[Header[T]]]:
temp = await self.redis.xreadgroup(
group_name,
consumer_name,
Expand All @@ -120,7 +123,7 @@ async def receive(
for message_id, message in temp[0][1]
]
headers = await asyncio.gather(*tasks)
return cast(Sequence[Header[T]], headers)
yield cast(Sequence[Header[T]], headers)

async def _send(self, topic_name: str, header: Header[T]) -> None:
serialized = await self.serialize(header)
Expand Down
39 changes: 20 additions & 19 deletions fastmicro/messaging/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ async def connect(self) -> None:
async def cleanup(self) -> None:
await self.messaging.cleanup()

async def subscribe(self, group_name: str) -> None:
await self.messaging.subscribe(self.name, group_name)
async def subscribe(self, group_name: str, latest: bool = False) -> None:
await self.messaging.subscribe(self.name, group_name, latest)

async def unsubscribe(self, group_name: str) -> None:
await self.messaging.unsubscribe(self.name, group_name)
Expand All @@ -55,35 +55,36 @@ async def receive(
consumer_name: str,
batch_size: int = BATCH_SIZE,
timeout: Optional[float] = MESSAGING_TIMEOUT,
latest: bool = False,
) -> AsyncIterator[Sequence[HeaderABC[T]]]:
await self.messaging.subscribe(self.name, group_name)
await self.subscribe(group_name, latest)

headers = await self.messaging.receive(
async with self.messaging.receive(
self.name,
group_name,
consumer_name,
self.schema_type,
batch_size,
timeout,
)
if logger.isEnabledFor(logging.DEBUG):
for header in headers:
logger.debug(f"Received {header}")

try:
yield headers

) as headers:
if logger.isEnabledFor(logging.DEBUG):
for header in headers:
logger.debug(f"Acking {header}")
await self.messaging.ack(self.name, group_name, headers)
except Exception as e:
if headers:
logger.debug(f"Received {header}")

try:
yield headers

if logger.isEnabledFor(logging.DEBUG):
for header in headers:
logger.debug(f"Nacking {header}")
await self.messaging.nack(self.name, group_name, headers)
raise e
logger.debug(f"Acking {header}")
await self.messaging.ack(self.name, group_name, headers)
except Exception as e:
if headers:
if logger.isEnabledFor(logging.DEBUG):
for header in headers:
logger.debug(f"Nacking {header}")
await self.messaging.nack(self.name, group_name, headers)
raise e

async def send(
self,
Expand Down
Loading

0 comments on commit 0ddd7ff

Please sign in to comment.