From 369376719c87f9f4c9bfef833a791f1a9a7bd6b7 Mon Sep 17 00:00:00 2001 From: Andrei Neagu <5694077+GitHK@users.noreply.github.com> Date: Fri, 26 Jan 2024 08:02:51 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20subscribe=20to=20socketio=20room=20?= =?UTF-8?q?based=20on=20the=20user=5Fid=20(#5270)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andrei Neagu --- .../api_schemas_webserver/socketio.py | 17 ++++++++ .../test_api_schemas_webserver_socketio.py | 27 +++++++++++++ .../services/notifier.py | 6 +-- .../tests/unit/test_services_notifier.py | 8 ++-- .../_rabbitmq_exclusive_queue_consumers.py | 5 ++- .../projects/projects_api.py | 12 +++--- .../socketio/_handlers.py | 6 ++- .../socketio/messages.py | 40 +++++++------------ 8 files changed, 78 insertions(+), 43 deletions(-) create mode 100644 packages/models-library/src/models_library/api_schemas_webserver/socketio.py create mode 100644 packages/models-library/tests/test_api_schemas_webserver_socketio.py diff --git a/packages/models-library/src/models_library/api_schemas_webserver/socketio.py b/packages/models-library/src/models_library/api_schemas_webserver/socketio.py new file mode 100644 index 00000000000..6221387aa11 --- /dev/null +++ b/packages/models-library/src/models_library/api_schemas_webserver/socketio.py @@ -0,0 +1,17 @@ +from ..users import GroupID, UserID + + +class SocketIORoom(str): + __slots__ = () + + @classmethod + def from_socket_id(cls, socket_id: str) -> "SocketIORoom": + return cls(socket_id) + + @classmethod + def from_group_id(cls, group_id: GroupID) -> "SocketIORoom": + return cls(f"group:{group_id}") + + @classmethod + def from_user_id(cls, user_id: UserID) -> "SocketIORoom": + return cls(f"user:{user_id}") diff --git a/packages/models-library/tests/test_api_schemas_webserver_socketio.py b/packages/models-library/tests/test_api_schemas_webserver_socketio.py new file mode 100644 index 00000000000..9136dfda8ae --- /dev/null +++ b/packages/models-library/tests/test_api_schemas_webserver_socketio.py @@ -0,0 +1,27 @@ +# pylint:disable=redefined-outer-name + +import pytest +from faker import Faker +from models_library.api_schemas_webserver.socketio import SocketIORoom +from models_library.users import GroupID, UserID + + +@pytest.fixture +def user_id(faker: Faker) -> UserID: + return UserID(faker.pyint()) + + +@pytest.fixture +def group_id(faker: Faker) -> GroupID: + return GroupID(faker.pyint()) + + +@pytest.fixture +def socket_id(faker: Faker) -> str: + return faker.pystr() + + +def test_socketio_room(user_id: UserID, group_id: GroupID, socket_id: str): + assert SocketIORoom.from_user_id(user_id) == f"user:{user_id}" + assert SocketIORoom.from_group_id(group_id) == f"group:{group_id}" + assert SocketIORoom.from_socket_id(socket_id) == socket_id diff --git a/services/payments/src/simcore_service_payments/services/notifier.py b/services/payments/src/simcore_service_payments/services/notifier.py index e4cd5bd8ef8..b1c6b64fe38 100644 --- a/services/payments/src/simcore_service_payments/services/notifier.py +++ b/services/payments/src/simcore_service_payments/services/notifier.py @@ -8,6 +8,7 @@ SOCKET_IO_PAYMENT_COMPLETED_EVENT, SOCKET_IO_PAYMENT_METHOD_ACKED_EVENT, ) +from models_library.api_schemas_webserver.socketio import SocketIORoom from models_library.api_schemas_webserver.wallets import ( PaymentMethodTransaction, PaymentTransaction, @@ -48,7 +49,7 @@ async def notify_payment_completed( return await self._sio_manager.emit( SOCKET_IO_PAYMENT_COMPLETED_EVENT, data=jsonable_encoder(payment, by_alias=True), - room=f"{user_primary_group_id}", + room=SocketIORoom.from_group_id(user_primary_group_id), ) async def notify_payment_method_acked( @@ -61,7 +62,7 @@ async def notify_payment_method_acked( return await self._sio_manager.emit( SOCKET_IO_PAYMENT_METHOD_ACKED_EVENT, data=jsonable_encoder(payment_method, by_alias=True), - room=f"{user_primary_group_id}", + room=SocketIORoom.from_group_id(user_primary_group_id), ) @@ -77,7 +78,6 @@ async def _on_startup() -> None: assert Notifier.get_from_app_state(app) == notifier # nosec async def _on_shutdown() -> None: - with contextlib.suppress(AttributeError): Notifier.pop_from_app_state(app) diff --git a/services/payments/tests/unit/test_services_notifier.py b/services/payments/tests/unit/test_services_notifier.py index 8c20a8be078..57b4e5b9432 100644 --- a/services/payments/tests/unit/test_services_notifier.py +++ b/services/payments/tests/unit/test_services_notifier.py @@ -18,6 +18,7 @@ from models_library.api_schemas_payments.socketio import ( SOCKET_IO_PAYMENT_COMPLETED_EVENT, ) +from models_library.api_schemas_webserver.socketio import SocketIORoom from models_library.api_schemas_webserver.wallets import PaymentTransaction from models_library.users import GroupID, UserID from pydantic import parse_obj_as @@ -95,13 +96,12 @@ def socketio_server_events( mocker: MockerFixture, user_primary_group_id: GroupID, ) -> dict[str, AsyncMock]: - - user_room_name = f"{user_primary_group_id}" + room_name = SocketIORoom.from_group_id(user_primary_group_id) # handlers async def connect(sid: str, environ): print("connecting", sid) - await socketio_server.enter_room(sid, user_room_name) + await socketio_server.enter_room(sid, room_name) async def on_check(sid, data): print("check", sid, data) @@ -111,7 +111,7 @@ async def on_payment(sid, data): async def disconnect(sid: str): print("disconnecting", sid) - await socketio_server.leave_room(sid, user_room_name) + await socketio_server.leave_room(sid, room_name) # spies spy_connect = mocker.AsyncMock(wraps=connect) diff --git a/services/web/server/src/simcore_service_webserver/notifications/_rabbitmq_exclusive_queue_consumers.py b/services/web/server/src/simcore_service_webserver/notifications/_rabbitmq_exclusive_queue_consumers.py index a2b3280c347..d196f951148 100644 --- a/services/web/server/src/simcore_service_webserver/notifications/_rabbitmq_exclusive_queue_consumers.py +++ b/services/web/server/src/simcore_service_webserver/notifications/_rabbitmq_exclusive_queue_consumers.py @@ -12,6 +12,7 @@ WalletCreditsMessage, ) from models_library.socketio import SocketMessageDict +from models_library.users import GroupID from pydantic import parse_raw_as from servicelib.logging_utils import log_catch, log_context from servicelib.rabbitmq import RabbitMQClient @@ -152,13 +153,13 @@ async def _osparc_credits_message_parser(app: web.Application, data: bytes) -> b wallet_groups = await wallets_api.list_wallet_groups_with_read_access_by_wallet( app, wallet_id=rabbit_message.wallet_id ) - rooms_to_notify = [f"{item.gid}" for item in wallet_groups] + rooms_to_notify: list[GroupID] = [item.gid for item in wallet_groups] for room in rooms_to_notify: await send_group_messages(app, room, socket_messages) return True -_EXCHANGE_TO_PARSER_CONFIG: Final[tuple[SubcribeArgumentsTuple, ...,]] = ( +_EXCHANGE_TO_PARSER_CONFIG: Final[tuple[SubcribeArgumentsTuple, ...]] = ( SubcribeArgumentsTuple( LoggerRabbitMessage.get_channel_name(), _log_message_parser, diff --git a/services/web/server/src/simcore_service_webserver/projects/projects_api.py b/services/web/server/src/simcore_service_webserver/projects/projects_api.py index d0e5f9ee0d4..6f84ec83c4f 100644 --- a/services/web/server/src/simcore_service_webserver/projects/projects_api.py +++ b/services/web/server/src/simcore_service_webserver/projects/projects_api.py @@ -51,7 +51,7 @@ ServiceResourcesDictHelpers, ) from models_library.socketio import SocketMessageDict -from models_library.users import UserID +from models_library.users import GroupID, UserID from models_library.utils.fastapi_encoders import jsonable_encoder from models_library.wallets import ZERO_CREDITS, WalletID, WalletInfo from pydantic import ByteSize, parse_obj_as @@ -1458,10 +1458,8 @@ async def notify_project_state_update( if notify_only_user: await send_messages(app, user_id=f"{notify_only_user}", messages=messages) else: - rooms_to_notify = [ - f"{gid}" - for gid, rights in project["accessRights"].items() - if rights["read"] + rooms_to_notify: list[GroupID] = [ + gid for gid, rights in project["accessRights"].items() if rights["read"] ] for room in rooms_to_notify: await send_group_messages(app, room, messages) @@ -1476,8 +1474,8 @@ async def notify_project_node_update( if await is_project_hidden(app, ProjectID(project["uuid"])): return - rooms_to_notify = [ - f"{gid}" for gid, rights in project["accessRights"].items() if rights["read"] + rooms_to_notify: list[GroupID] = [ + gid for gid, rights in project["accessRights"].items() if rights["read"] ] messages: list[SocketMessageDict] = [ diff --git a/services/web/server/src/simcore_service_webserver/socketio/_handlers.py b/services/web/server/src/simcore_service_webserver/socketio/_handlers.py index ad1eacca7ec..cbf5536899a 100644 --- a/services/web/server/src/simcore_service_webserver/socketio/_handlers.py +++ b/services/web/server/src/simcore_service_webserver/socketio/_handlers.py @@ -8,6 +8,7 @@ from typing import Any from aiohttp import web +from models_library.api_schemas_webserver.socketio import SocketIORoom from models_library.socketio import SocketMessageDict from models_library.users import UserID from servicelib.aiohttp.observer import emit @@ -89,7 +90,10 @@ async def _set_user_in_group_rooms( sio = get_socket_server(app) for group in groups: - sio.enter_room(socket_id, f"{group['gid']}") + # NOTE socketio need to be upgraded that's why enter_room is not an awaitable + sio.enter_room(socket_id, SocketIORoom.from_group_id(group["gid"])) + + sio.enter_room(socket_id, SocketIORoom.from_user_id(user_id)) # diff --git a/services/web/server/src/simcore_service_webserver/socketio/messages.py b/services/web/server/src/simcore_service_webserver/socketio/messages.py index f76d4545b69..8412870814a 100644 --- a/services/web/server/src/simcore_service_webserver/socketio/messages.py +++ b/services/web/server/src/simcore_service_webserver/socketio/messages.py @@ -7,11 +7,11 @@ from typing import Final from aiohttp.web import Application +from models_library.api_schemas_webserver.socketio import SocketIORoom from models_library.socketio import SocketMessageDict -from models_library.users import UserID -from servicelib.aiohttp.application_keys import APP_FIRE_AND_FORGET_TASKS_KEY +from models_library.users import GroupID, UserID from servicelib.json_serialization import json_dumps -from servicelib.utils import fire_and_forget_task, logged_gather +from servicelib.utils import logged_gather from socketio import AsyncServer from ..resource_manager.user_sessions import managed_resource @@ -44,7 +44,11 @@ async def send_messages( await logged_gather( *( - sio.emit(message["event_type"], json_dumps(message["data"]), room=sid) + sio.emit( + message["event_type"], + json_dumps(message["data"]), + room=SocketIORoom.from_socket_id(sid), + ) for message in messages for sid in socket_ids ), @@ -54,32 +58,16 @@ async def send_messages( ) -async def post_messages( - app: Application, user_id: UserID, messages: Sequence[SocketMessageDict] -) -> None: - fire_and_forget_task( - send_messages(app, user_id, messages), - task_suffix_name=f"post_message_{user_id=}", - fire_and_forget_tasks_collection=app[APP_FIRE_AND_FORGET_TASKS_KEY], - ) - - -async def post_group_messages( - app: Application, room: str, messages: Sequence[SocketMessageDict] -) -> None: - fire_and_forget_task( - send_group_messages(app, room, messages), - task_suffix_name=f"post_group_messages_{room=}", - fire_and_forget_tasks_collection=app[APP_FIRE_AND_FORGET_TASKS_KEY], - ) - - async def send_group_messages( - app: Application, room: str, messages: Sequence[SocketMessageDict] + app: Application, group_id: GroupID, messages: Sequence[SocketMessageDict] ) -> None: sio: AsyncServer = get_socket_server(app) send_tasks = [ - sio.emit(message["event_type"], json_dumps(message["data"]), room=room) + sio.emit( + message["event_type"], + json_dumps(message["data"]), + room=SocketIORoom.from_group_id(group_id), + ) for message in messages ]