Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡ Waypoint start time #1137

Merged
merged 17 commits into from
Oct 25, 2024
6 changes: 6 additions & 0 deletions app/routes/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ async def get_sse_subscribe_event_with_field_and_state(
field_id: str,
desired_state: str,
group_id: Optional[str] = group_id_query,
look_back: Optional[int] = Query(
default=300, description="Number of seconds to look back for events"
),
auth: AcaPyAuthVerified = Depends(acapy_auth_verified),
) -> StreamingResponse:
"""
Expand Down Expand Up @@ -63,6 +66,8 @@ async def get_sse_subscribe_event_with_field_and_state(
The ID of the field subscribing to the events.
desired_state:
The desired state to be reached.
look_back:
Number of seconds to look back for events before subscribing.
"""
logger.bind(
body={
Expand All @@ -87,6 +92,7 @@ async def get_sse_subscribe_event_with_field_and_state(
field=field,
field_id=field_id,
desired_state=desired_state,
look_back=look_back,
),
media_type="text/event-stream",
)
5 changes: 4 additions & 1 deletion app/services/event_handling/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ async def sse_subscribe_event_with_field_and_state(
field: str,
field_id: str,
desired_state: str,
look_back: int = 300,
) -> AsyncGenerator[str, None]:
"""
Subscribe to server-side events for a specific wallet ID and topic.
Expand All @@ -56,8 +57,10 @@ async def sse_subscribe_event_with_field_and_state(
)

params = {}
if group_id: # Optional param
if group_id: # Optional params
params["group_id"] = group_id
if look_back:
params["look_back"] = look_back

try:
async with RichAsyncClient(timeout=event_timeout) as client:
Expand Down
2 changes: 2 additions & 0 deletions app/tests/routes/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def test_get_sse_subscribe_event_with_field_and_state(
desired_state=state,
group_id=group_id,
auth=mock_auth,
look_back=300,
)

assert response.media_type == "text/event-stream"
Expand All @@ -67,4 +68,5 @@ async def test_get_sse_subscribe_event_with_field_and_state(
field=field,
field_id=field_id,
desired_state=state,
look_back=300,
)
2 changes: 1 addition & 1 deletion app/tests/services/event_handling/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ async def test_sse_subscribe_event_with_field_and_state_success(
patch_yield_lines_with_disconnect_check, # pylint: disable=redefined-outer-name
group_id: Optional[str],
):
expected_params = {}
expected_params = {"look_back": 300}
if group_id: # Optional param
expected_params["group_id"] = group_id

Expand Down
2 changes: 1 addition & 1 deletion app/tests/util/sse_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def wait_for_event(
"""
Start listening for SSE events. When an event is received that matches the specified parameters.
"""
url = f"{waypoint_base_url}/{self.wallet_id}/{self.topic}/{field}/{field_id}/{desired_state}"
url = f"{waypoint_base_url}/{self.wallet_id}/{self.topic}/{field}/{field_id}/{desired_state}?look_back=5"

timeout = Timeout(timeout)
async with RichAsyncClient(timeout=timeout) as client:
Expand Down
7 changes: 7 additions & 0 deletions waypoint/routers/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def nats_event_stream_generator(
field_id: str,
desired_state: str,
group_id: Optional[str],
look_back: Optional[int],
nats_processor: NatsEventsProcessor,
) -> AsyncGenerator[str, None]:
"""
Expand All @@ -53,6 +54,7 @@ async def nats_event_stream_generator(
topic=topic,
stop_event=stop_event,
duration=SSE_TIMEOUT,
look_back=look_back,
) as event_generator:
background_tasks.add_task(check_disconnect, request, stop_event)

Expand Down Expand Up @@ -94,6 +96,10 @@ async def sse_wait_for_event_with_field_and_state(
group_id: Optional[str] = Query(
default=None, description="Group ID to which the wallet belongs"
),
look_back: Optional[int] = Query(
default=300,
description="Number of seconds to look back for events before subscribing",
),
nats_processor: NatsEventsProcessor = Depends(
Provide[Container.nats_events_processor]
),
Expand Down Expand Up @@ -121,6 +127,7 @@ async def sse_wait_for_event_with_field_and_state(
field_id=field_id,
desired_state=desired_state,
group_id=group_id,
look_back=look_back,
nats_processor=nats_processor,
)

Expand Down
50 changes: 33 additions & 17 deletions waypoint/services/nats_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import time
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone

import orjson
from nats.errors import BadSubscriptionError, Error, TimeoutError
from nats.js.api import ConsumerConfig, DeliverPolicy
from nats.js.client import JetStreamContext

from shared.constants import NATS_STREAM, NATS_SUBJECT
Expand All @@ -23,24 +25,35 @@ def __init__(self, jetstream: JetStreamContext):
self.js_context: JetStreamContext = jetstream

async def _subscribe(
self, group_id: str, wallet_id: str
self, group_id: str, wallet_id: str, look_back: int
) -> JetStreamContext.PullSubscription:
try:
logger.debug("Subscribing to JetStream...")
if group_id:

logger.trace("Tenant-admin call got group_id: {}", group_id)
subscribe_kwargs = {
"subject": f"{NATS_SUBJECT}.{group_id}.{wallet_id}",
"stream": NATS_STREAM,
}
else:
logger.trace("Tenant call got no group_id")
subscribe_kwargs = {
"subject": f"{NATS_SUBJECT}.*.{wallet_id}",
"stream": NATS_STREAM,
}
subscription = await self.js_context.pull_subscribe(**subscribe_kwargs)
logger.trace(
"Subscribing to JetStream for wallet_id: {}, group_id: {}",
wallet_id,
group_id,
)
group_id = group_id or "*"
subscribe_kwargs = {
"subject": f"{NATS_SUBJECT}.{group_id}.{wallet_id}",
"stream": NATS_STREAM,
}

# Get the current time in UTC
current_time = datetime.now(timezone.utc)

# Subtract 30 seconds
look_back_time = current_time - timedelta(seconds=look_back)

# Format the time in the required format
start_time = look_back_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
config = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_TIME,
opt_start_time=start_time,
)
subscription = await self.js_context.pull_subscribe(
config=config, **subscribe_kwargs
)

return subscription

Expand All @@ -62,6 +75,7 @@ async def process_events(
topic: str,
stop_event: asyncio.Event,
duration: int = 10,
look_back: int = 300,
):
logger.debug(
"Processing events for group {} and wallet {} on topic {}",
Expand All @@ -70,7 +84,9 @@ async def process_events(
topic,
)

subscription = await self._subscribe(group_id=group_id, wallet_id=wallet_id)
subscription = await self._subscribe(
group_id=group_id, wallet_id=wallet_id, look_back=look_back
)

async def event_generator():
end_time = time.time() + duration
Expand Down
5 changes: 5 additions & 0 deletions waypoint/tests/routers/test_waypoint_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def mock_event_generator():
desired_state=desired_state,
group_id=group_id,
nats_processor=nats_processor_mock,
look_back=300,
):
events.append(event)

Expand Down Expand Up @@ -131,6 +132,7 @@ async def mock_event_generator():
field_id=field_id,
desired_state=desired_state,
group_id=group_id,
look_back=300,
nats_processor=nats_processor_mock,
):
pass
Expand Down Expand Up @@ -163,6 +165,7 @@ async def mock_event_generator():
field_id="some_field_id",
desired_state="some_state",
group_id="some_group",
look_back=300,
nats_processor=nats_processor_mock,
)

