diff --git a/ror_server_bot/ror_bot/packet_handler.py b/ror_server_bot/ror_bot/packet_handler.py deleted file mode 100644 index 5b5aba8..0000000 --- a/ror_server_bot/ror_bot/packet_handler.py +++ /dev/null @@ -1,80 +0,0 @@ -import inspect - -from pyee.asyncio import AsyncIOEventEmitter - -from .enums import MessageType -from .models import Packet - - -class PacketHandler: - def __init__(self, event_emitter: AsyncIOEventEmitter) -> None: - """Create a new PacketHandler. This class is used to register - event handlers on an EventEmitter. The event handlers are - methods of this class that start with 'on_' or 'once_' and - end with the name of the packet type. - - Methods must have the following signature: - ``` - def on_hello(self, packet: Packet) -> None: - ... - def once_hello(self, packet: Packet) -> None: - ... - ``` - - The method `on_hello` will be called when a `MessageType.HELLO` - packet is received and the method `once_hello` will be called - once when a `MessageType.HELLO` packet is received. - - :param event_emitter: The EventEmitter to register the packet - handlers on. The event_emitter must have the wildcard option - enabled. See `pymitter.EventEmitter` for more information. - """ - for (name, method) in inspect.getmembers(self, inspect.ismethod): - for prefix in ('on_', 'once_'): - if name.startswith(prefix): - if name == prefix + 'packet': - event = 'packet' - else: - message_type = name[len(prefix):].upper() - if message_type not in MessageType._member_names_: - raise ValueError( - f'Invalid packet type: {message_type}' - ) - event = f'packet.{message_type}' - - parameters = inspect.signature(method).parameters - - if len(parameters) != 1: - raise ValueError( - f'Invalid signature for "{name}". ' - 'Expected 2 parameters: self, packet', - ) - - (param_name, parameter), *_ = parameters.items() - - if parameter.annotation != Packet: - raise ValueError( - f'Invalid signature for {name}. ' - f'Expected "{param_name}" to be of type Packet' - ) - - if prefix == 'on_': - event_emitter.on(event, method) - elif prefix == 'once_': - event_emitter.once(event, method) - - -if __name__ == '__main__': - class Test(PacketHandler): - def __init__(self, event_emitter: AsyncIOEventEmitter) -> None: - super().__init__(event_emitter) - - def on_hello(self, foo: Packet) -> None: - print('hello', foo) - - def other_method(self) -> None: - print('other_method') - - ee = AsyncIOEventEmitter() - test = Test(ee) - ee.emit('packet.*', Packet(command=MessageType.HELLO)) diff --git a/ror_server_bot/ror_bot/ror_client.py b/ror_server_bot/ror_bot/ror_client.py index 08fc4a0..6bac2d0 100644 --- a/ror_server_bot/ror_bot/ror_client.py +++ b/ror_server_bot/ror_bot/ror_client.py @@ -1,62 +1,16 @@ import asyncio import logging -import struct -import time -from enum import Enum from typing import Callable -from pyee.asyncio import AsyncIOEventEmitter - -from .enums import MessageType, StreamType -from .models import ( - Packet, - RoRClientConfig, - stream_data_factory, - stream_register_factory, - UserInfo, -) -from .packet_handler import PacketHandler -from .ror_connection import RoRConnection, UserNotFoundError -from .user import StreamNotFoundError +from .models import RoRClientConfig +from .ror_connection import RoRClientEvents, RoRConnection logger = logging.getLogger(__name__) -def check_packet_type(packet: Packet, message_type: MessageType) -> bool: - """Check if the packet is of the specified type. - - :param packet: The packet to check. - :param message_type: The type to check the packet against. - :return: True if the packet is of the specified type, False - """ - return packet.command is message_type - - -class RoRClientEvents(Enum): - FRAME_STEP = 'frame_step' - NET_QUALITY = 'net_quality' - CHAT = 'chat' - PRIVATE_CHAT = 'private_chat' - USER_JOIN = 'user_join' - USER_INFO = 'user_info' - USER_LEAVE = 'user_leave' - GAME_CMD = 'game_cmd' - STREAM_REGISTER = 'stream_register' - STREAM_REGISTER_RESULT = 'stream_register_result' - STREAM_DATA = 'stream_data' - STREAM_UNREGISTER = 'stream_unregister' - - -class RoRClient(PacketHandler): - STABLE_FPS = 20 - +class RoRClient: def __init__(self, client_config: RoRClientConfig) -> None: - """Create a new RoRClient. This class is used to connect to a - RoR server and handle packets received from the server. It - inherits from PacketHandler and registers event handlers on - an AsyncIOEventEmitter. - - See `ror_client.packet_handler.PacketHandler` for more information. + """Create a new RoRClient. :param client_config: The configuration to use for the client. """ @@ -70,15 +24,6 @@ def __init__(self, client_config: RoRClientConfig) -> None: port=self.config.server.port, ) - super().__init__(self.server) - - self._frame_step_task: asyncio.Task - - self.event_emitter = AsyncIOEventEmitter() - - self.event_emitter.add_listener('new_listener', self._new_listener) - self.event_emitter.add_listener('error', self._error) - async def __aenter__(self): for attempt in range(self.config.reconnection_tries): try: @@ -109,45 +54,18 @@ async def __aenter__(self): f'after {self.config.reconnection_tries} attempts', ) - self._frame_step_task = asyncio.create_task( - self._frame_step_loop(), - name='frame_step_loop' - ) - return self async def __aexit__(self, exc_type, exc, tb): await self.server.__aexit__(exc_type, exc, tb) - if self._frame_step_task: - self._frame_step_task.cancel() - - def _new_listener(self, event: str, listener: Callable): - """Handles new listener events. - - :param event: The event that was added. - :param listener: The listener that was added. - """ - logger.debug( - 'New listener added: event="%s" listener="%s"', - event, - listener.__name__ - ) - - def _error(self, error: Exception): - """Handles error events. - - :param error: The error that was emitted. - """ - logger.error('%r', error, exc_info=True) - def on(self, event: RoRClientEvents, listener: Callable | None = None): """Decorator to register an event handler on the event emitter. :param event: The event to register the handler on. :param listener: The listener to register. """ - return self.event_emitter.on(event.value, listener) + return self.server.on(event, listener) def once(self, event: RoRClientEvents, listener: Callable | None = None): """Decorator to register a one-time event handler on the event @@ -156,285 +74,4 @@ def once(self, event: RoRClientEvents, listener: Callable | None = None): :param event: The event to register the handler on. :param listener: The listener to register. """ - return self.event_emitter.once(event.value, listener) - - def emit(self, event: RoRClientEvents, *args, **kwargs): - """Emit an event on the event emitter. - - :param event: The event to emit. - :param args: The arguments to pass to the event handler. - :param kwargs: The keyword arguments to pass to the event handler. - """ - if event is not RoRClientEvents.FRAME_STEP: - # we do not need to log every frame_step event emit - logger.debug( - '[EMIT] event=%r listeners=%d', - event.value, - len(self.event_emitter.listeners(event.value)) - ) - self.event_emitter.emit(event.value, *args, **kwargs) - - async def _frame_step_loop(self): - """Send frame_step events at a stable rate.""" - start_time = time.time() - current_time = start_time - delta = 0 - while True: - prev_time = current_time - current_time = time.time() - delta += current_time - prev_time - - if delta >= (self.STABLE_FPS / 60): - self.emit(RoRClientEvents.FRAME_STEP, delta) - delta = 0 - - await asyncio.sleep(0.01) - - def on_packet(self, packet: Packet) -> None: - """Handle packets received from the server. - - :param packet: The packet to handle. - """ - if ( - packet.command in (MessageType.HELLO, MessageType.WELCOME) - or f'packet.{packet.command.name}' in self.server.event_names() - ): - return - - logger.warning('Unhandled packet command: %r', packet.command) - - def on_net_quality(self, packet: Packet) -> None: - """Handle net_quality packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.NET_QUALITY): - return - - net_quality: int = struct.unpack('I', packet.data)[0] - - if self.server.net_quality != net_quality: - logger.info( - 'Net quality for uid=%d changed: %d -> %d', - packet.source, - self.server.net_quality, - net_quality - ) - self.emit(RoRClientEvents.NET_QUALITY, packet.source, net_quality) - else: - logger.debug( - 'Net quality unchanged for uid=%d: %d', - packet.source, - net_quality - ) - - self.server.net_quality = net_quality - - def on_chat(self, packet: Packet) -> None: - """Handle chat packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.CHAT): - return - - message = packet.data.decode().strip('\x00') - - logger.info('[CHAT] from_uid=%d message=%r', packet.source, message) - - if message and packet.source != self.server.unique_id: - self.emit(RoRClientEvents.CHAT, packet.source, message) - - def on_private_chat(self, packet: Packet) -> None: - """Handle private_chat packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.PRIVATE_CHAT): - return - - message = packet.data.decode().strip('\x00') - - logger.info('[PRIV] from_uid=%d message=%r', packet.source, message) - - if message and packet.source != self.server.unique_id: - self.emit(RoRClientEvents.PRIVATE_CHAT, packet.source, message) - - def on_user_join(self, packet: Packet) -> None: - """Handle user_join packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.USER_JOIN): - return - - logger.info('User join received') - - user_info = UserInfo.from_bytes(packet.data) - - self.server.add_user(user_info) - - if user_info.unique_id != self.server.unique_id: - self.emit(RoRClientEvents.USER_JOIN, packet.source, user_info) - - def on_user_info(self, packet: Packet) -> None: - """Handle user_info packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.USER_INFO): - return - - logger.info('User info received') - - user_info = UserInfo.from_bytes(packet.data) - - self.server.add_user(user_info) - - if user_info.unique_id != self.server.unique_id: - self.emit(RoRClientEvents.USER_INFO, packet.source, user_info) - - def on_user_leave(self, packet: Packet) -> None: - """Handle user_leave packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.USER_LEAVE): - return - - logger.info( - 'uid=%r left with reason: %s', - packet.source, - packet.data - ) - - if packet.source == self.server.unique_id: - raise ConnectionError('RoRClient disconnected from the server') - - self.server.delete_user(packet.source) - - self.emit(RoRClientEvents.USER_LEAVE, packet.source) - - def on_game_cmd(self, packet: Packet) -> None: - """Handle game_cmd packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.GAME_CMD): - return - - if packet.source == self.server.unique_id: - return - - game_cmd = packet.data.decode().strip('\x00') - - logger.debug('[GAME_CMD] from_uid=%d cmd=%r', packet.source, game_cmd) - - if game_cmd and packet.source != self.server.unique_id: - self.emit(RoRClientEvents.GAME_CMD, packet.source, game_cmd) - - async def on_stream_register(self, packet: Packet) -> None: - """Handle stream_register packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.STREAM_REGISTER): - return - - stream = stream_register_factory(packet.data) - - logger.info('Stream register received: %s', stream) - - self.server.add_stream(stream) - - if stream.type is StreamType.ACTOR: - # why? - await self.server.reply_to_stream_register(stream, status=-1) - - self.emit(RoRClientEvents.STREAM_REGISTER, packet.source, stream) - - def on_stream_register_result(self, packet: Packet) -> None: - """Handle stream_register_result packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.STREAM_REGISTER_RESULT): - return - - stream = stream_register_factory(packet.data) - - logger.info('Stream register result received: %s', stream) - - self.emit( - RoRClientEvents.STREAM_REGISTER_RESULT, - packet.source, - stream, - ) - - def on_stream_data(self, packet: Packet) -> None: - """Handle stream_data packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.STREAM_DATA): - return - - try: - stream = self.server.get_stream( - packet.source, - packet.stream_id - ) - except UserNotFoundError: - logger.error( - 'Could not find user! uid=%d', - packet.source, - exc_info=True - ) - except StreamNotFoundError: - logger.error( - 'Could not find stream! sid=%d', - packet.stream_id, - exc_info=True - ) - else: - if stream.type in (StreamType.CHARACTER, StreamType.ACTOR): - logger.info('%s stream received', stream.type.name.title()) - stream_data = stream_data_factory(stream.type, packet.data) - - logger.debug('Stream data: %s', stream_data) - elif stream.type is StreamType.CHAT: - logger.info('Chat stream received') - stream_data = None - else: - raise ValueError(f'Unknown stream type: {stream.type!r}') - - self.emit( - RoRClientEvents.STREAM_DATA, - packet.source, - stream, - stream_data - ) - - def on_stream_unregister(self, packet: Packet) -> None: - """Handle stream_unregister packets. - - :param packet: The packet to handle. - """ - if not check_packet_type(packet, MessageType.STREAM_UNREGISTER): - return - - if len(packet.data) != 0: - raise ValueError('Stream unregister packet has data') - - logger.info('Stream unregister received: %s', packet) - - self.server.delete_stream( - packet.source, - packet.stream_id - ) - - self.emit( - RoRClientEvents.STREAM_UNREGISTER, - packet.source, - packet.stream_id - ) + return self.server.once(event, listener) diff --git a/ror_server_bot/ror_bot/ror_connection.py b/ror_server_bot/ror_bot/ror_connection.py index cd3f4da..5e4bf86 100644 --- a/ror_server_bot/ror_bot/ror_connection.py +++ b/ror_server_bot/ror_bot/ror_connection.py @@ -1,18 +1,22 @@ import asyncio +import contextlib import hashlib import logging import math import struct import time from datetime import datetime +from enum import Enum +from functools import singledispatchmethod from itertools import chain from typing import Callable -from pyee.asyncio import AsyncIOEventEmitter +from pyee import AsyncIOEventEmitter from ror_server_bot import pformat, RORNET_VERSION from .enums import ( + ActorStreamStatus, AuthLevels, CharacterAnimation, CharacterCommand, @@ -21,36 +25,83 @@ ) from .models import ( ActorStreamRegister, + BannedPacket, CharacterPositionStreamData, CharacterStreamRegister, + ChatPacket, ChatStreamRegister, + GameCmdPacket, GlobalStats, Packet, + packet_factory, + HelloPacket, + NetQualityPacket, + PrivateChatPacket, + ServerFullPacket, ServerInfo, + stream_data_factory, + stream_register_factory, + StreamData, + StreamDataPacket, StreamRegister, + StreamRegisterPacket, + StreamRegisterResultPacket, + StreamUnregisterPacket, UserInfo, + UserInfoPacket, + UserJoinPacket, + UserLeavePacket, Vector3, Vector4, + WelcomePacket, + WrongPasswordPacket, + WrongVersionPacket, ) -from .user import User +from .user import StreamNotFoundError, User logger = logging.getLogger(__name__) -class PacketError(Exception): - """An error that occurs when a packet is malformed.""" - - -class UnexpectedCommandError(Exception): - """An error that occurs when a packet with an unexpected command is +class UnexpectedMessageError(Exception): + """An error that occurs when a header with an unexpected message is received.""" + def __init__(self, *args: object) -> None: + super().__init__(*args) + class UserNotFoundError(Exception): """Raised when a user is not found.""" + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +class UserAlreadyExistsError(Exception): + """Raised when a user already exists.""" + + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +class RoRClientEvents(Enum): + FRAME_STEP = 'frame_step' + NET_QUALITY = 'net_quality' + CHAT = 'chat' + PRIVATE_CHAT = 'private_chat' + USER_JOIN = 'user_join' + USER_INFO = 'user_info' + USER_LEAVE = 'user_leave' + GAME_CMD = 'game_cmd' + STREAM_REGISTER = 'stream_register' + STREAM_REGISTER_RESULT = 'stream_register_result' + STREAM_DATA = 'stream_data' + STREAM_UNREGISTER = 'stream_unregister' + + +class RoRConnection: + STABLE_FPS = 20 -class RoRConnection(AsyncIOEventEmitter): def __init__( self, username: str, @@ -75,37 +126,39 @@ def __init__( :param host: The IP address of the server. :param port: The port the server is running on. :param heartbeat_interval: The interval, in seconds, to send - heartbeat packets to the server, defaults to 1.0. + heartbeat packets to the server, defaults to 10.0. """ - super().__init__() - - self.add_listener('new_listener', self._new_listener) - self.add_listener('error', self._error) + self._connect_time: datetime self._reader: asyncio.StreamReader self._writer: asyncio.StreamWriter self._writer_lock: asyncio.Lock - self._reader_task: asyncio.Task self._heartbeat_task: asyncio.Task + self._frame_step_task: asyncio.Task self._task_group = asyncio.TaskGroup() - self._connect_time: datetime - - self.stream_id = 10 # stream ids under 10 are reserved - self.net_quality = 0 - self._host = host self._port = port self._password = hashlib.sha1(password.encode()).hexdigest().upper() - self._heartbeat_interval = heartbeat_interval + self._net_quality = 0 + self._stream_id = 10 # stream ids under 10 are reserved self._is_connected = False + self._heartbeat_interval = heartbeat_interval + + self._users: dict[int, User] = {} + self._global_stats = GlobalStats() + + self._event_emitter = AsyncIOEventEmitter() + self._event_emitter.add_listener('new_listener', self._new_listener) + self._event_emitter.add_listener('error', self._error) - self.server_info: ServerInfo - self.user_info = UserInfo( + self._server_info: ServerInfo + self._user_info = UserInfo( auth_status=AuthLevels.BOT, + slot_num=-2, username=username, user_token=user_token, server_password=self._password, @@ -116,8 +169,6 @@ def __init__( session_type='bot', session_options='', ) - self.users: dict[int, User] = {} - self.global_stats = GlobalStats() @property def is_connected(self) -> bool: @@ -137,23 +188,23 @@ def address(self) -> str: @property def unique_id(self) -> int: """Gets the unique id of the client.""" - return self.user_info.unique_id + return self._user_info.unique_id @property def user_count(self) -> int: """Gets the number of users.""" - return len(self.users) - 1 # subtract 1 for the server client + return len(self._users) - 1 # subtract 1 for the server client @property def user_ids(self) -> list[int]: """Gets the ids of the users.""" - return list(self.users.keys()) + return list(self._users.keys()) @property def stream_ids(self) -> list[int]: """Gets the ids of every stream for every user.""" return list(chain.from_iterable( - user.stream_ids for user in self.users.values() + user.stream_ids for user in self._users.values() )) async def __aenter__(self) -> 'RoRConnection': @@ -176,19 +227,12 @@ async def __aenter__(self) -> 'RoRConnection': self._reader_task = self._task_group.create_task( self.__reader_loop(), - name='reader_loop' + name=self.__reader_loop.__name__ ) - hello_packet = await self.__send_hello() - self.server_info = ServerInfo.from_bytes(hello_packet.data) + await self.__send_hello() - logger.info('Received Server Info: %s', self.server_info) - - welcome_packet = await self.__send_welcome() - self.user_info = UserInfo.from_bytes(welcome_packet.data) - - logger.info('Received User Info: %s', self.user_info) - self.add_user(self.user_info) + await self.__send_welcome() await self.__register_streams() @@ -200,7 +244,13 @@ async def __aenter__(self) -> 'RoRConnection': self._heartbeat_task = self._task_group.create_task( self.__heartbeat_loop(), - name='heartbeat_loop' + name=self.__heartbeat_loop.__name__ + ) + + logger.info('Starting frame step loop') + self._frame_step_task = self._task_group.create_task( + self.__frame_step_loop(), + name=self.__frame_step_loop.__name__ ) return self @@ -214,12 +264,14 @@ async def __aexit__(self, exc_type, exc, tb): """ logger.info('Disconnecting from %s', self.address) - await self._send(Packet( - command=MessageType.USER_LEAVE, - source=self.unique_id, - stream_id=self.stream_id, - size=0, - )) + await self._send( + UserLeavePacket( + type=MessageType.USER_LEAVE, + source=self.unique_id, + stream_id=self._stream_id, + size=0, + ) + ) await self._task_group.__aexit__(exc_type, exc, tb) @@ -229,6 +281,9 @@ async def __aexit__(self, exc_type, exc, tb): if self._heartbeat_task is not None: self._heartbeat_task.cancel() + if self._frame_step_task is not None: + self._frame_step_task.cancel() + async with self._writer_lock: self._reader.feed_eof() self._writer.close() @@ -236,71 +291,49 @@ async def __aexit__(self, exc_type, exc, tb): self._is_connected = False - async def __send_hello(self) -> Packet: + async def __send_hello(self): logger.info('Sending Hello Message') - hello_packet = Packet( - command=MessageType.HELLO, - source=0, # we do not have a unique id yet - stream_id=self.stream_id, - size=len(RORNET_VERSION), - data=RORNET_VERSION - ) - - future: asyncio.Future[Packet] = asyncio.Future() - - @self.once('packet') - async def hello(packet: Packet): - if packet.command is not MessageType.HELLO: - raise UnexpectedCommandError('Did not recieve hello response') - future.set_result(packet) + self._server_info = None - await self._send(hello_packet) - - return await future - - async def __send_welcome(self) -> Packet: - logger.info('Sending User Info: %s', self.user_info) - - data = self.user_info.pack() - welcome_packet = Packet( - command=MessageType.USER_INFO, - source=self.unique_id, - stream_id=self.stream_id, - size=len(data), - data=data + await self._send( + HelloPacket( + type=MessageType.HELLO, + source=0, # we do not have a unique id yet + stream_id=self._stream_id, + size=len(RORNET_VERSION), + payload=RORNET_VERSION.encode() + ) ) - future: asyncio.Future[Packet] = asyncio.Future() + while True: + if self._server_info is not None: + break + await asyncio.sleep(0.1) - @self.once('packet') - async def welcome(packet: Packet): - if packet.command is MessageType.WELCOME: - future.set_result(packet) - elif packet.command is MessageType.SERVER_FULL: - raise ConnectionError('Server is full') - elif packet.command is MessageType.BANNED: - raise ConnectionError('RoR Client is banned') - elif packet.command is MessageType.WRONG_PASSWORD: - raise ConnectionError('Wrong password') - elif packet.command is MessageType.WRONG_VERSION: - raise ConnectionError('Wrong version') - else: - raise UnexpectedCommandError( - 'Invalid response: %r', - packet.command - ) + async def __send_welcome(self): + logger.info('Sending User Info: %s', self._user_info) - await self._send(welcome_packet) + payload = self._user_info.pack() + await self._send(UserInfoPacket( + type=MessageType.USER_INFO, + source=self.unique_id, + stream_id=self._stream_id, + size=len(payload), + payload=payload + )) - return await future + while True: + if self._user_info.color_num != -1: + break + await asyncio.sleep(0.1) async def __register_streams(self): chat_stream_reg = ChatStreamRegister( type=StreamType.CHAT, status=0, origin_source_id=self.unique_id, - origin_stream_id=self.stream_id, + origin_stream_id=self._stream_id, name='chat', reg_data='0', ) @@ -312,7 +345,7 @@ async def __register_streams(self): type=StreamType.CHARACTER, status=0, origin_source_id=self.unique_id, - origin_stream_id=self.stream_id, + origin_stream_id=self._stream_id, name='default', reg_data=b'\x02', ) @@ -322,57 +355,37 @@ async def __register_streams(self): await self.register_stream(char_stream_reg) async def __reader_loop(self): - """The main reader loop. Reads packets from the server and emits - events. + """The main reader loop. Handles packets sent by the server. This function should not be called directly. - - This function will emit the following events when a packet is - received: - - `packet.*`: Emits for every packet received. - - `packet.`: Emits an event with the name of the - command from the packet received. For example, if a packet with - the command `MessageType.CHAT` is received, the event - `packet.CHAT` will be emitted. """ + header_format = 'IIII' + header_size = struct.calcsize(header_format) while True: - header = await self._reader.readexactly(Packet.calc_size()) + header = await self._reader.readexactly(header_size) logger.debug('[HEAD] %s', header) - try: - packet = Packet.from_bytes(header) - except struct.error as e: - raise PacketError( - f'Failed to read packet header: {header}' - ) from e + packet = packet_factory(*struct.unpack(header_format, header)) if ( - packet.command is not MessageType.STREAM_UNREGISTER + packet.type is not MessageType.STREAM_UNREGISTER and packet.size == 0 ): - raise PacketError(f'No data to read: {packet}') + raise ValueError(f'No data to read: {packet}') payload = await self._reader.read(packet.size) - if len(payload) != packet.size: - logger.warning( - 'Packet size mismatch: data=%s packet=%s', - payload, - packet - ) - logger.debug('[RECV] %s', payload) - packet.data = payload - - logger.debug('[PCKT] %s', packet) + if len(payload) != packet.size: + raise ValueError( + f'Packet size mismatch: data={payload} packet={packet}' + ) - # emit to packet wildcard - self.emit('packet', packet) + packet.payload = payload - # command event - self.emit('packet.' + packet.command.name, packet) + await self._parse_packet(packet) await asyncio.sleep(0.01) @@ -396,10 +409,11 @@ async def __heartbeat_loop(self): animation_mode=CharacterAnimation.IDLE_SWAY, ) - packet = Packet( - command=MessageType.STREAM_DATA, + header = StreamDataPacket( + type=MessageType.STREAM_DATA, source=self.unique_id, - stream_id=self.get_character_sid(self.unique_id) + stream_id=self.get_character_sid(self.unique_id), + size=0 ) logger.info( @@ -420,18 +434,302 @@ async def __heartbeat_loop(self): stream.animation_time = delta delta = 0 + # avoid spamming logs if self._heartbeat_interval >= 1: - # avoid spamming logs - logger.info('Sending heartbeat character stream data.') + logger.info('Sending heartbeat') - data = stream.pack() - packet.data = data - packet.size = len(data) + payload = stream.pack() + header.size = len(payload) + header.payload = payload - await self._send(packet) + await self._send(header) await asyncio.sleep(0.1) + async def __frame_step_loop(self): + """Send frame_step events at a stable rate.""" + start_time = time.time() + curr_time = start_time + delta = 0 + while True: + prev_time = curr_time + curr_time = time.time() + delta += curr_time - prev_time + + if delta >= (self.STABLE_FPS / 60): + self._emit(RoRClientEvents.FRAME_STEP, delta) + delta = 0 + + await asyncio.sleep(0.01) + + async def _send(self, packet: Packet): + """Sends a message to the server. + + :param header: The packet of the message. + """ + async with self._writer_lock: + if packet.size != len(packet.payload): + raise ValueError( + f'Packet size mismatch: data={packet.payload} ' + f'packet={packet}' + ) + + data = packet.pack() + + logger.debug('[SEND] %s', data) + + self._writer.write(data) + await self._writer.drain() + + @singledispatchmethod + async def _parse_packet(self, packet: Packet): + """Parses a packet from the server. + + :param packet: The packet to parse. + """ + raise NotImplementedError(f'No parse method for packet {packet.type}') + + @_parse_packet.register + async def _(self, packet: HelloPacket): + self._server_info = ServerInfo.from_bytes(packet.payload) + logger.info('Received Server Info: %s', self._server_info) + + @_parse_packet.register + async def _( + self, + packet: ( + WelcomePacket + | ServerFullPacket + | WrongPasswordPacket + | WrongVersionPacket + | BannedPacket + ), + ): + match packet.type: + case MessageType.WELCOME: + self._user_info = UserInfo.from_bytes(packet.payload) + logger.info('Received User Info: %s', self._user_info) + self.add_user(self._user_info) + case MessageType.SERVER_FULL: + raise ConnectionError('Server is full') + case MessageType.WRONG_PASSWORD: + raise ConnectionError('Wrong password') + case MessageType.WRONG_VERSION: + raise ConnectionError('Wrong version') + case MessageType.BANNED: + raise ConnectionError('RoR Client is banned') + case _: + raise UnexpectedMessageError( + f'Unexpected message: {packet.type}' + ) + + @_parse_packet.register + async def _(self, packet: NetQualityPacket): + prev_nq = self._net_quality + + curr_nq, *_ = struct.unpack('I', packet.payload) + + if not isinstance(curr_nq, int): + raise TypeError( + 'Expected net_quality to be an int, got ' + f'{type(curr_nq)}' + ) + + net_quality_changed = prev_nq != curr_nq + + logger.debug( + '[NETQ] uid=%d net_quality=(%d -> %d) changed=%s', + packet.source, + prev_nq, + curr_nq, + net_quality_changed + ) + + self._net_quality = curr_nq + + if net_quality_changed: + self._emit(RoRClientEvents.NET_QUALITY, curr_nq) + + @_parse_packet.register + async def _(self, packet: UserJoinPacket): + if packet.source == self.unique_id: + return + + user_info = UserInfo.from_bytes(packet.payload) + + logger.info( + 'User %r with uid %d joined the server', + user_info.client_name, + packet.source + ) + + self.add_user(user_info) + + self._emit(RoRClientEvents.USER_JOIN, packet.source, user_info) + + @_parse_packet.register + async def _(self, packet: UserInfoPacket): + user_info = UserInfo.from_bytes(packet.payload) + + self.update_user(user_info) + + logger.info( + 'Recieved user info from user %r uid=%d', + self.get_username(packet.source), + packet.source + ) + + self._emit(RoRClientEvents.USER_INFO, packet.source, user_info) + + @_parse_packet.register + async def _(self, packet: UserLeavePacket): + user = self.get_user(packet.source) + + logger.info( + 'User %r with uid %d left with reason: %r', + user.client_name, + packet.source, + packet.payload.decode() + ) + + if packet.source == self.unique_id: + raise ConnectionError('Disconnected from the server!') + + self.delete_user(packet.source) + + self._emit(RoRClientEvents.USER_LEAVE, packet.source, user) + + @_parse_packet.register + async def _(self, packet: ChatPacket | PrivateChatPacket): + message = packet.payload.decode().strip('\x00') + + logger.info( + '[%s] from_uid=%d message=%r', + 'CHAT' if isinstance(packet, ChatPacket) else 'PRIV', + packet.source, + message + ) + + if message and packet.source != self.unique_id: + event = ( + RoRClientEvents.CHAT + if isinstance(packet, ChatPacket) + else RoRClientEvents.PRIVATE_CHAT + ) + self._emit(event, packet.source, message) + + @_parse_packet.register + async def _(self, packet: GameCmdPacket): + if packet.source == self.unique_id: + return + + game_cmd = packet.payload.decode().strip('\x00') + + logger.debug( + '[GCMD] [RECV] from_uid=%d cmd=%r', + packet.source, + game_cmd + ) + + if game_cmd: + self._emit(RoRClientEvents.GAME_CMD, packet.source, game_cmd) + + @_parse_packet.register + async def _(self, packet: StreamRegisterPacket): + stream = stream_register_factory(packet.payload) + + self.add_stream(stream) + + logger.info( + 'User %r with uid=%d registered a new %s stream with sid=%d', + self.get_username(packet.source), + packet.source, + stream.type.name.lower(), + stream.origin_stream_id + ) + + if stream.type is StreamType.ACTOR: + await self.reply_to_actor_stream_register( + stream, + status=ActorStreamStatus.SUCCESS + ) + + self._emit(RoRClientEvents.STREAM_REGISTER, packet.source, stream) + + @_parse_packet.register + async def _(self, packet: StreamRegisterResultPacket): + stream = stream_register_factory(packet.payload) + + logger.info( + 'User %r with uid=%d has registered a %s stream with sid=%d', + self.get_username(packet.source), + packet.source, + stream.type.name.lower(), + stream.origin_stream_id + ) + + self._emit( + RoRClientEvents.STREAM_REGISTER_RESULT, + packet.source, + stream + ) + + @_parse_packet.register + async def _(self, packet: StreamDataPacket): + if packet.source == self.unique_id: + return + + with contextlib.suppress(UserNotFoundError, StreamNotFoundError): + stream = self.get_stream(packet.source, packet.stream_id) + + logger.info( + 'User %r with uid=%d sent data for %s stream with sid=%d', + self.get_username(packet.source), + packet.source, + stream.type.name.lower(), + stream.origin_stream_id + ) + + stream_data: StreamData | None + match stream.type: + case StreamType.CHARACTER | StreamType.ACTOR: + stream_data = stream_data_factory( + stream.type, + packet.payload + ) + logger.debug('[STREAM] stream_data=%s', stream_data) + case StreamType.CHAT: + stream_data = None + case _: + raise ValueError(f'Unknown stream type: {stream.type!r}') + + self._emit( + RoRClientEvents.STREAM_DATA, + packet.source, + stream, + stream_data + ) + + @_parse_packet.register + async def _(self, packet: StreamUnregisterPacket): + if len(packet.payload) != 0: + raise ValueError('Stream unregister packet has data') + + logger.info( + 'User %r with uid=%d unregistered a stream with sid=%d', + self.get_username(packet.source), + packet.source, + packet.stream_id + ) + + self.delete_stream(packet.source, packet.stream_id) + + self._emit( + RoRClientEvents.STREAM_UNREGISTER, + packet.source, + packet.stream_id + ) + def _new_listener(self, event: str, listener: Callable): """Handles new listener events. @@ -439,7 +737,7 @@ def _new_listener(self, event: str, listener: Callable): :param listener: The listener that was added. """ logger.debug( - 'New listener added: event="%s" listener="%s"', + '[EVENT] event=%r new_listener=%r', event, listener.__name__ ) @@ -449,21 +747,197 @@ def _error(self, error: Exception): :param error: The error that was emitted. """ - logger.error('Error: %r', error, exc_info=True, stacklevel=2) + logger.error('[EVENT] error=%r', error, exc_info=True, stacklevel=2) - async def _send(self, packet: Packet): - """Sends a packet to the server. + def _emit(self, event: RoRClientEvents, *args, **kwargs): + """Emit an event on the event emitter. - :param packet: The packet to send. + :param event: The event to emit. + :param args: The arguments to pass to the event handler. + :param kwargs: The keyword arguments to pass to the event handler. """ - async with self._writer_lock: - data = packet.pack() + if event is not RoRClientEvents.FRAME_STEP: + # we do not need to log every frame_step event emit + logger.debug( + '[EMIT] event=%r listeners=%d', + event.value, + len(self._event_emitter.listeners(event.value)) + ) + self._event_emitter.emit(event.value, *args, **kwargs) - logger.debug('[SEND] %s', data) + def on(self, event: RoRClientEvents, listener: Callable | None = None): + """Decorator to register an event handler on the event emitter. - self._writer.write(data) + :param event: The event to register the handler on. + :param listener: The listener to register. + """ + return self._event_emitter.on(event.value, listener) - await self._writer.drain() + def once(self, event: RoRClientEvents, listener: Callable | None = None): + """Decorator to register a one-time event handler on the event + emitter. + + :param event: The event to register the handler on. + :param listener: The listener to register. + """ + return self._event_emitter.once(event.value, listener) + + async def register_stream(self, stream: StreamRegister) -> int: + """Registers a stream with the server as the client. + + :param stream: The stream being registered. + :return: The stream id of the stream. + """ + stream.origin_source_id = self.unique_id + stream.origin_stream_id = self._stream_id + + if isinstance(stream, ActorStreamRegister): + stream.timestamp = -1 + + payload = stream.pack() + await self._send( + StreamRegisterPacket( + type=MessageType.STREAM_REGISTER, + source=stream.origin_source_id, + stream_id=stream.origin_stream_id, + size=len(payload), + payload=payload + ) + ) + + self.add_stream(stream) + self._stream_id += 1 + + return stream.origin_stream_id + + async def unregister_stream(self, stream_id: int): + """Unregisters a stream with the server as the client. + + :param stream_id: The stream id of the stream to unregister. + """ + await self._send(StreamUnregisterPacket( + type=MessageType.STREAM_UNREGISTER, + source=self.unique_id, + stream_id=stream_id, + size=0, + )) + self.delete_stream(self.unique_id, stream_id) + + async def reply_to_actor_stream_register( + self, + stream: ActorStreamRegister, + status: ActorStreamStatus + ): + """Replies to an actor stream register request. This will + determine what upstream arrow will be displayed on the client. + + :param stream: The stream to reply to. + :param status: The status to reply with. + """ + stream.status = status + payload = stream.pack() + await self._send(StreamRegisterResultPacket( + type=MessageType.STREAM_REGISTER_RESULT, + source=self.unique_id, + stream_id=stream.origin_stream_id, + size=len(payload), + payload=payload + )) + + async def send_chat(self, message: str): + """Sends a message to the game chat. + + :param message: The message to send. + """ + logger.info('[CHAT] message=%r', message) + + payload = message.encode() + await self._send(ChatPacket( + type=MessageType.CHAT, + source=self.unique_id, + stream_id=self.get_chat_sid(self.unique_id), + size=len(payload), + payload=payload + )) + + async def send_private_chat(self, uid: int, message: str): + """Sends a private message to a user. + + :param uid: The uid of the user to send the message to. + :param msg: The message to send. + """ + logger.info('[PRIV] to_uid=%d message=%r', uid, message) + + payload = struct.pack('I8000s', uid, message.encode()) + await self._send(PrivateChatPacket( + type=MessageType.PRIVATE_CHAT, + source=self.unique_id, + stream_id=self.get_chat_sid(self.unique_id), + size=len(payload), + payload=payload + )) + + async def send_multiline_chat(self, message: str): + """Sends a multiline message to the game chat. + + :param message: The message to send. + """ + max_line_len = 100 + if len(message) > max_line_len: + logger.debug('[CHAT] multiline_message=%r', message) + + total_lines = math.ceil(len(message) / max_line_len) + for i in range(total_lines): + line = message[max_line_len*i:max_line_len*(i+1)] + if i > 0: + line = f'| {line}' + await self.send_chat(line) + else: + await self.send_chat(message) + + async def kick(self, uid: int, reason: str = 'No reason given'): + """Kicks a user from the server. + + :param uid: The uid of the user to kick. + :param reason: The reason for kicking the user, defaults to + 'No reason given' + """ + await self.send_chat(f'!kick {uid} {reason}') + + async def ban(self, uid: int, reason: str = 'No reason given'): + """Bans a user from the server. + + :param uid: The uid of the user to ban. + :param reason: The reason for banning the user, defaults to + 'No reason given' + """ + await self.send_chat(f'!ban {uid} {reason}') + + async def say(self, uid: int, message: str): + """Send a message as a user anonymously. + + :param uid: The uid of the user to send the message to. If -1, + the message will be sent to everyone. + :param message: The message to send. + """ + await self.send_chat(f'!say {uid} {message}') + + async def send_game_cmd(self, command: str): + """Sends a game command (Angelscript) to the server. + + :param command: The command to send. + """ + logger.debug('[GCMD] [SEND] game_cmd=%r', command) + message = command.encode() + await self._send( + GameCmdPacket( + type=MessageType.GAME_CMD, + source=self.unique_id, + stream_id=0, + size=len(message), + ), + message + ) def get_uid_by_username(self, username: str) -> int | None: """Gets the uid of the user by their username. @@ -471,7 +945,7 @@ def get_uid_by_username(self, username: str) -> int | None: :param username: The username of the user. :return: The uid of the user. """ - for uid, user in self.users.items(): + for uid, user in self._users.items(): if user.username == username: return uid return None @@ -483,29 +957,51 @@ def get_user(self, uid: int) -> User: :return: The user. """ try: - return self.users[uid] + return self._users[uid] except KeyError as e: - raise UserNotFoundError(uid, pformat(self.users)) from e + logger.debug( + '[USER] uid=%d not found in users %s', + uid, + pformat(self._users) + ) + raise UserNotFoundError(f'User uid={uid} not found') from e def add_user(self, user_info: UserInfo): """Adds a client to the stream manager. :param user_info: The user info of the client to add. """ - # update global stats if this is a new user - if user_info.unique_id not in self.users: - self.global_stats.add_user(user_info.username) + if user_info.unique_id in self._users: + raise UserAlreadyExistsError( + f'User uid={user_info.unique_id} already exists' + ) - # set the user to a new user if not already set - self.users.setdefault(user_info.unique_id, User(info=user_info)) + self._users[user_info.unique_id] = User(info=user_info) - # update the user info for the user - self.users[user_info.unique_id].info = user_info + self._global_stats.add_user(user_info.username) - logger.info( - 'Added user %r uid=%d', + logger.debug( + '[USER] Added username=%r uid=%d %s', user_info.username, - user_info.unique_id + user_info.unique_id, + user_info + ) + + def update_user(self, user_info: UserInfo): + """Updates a client in the stream manager. + + :param user_info: The user info of the client to update. + """ + if user_info.unique_id not in self._users: + self.add_user(user_info) + else: + self._users[user_info.unique_id].info = user_info + + logger.debug( + '[USER] Updated username=%r uid=%d %s', + user_info.username, + user_info.unique_id, + user_info ) def delete_user(self, uid: int): @@ -513,18 +1009,23 @@ def delete_user(self, uid: int): :param uid: The uid of the client to delete. """ - user = self.users.pop(uid) + user = self._users.pop(uid) - self.global_stats.meters_driven += user.stats.meters_driven - self.global_stats.meters_sailed += user.stats.meters_sailed - self.global_stats.meters_walked += user.stats.meters_walked - self.global_stats.meters_flown += user.stats.meters_flown + self._global_stats.meters_driven += user.stats.meters_driven + self._global_stats.meters_sailed += user.stats.meters_sailed + self._global_stats.meters_walked += user.stats.meters_walked + self._global_stats.meters_flown += user.stats.meters_flown - self.global_stats.connection_times.append( + self._global_stats.connection_times.append( datetime.now() - user.stats.online_since ) - logger.debug('Deleted user %r uid=%d', user.username, uid) + logger.debug( + '[USER] Deleted username=%r uid=%d %s', + user.username, + uid, + user + ) def add_stream(self, stream: StreamRegister): """Adds a stream to the stream manager. @@ -712,157 +1213,3 @@ def get_auth_status(self, uid: int) -> AuthLevels: :return: The authentication status of the user. """ return self.get_user(uid).auth_status - - async def register_stream(self, stream: StreamRegister) -> int: - """Registers a stream with the server as the client. - - :param stream: The stream being registered. - :return: The stream id of the stream. - """ - stream.origin_source_id = self.unique_id - stream.origin_stream_id = self.stream_id - - if isinstance(stream, ActorStreamRegister): - stream.timestamp = -1 - - stream_data = stream.pack() - packet = Packet( - command=MessageType.STREAM_REGISTER, - source=stream.origin_source_id, - stream_id=stream.origin_stream_id, - size=len(stream_data), - data=stream_data - ) - await self._send(packet) - self.add_stream(stream) - self.stream_id += 1 - - return stream.origin_stream_id - - async def unregister_stream(self, stream_id: int): - """Unregisters a stream with the server as the client. - - :param stream_id: The stream id of the stream to unregister. - """ - packet = Packet( - command=MessageType.STREAM_UNREGISTER, - source=self.unique_id, - stream_id=stream_id, - ) - await self._send(packet) - self.delete_stream(self.unique_id, stream_id) - - async def reply_to_stream_register( - self, - stream: StreamRegister, - status: int - ): - """Replies to a stream register request. - - :param stream: The stream to reply to. - :param status: The status to reply with. - """ - stream.status = status - data = stream.pack() - packet = Packet( - command=MessageType.STREAM_REGISTER_RESULT, - source=self.unique_id, - stream_id=stream.origin_stream_id, - size=len(data), - data=data - ) - await self._send(packet) - - async def send_chat(self, message: str): - """Sends a message to the game chat. - - :param message: The message to send. - """ - logger.info('[CHAT] message="%s"', message) - - data = message.encode() - - await self._send(Packet( - command=MessageType.CHAT, - source=self.unique_id, - stream_id=self.get_chat_sid(self.unique_id), - size=len(data), - data=data - )) - - async def send_private_chat(self, uid: int, message: str): - """Sends a private message to a user. - - :param uid: The uid of the user to send the message to. - :param message: The message to send. - """ - logger.info('[PRIV] to_uid=%d message="%s"', uid, message) - - data = struct.pack('I8000s', uid, message.encode()) - await self._send(Packet( - command=MessageType.PRIVATE_CHAT, - source=self.unique_id, - stream_id=self.get_chat_sid(self.unique_id), - size=len(data), - data=data - )) - - async def send_multiline_chat(self, message: str): - """Sends a multiline message to the game chat. - - :param message: The message to send. - """ - max_line_len = 100 - if len(message) > max_line_len: - logger.debug('[CHAT] multiline_message="%s"', message) - - total_lines = math.ceil(len(message) / max_line_len) - for i in range(total_lines): - line = message[max_line_len*i:max_line_len*(i+1)] - if i > 0: - line = f'| {line}' - await self.send_chat(line) - else: - await self.send_chat(message) - - async def kick(self, uid: int, reason: str = 'No reason given'): - """Kicks a user from the server. - - :param uid: The uid of the user to kick. - :param reason: The reason for kicking the user, defaults to - 'No reason given' - """ - await self.send_chat(f'!kick {uid} {reason}') - - async def ban(self, uid: int, reason: str = 'No reason given'): - """Bans a user from the server. - - :param uid: The uid of the user to ban. - :param reason: The reason for banning the user, defaults to - 'No reason given' - """ - await self.send_chat(f'!ban {uid} {reason}') - - async def say(self, uid: int, message: str): - """Send a message as a user anonymously. - - :param uid: The uid of the user to send the message to. If -1, - the message will be sent to everyone. - :param message: The message to send. - """ - await self.send_chat(f'!say {uid} {message}') - - async def send_game_cmd(self, command: str): - """Sends a game command (Angelscript) to the server. - - :param command: The command to send. - """ - logger.debug('[GAME_CMD] cmd="%s"', command) - data = command.encode() - await self._send(Packet( - command=MessageType.GAME_CMD, - source=self.unique_id, - stream_id=0, - size=len(data), - data=data - ))