Expand Down Expand Up @@ -196,6 +199,7 @@ async def test_sse_event_stream(
field_id=field_id,
desired_state=desired_state,
group_id=group_id,
look_back=300,
nats_processor=nats_processor_mock,
)

Expand All @@ -211,5 +215,6 @@ async def test_sse_event_stream(
field_id=field_id,
desired_state=desired_state,
group_id=group_id,
look_back=300,
nats_processor=nats_processor_mock,
)
25 changes: 16 additions & 9 deletions waypoint/tests/services/test_nats_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from nats.aio.client import Client as NATS
from nats.aio.errors import ErrConnectionClosed, ErrNoServers, ErrTimeout
from nats.errors import BadSubscriptionError, Error, TimeoutError
from nats.js.api import ConsumerConfig, DeliverPolicy
from nats.js.client import JetStreamContext

from shared.constants import NATS_STREAM, NATS_SUBJECT
Expand Down Expand Up @@ -53,14 +54,20 @@ async def test_nats_events_processor_subscribe(
mock_nats_client.pull_subscribe.return_value = AsyncMock(
spec=JetStreamContext.PullSubscription
)

subscription = await processor._subscribe( # pylint: disable=protected-access
"group_id", "wallet_id"
)
mock_nats_client.pull_subscribe.assert_called_once_with(
subject=f"{NATS_SUBJECT}.group_id.wallet_id", stream=NATS_STREAM
)
assert isinstance(subscription, JetStreamContext.PullSubscription)
with patch("waypoint.services.nats_service.ConsumerConfig") as mock_config:
mock_config.return_value = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_TIME,
opt_start_time="2024-10-24T09:17:17.998149541Z",
)
subscription = await processor._subscribe( # pylint: disable=protected-access
"group_id", "wallet_id", 300
)
mock_nats_client.pull_subscribe.assert_called_once_with(
subject=f"{NATS_SUBJECT}.group_id.wallet_id",
stream=NATS_STREAM,
config=mock_config.return_value,
)
assert isinstance(subscription, JetStreamContext.PullSubscription)


@pytest.mark.anyio
Expand All @@ -72,7 +79,7 @@ async def test_nats_events_processor_subscribe_error(
mock_nats_client.pull_subscribe.side_effect = exception

with pytest.raises(exception):
await processor._subscribe("group_id", "wallet_id")
await processor._subscribe("group_id", "wallet_id", 300)
Dismissed Show dismissed Hide dismissed


@pytest.mark.anyio
Expand Down
Loading