From 83126b61592446d99d8f9fe3c6e386df7c9bf6fe Mon Sep 17 00:00:00 2001 From: Danny Mackey Date: Thu, 31 Aug 2023 02:00:12 -0400 Subject: [PATCH] Begin conversion to asyncio communication --- ror_server_bot/__init__.py | 86 ++- ror_server_bot/__main__.py | 46 ++ ror_server_bot/ror_bot/__init__.py | 7 + ror_server_bot/ror_bot/enums.py | 226 +++++++ ror_server_bot/ror_bot/models/__init__.py | 44 ++ ror_server_bot/ror_bot/models/config.py | 62 ++ ror_server_bot/ror_bot/models/sendable.py | 660 ++++++++++++++++++++ ror_server_bot/ror_bot/models/stats.py | 25 + ror_server_bot/ror_bot/models/truck_file.py | 70 +++ ror_server_bot/ror_bot/models/validators.py | 12 + ror_server_bot/ror_bot/models/vector.py | 246 ++++++++ ror_server_bot/ror_bot/packet_handler.py | 80 +++ ror_server_bot/ror_bot/ror_client.py | 424 +++++++++++++ ror_server_bot/ror_bot/ror_connection.py | 595 ++++++++++++++++++ ror_server_bot/ror_bot/stream_manager.py | 286 +++++++++ ror_server_bot/ror_bot/user.py | 201 ++++++ 16 files changed, 3069 insertions(+), 1 deletion(-) create mode 100644 ror_server_bot/ror_bot/__init__.py create mode 100644 ror_server_bot/ror_bot/enums.py create mode 100644 ror_server_bot/ror_bot/models/__init__.py create mode 100644 ror_server_bot/ror_bot/models/config.py create mode 100644 ror_server_bot/ror_bot/models/sendable.py create mode 100644 ror_server_bot/ror_bot/models/stats.py create mode 100644 ror_server_bot/ror_bot/models/truck_file.py create mode 100644 ror_server_bot/ror_bot/models/validators.py create mode 100644 ror_server_bot/ror_bot/models/vector.py create mode 100644 ror_server_bot/ror_bot/packet_handler.py create mode 100644 ror_server_bot/ror_bot/ror_client.py create mode 100644 ror_server_bot/ror_bot/ror_connection.py create mode 100644 ror_server_bot/ror_bot/stream_manager.py create mode 100644 ror_server_bot/ror_bot/user.py diff --git a/ror_server_bot/__init__.py b/ror_server_bot/__init__.py index 6853c36..0379cfe 100644 --- a/ror_server_bot/__init__.py +++ b/ror_server_bot/__init__.py @@ -1 +1,85 @@ -__version__ = '0.0.0' \ No newline at end of file +import logging +import sys +from logging import handlers +from pathlib import Path + +from devtools import PrettyFormat +from rich.logging import RichHandler + +__version__ = '0.0.0' + +RORNET_VERSION = 'RoRnet_2.44' + +PROJECT_DIRECTORY = Path(__file__).parent + +TRUCK_TO_NAME_FILE = Path('truck_to_name.json') + + +class PFormat(PrettyFormat): + def _format_str_bytes( + self, + value: str | bytes, + value_repr: str, + indent_current: int, + indent_new: int + ) -> None: + if isinstance(value, bytes): + value = value.replace(b'\x00', b'') + return super()._format_str_bytes( + value, + value_repr, + indent_current, + indent_new + ) + + +pformat = PFormat(indent_step=2) + +stream_handler = logging.StreamHandler(sys.stdout) +file_handler = handlers.RotatingFileHandler( + filename='ror_server_bot.log', + mode='at', + maxBytes=1024*1024*16, + backupCount=5, + encoding='utf-8' +) +rich_handler = RichHandler( + omit_repeated_times=False, + keywords=[ + '[CHAT]', + '[EMIT]' + '[GAME_CMD]', + '[HEAD]', + '[PCKT]', + '[PRIV]', + '[RECV]', + '[SEND]', + ] +) + +stream_fmt = logging.Formatter( + fmt='{asctime} | {levelname} | {filename}:{lineno} | {message}', + style='{' +) +file_fmt = logging.Formatter( + fmt='{asctime} | {levelname} | {filename}:{lineno} | {message}', + style='{' +) +rich_fmt = logging.Formatter(fmt='{message}', style='{') + +msec_fmt = '%s.%04d' +stream_fmt.default_msec_format = msec_fmt +file_fmt.default_msec_format = msec_fmt + +stream_handler.setFormatter(stream_fmt) +stream_handler.setLevel(logging.INFO) +file_handler.setFormatter(file_fmt) +file_handler.setLevel(logging.DEBUG) +rich_handler.setFormatter(rich_fmt) +rich_handler.setLevel(logging.DEBUG) + + +logging.basicConfig( + level=logging.DEBUG, + handlers=[file_handler, rich_handler] +) diff --git a/ror_server_bot/__main__.py b/ror_server_bot/__main__.py index e69de29..ad1e019 100644 --- a/ror_server_bot/__main__.py +++ b/ror_server_bot/__main__.py @@ -0,0 +1,46 @@ +import asyncio +import logging + +import discord + +from .ror_bot import RoRClient, RoRClientConfig + +logger = logging.getLogger(__name__) + +if __name__ == '__main__': + def start(): + class DiscordClient: + pass + + intents = discord.Intents.default() + intents.message_content = True + + client = DiscordClient(intents=intents) + logger.warning( + 'Expect a slowdown when requesting guild information from Discord!' + ) + client.run(client.config.discord_bot_token) + + config = RoRClientConfig( + id='1', + enabled=True, + server=RoRClientConfig.ServerConfig( + host='10.90.1.64', + port=12000, + password='' + ), + user=RoRClientConfig.UserConfig( + token='' + ), + discord_channel_id=-1, + announcements=RoRClientConfig.Announcements(), + reconnection_interval=1, + reconnection_tries=3, + ) + + async def main(): + async with RoRClient(config): + while True: + await asyncio.sleep(0.1) + + asyncio.run(main()) diff --git a/ror_server_bot/ror_bot/__init__.py b/ror_server_bot/ror_bot/__init__.py new file mode 100644 index 0000000..0a07bc5 --- /dev/null +++ b/ror_server_bot/ror_bot/__init__.py @@ -0,0 +1,7 @@ +from .models import RoRClientConfig +from .ror_client import RoRClient + +__all__ = [ + 'RoRClientConfig', + 'RoRClient', +] diff --git a/ror_server_bot/ror_bot/enums.py b/ror_server_bot/ror_bot/enums.py new file mode 100644 index 0000000..046ce12 --- /dev/null +++ b/ror_server_bot/ror_bot/enums.py @@ -0,0 +1,226 @@ +from enum import auto, Enum, IntEnum + + +# names from https://www.color-name.com/ +class PlayerColor(Enum): + """The color assigned to each player.""" + + #! DO NOT REORDER + GREEN = "#00CC00" + BLUE = "#0066B3" + ORANGE = "#FF8000" + YELLOW = "#FFCC00" + LIME = "#CCFF00" + RED = "#FF0000" + GRAY = "#808080" + DARK_GREEN = "#008F00" + WINDSOR_TAN = "#B35A00" + LIGHT_GOLD = "#B38F00" + APPLE_GREEN = "#8FB300" + UE_RED = "#B30000" + DARK_GRAY = "#BEBEBE" + LIGHT_GREEN = "#80FF80" + LIGHT_SKY_BLUE = "#80C9FF" + MAC_AND_CHEESE = "#FFC080" + YELLOW_CRAYOLA = "#FFE680" + LAVENDER_FLORAL = "#AA80FF" + ELECTRIC_PINK = "#EE00CC" + CONGO_PINK = "#FF8080" + BRONZE_YELLOW = "#666600" + BRILLIANT_LAVENDER = "#FFBFFF" + SEA_GREEN = "#00FFCC" + WILD_ORCHID = "#CC6699" + DARK_YELLOW = "#999900" + + +class Color(Enum): + BLACK = "#000000" + GREY = "#999999" + RED = "#FF0000" + YELLOW = "#FFFF00" + WHITE = "#FFFFFF" + CYAN = "#00FFFF" + BLUE = "#0000FF" + GREEN = "#00FF00" + MAGENTA = "#FF00FF" + COMMAND = "#941E8D" + WHISPER = "#967417" + SCRIPT = "#32436F" + + +class ActorType(Enum): + TRUCK = 'truck' + CAR = 'car' + LOAD = 'load' + AIRPLANE = 'airplane' + BOAT = 'boat' + TRAILER = 'trailer' + TRAIN = 'train' + FIXED = 'fixed' + + +class MessageType(IntEnum): + HELLO = 1025 + """Client sends its version as the first message.""" + + # Hello Responses + SERVER_FULL = auto() + """Server is full.""" + WRONG_PASSWORD = auto() + """Wrong password.""" + WRONG_VERSION = auto() + """Wrong version.""" + BANNED = auto() + """Client not allowed to join (banned).""" + WELCOME = auto() + """Client accepted.""" + + # Technical + SERVER_VERSION = auto() + """Server sends its version.""" + SERVER_SETTINGS = auto() + """Server sends client the terrain name.""" + USER_INFO = auto() + """User data that is sent from the server to clients.""" + MASTER_SERVER_INFO = auto() + """Server sends master server info.""" + NET_QUALITY = auto() + """Server sends network quality information.""" + + # Gameplay + GAME_CMD = auto() + """Script message. Can be sent in both directions.""" + USER_JOIN = auto() + """New user joined.""" + USER_LEAVE = auto() + """User leaves.""" + CHAT = auto() + """Chat line in UTF8 encoding.""" + PRIVATE_CHAT = auto() + """Private chat line in UTF8 encoding.""" + + # Stream Functions + STREAM_REGISTER = auto() + """Create new stream.""" + STREAM_REGISTER_RESULT = auto() + """Result of a stream creation.""" + STREAM_UNREGISTER = auto() + """Remove stream.""" + STREAM_DATA = auto() + """Stream data.""" + STREAM_DATA_DISCARDABLE = auto() + """Stream data that is allowed to be discarded.""" + + # Legacy + USER_INFO_LEGACY = 1003 + """Wrong version.""" + + +class StreamType(IntEnum): + ACTOR = 0 + CHARACTER = 1 + AI = 2 + CHAT = 3 + + +class AuthLevels(IntEnum): + NONE = 0 + """no authentication""" + ADMIN = 1 + """admin on the server""" + RANKED = 2 + """ranked status""" + MOD = 4 + """moderator status""" + BOT = 8 + """bot status""" + BANNED = 16 + """banned""" + + @classmethod + def get_auth_str(cls, auth: 'AuthLevels') -> str: + auth_str = '' + if auth is AuthLevels.NONE: + auth_str = '' + if auth is AuthLevels.ADMIN: + auth_str = 'A' + if auth is AuthLevels.MOD: + auth_str = 'M' + if auth is AuthLevels.RANKED: + auth_str = 'R' + if auth is AuthLevels.BOT: + auth_str = 'B' + if auth is AuthLevels.BANNED: + auth_str = 'X' + return auth_str + + @property + def auth_str(self) -> str: + return self.get_auth_str(self) + + +class NetMask(IntEnum): + HORN = 1 + """Horn is in use.""" + POLICE_AUDIO = auto() + """Police siren is on.""" + PARTICLE = auto() + """Custom particles are on.""" + PARKING_BRAKE = auto() + """Parking brake is on.""" + TRACTION_CONTROL_ACTIVE = auto() + """Traction control is on.""" + ANTI_LOCK_BRAKES_ACTIVE = auto() + """Anti-lock brakes are on.""" + ENGINE_CONTACT = auto() + """Ignition is on.""" + ENGINE_RUN = auto() + """Engine is running.""" + ENGINE_MODE_AUTOMATIC = auto() + """Using automatic transmission.""" + ENGINE_MODE_SEMIAUTO = auto() + """Using semi-automatic transmission.""" + ENGINE_MODE_MANUAL = auto() + """Using manual transmission.""" + ENGINE_MODE_MANUAL_STICK = auto() + """Using manual transmission with stick.""" + ENGINE_MODE_MANUAL_RANGES = auto() + """Using manual transmission with ranges.""" + + +class LightMask(IntEnum): + CUSTOM_LIGHT_1 = 1 + CUSTOM_LIGHT_2 = auto() + CUSTOM_LIGHT_3 = auto() + CUSTOM_LIGHT_4 = auto() + CUSTOM_LIGHT_5 = auto() + CUSTOM_LIGHT_6 = auto() + CUSTOM_LIGHT_7 = auto() + CUSTOM_LIGHT_8 = auto() + CUSTOM_LIGHT_9 = auto() + CUSTOM_LIGHT_10 = auto() + HEADLIGHT = auto() + HIGHBEAMS = auto() + FOGLIGHTS = auto() + SIDELIGHTS = auto() + BRAKES = auto() + REVERSE = auto() + BEACONS = auto() + BLINK_LEFT = auto() + BLINK_RIGHT = auto() + BLINK_WARN = auto() + +class CharacterCommand(IntEnum): + INVALID = 0 + POSITION = auto() + ATTACH = auto() + DETACH = auto() + +class CharacterAnimation(Enum): + IDLE_SWAY = "Idle_sway" + SPOT_SWIM = "Spot_swim" + WALK = "Walk" + RUN = "Run" + SWIM_LOOP = "Swim_loop" + TURN = "Turn" + DRIVING = "Driving" diff --git a/ror_server_bot/ror_bot/models/__init__.py b/ror_server_bot/ror_bot/models/__init__.py new file mode 100644 index 0000000..29464a0 --- /dev/null +++ b/ror_server_bot/ror_bot/models/__init__.py @@ -0,0 +1,44 @@ +from .config import Config, RoRClientConfig +from .sendable import ( + ActorStreamRegister, + CharacterAttachStreamData, + CharacterPositionStreamData, + CharacterStreamRegister, + ChatStreamRegister, + Packet, + Sendable, + ServerInfo, + stream_data_factory, + stream_register_factory, + StreamRegister, + UserInfo, + VehicleStreamData, +) +from .stats import DistanceStats, GlobalStats, UserStats +from .truck_file import TruckFile, TruckFilenames +from .vector import Vector3, Vector4 + +__all__ = [ + 'ActorStreamRegister', + 'CharacterAttachStreamData', + 'CharacterPositionStreamData', + 'CharacterStreamRegister', + 'ChatStreamRegister', + 'Config', + 'DistanceStats', + 'GlobalStats', + 'Packet', + 'RoRClientConfig', + 'Sendable', + 'ServerInfo', + 'StreamRegister', + 'TruckFile', + 'TruckFilenames', + 'UserInfo', + 'UserStats', + 'Vector3', + 'Vector4', + 'VehicleStreamData', + 'stream_data_factory', + 'stream_register_factory', +] diff --git a/ror_server_bot/ror_bot/models/config.py b/ror_server_bot/ror_bot/models/config.py new file mode 100644 index 0000000..dce9a06 --- /dev/null +++ b/ror_server_bot/ror_bot/models/config.py @@ -0,0 +1,62 @@ +from pydantic import BaseModel, Field, model_validator + +from ror_server_bot import RORNET_VERSION + + +class RoRClientConfig(BaseModel): + class ServerConfig(BaseModel): + host: str = '' + port: int = Field(12000, ge=12000, le=12999) + password: str = '' + + class UserConfig(BaseModel): + name: str = 'RoR Server Bot' + token: str = '' + language: str = 'en_US' + + class Announcements(BaseModel): + delay: int = 300 + """Delay between announcements in seconds.""" + enabled: bool = False + messages: list[str] = Field(default_factory=list) + + @model_validator(mode='after') + def set_enabled(self) -> 'RoRClientConfig.Announcements': # noqa: N804 + self.enabled = bool(self.messages) + return self + + def get_next_announcement(self, time_sec: float) -> str: + idx = int((time_sec / self.delay) % len(self.messages)) + return self.messages[idx] + + id: str + enabled: bool + server: ServerConfig + user: UserConfig + discord_channel_id: int + announcements: Announcements | None + reconnection_interval: int = 5 + """Interval between reconnection attempts in seconds.""" + reconnection_tries: int = 3 + """Number of reconnection attempts before giving up.""" + + +class Config(BaseModel): + """Represents a configuration used to build RoR server bots""" + + client_name: str = '2022.04' + version_num: str = RORNET_VERSION + discord_bot_token: str + ror_clients: list[RoRClientConfig] + + def get_channel_id_by_client_id(self, id: str) -> int | None: + for client in self.ror_clients: + if client.id == id: + return client.discord_channel_id + return None + + def get_ror_client_by_id(self, id: str) -> RoRClientConfig | None: + for client in self.ror_clients: + if client.id == id: + return client + return None diff --git a/ror_server_bot/ror_bot/models/sendable.py b/ror_server_bot/ror_bot/models/sendable.py new file mode 100644 index 0000000..330b221 --- /dev/null +++ b/ror_server_bot/ror_bot/models/sendable.py @@ -0,0 +1,660 @@ +import logging +import struct +from datetime import datetime +from typing import Annotated, Any, ClassVar, Literal, Self + +from pydantic import BaseModel, Field, field_validator + +from ror_server_bot import pformat, RORNET_VERSION +from ror_server_bot.ror_bot.enums import ( + ActorType, + AuthLevels, + CharacterAnimation, + CharacterCommand, + Color, + MessageType, + PlayerColor, + StreamType, +) + +from .validators import strip_nulls_after +from .vector import Vector3, Vector4 + +logger = logging.getLogger(__name__) + + +class Sendable(BaseModel): + """A sendable object.""" + + STRUCT_FORMAT: ClassVar[str] = '' + """The struct format of the object.""" + + @classmethod + def calc_size(cls) -> int: + """The expected size of the `cls.STRUCT_FORMAT` in bytes.""" + return struct.calcsize(cls.STRUCT_FORMAT) + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + """Creates an object from the bytes. + + :param data: The bytes to create the object from. + :return: The object created from the bytes. + """ + return cls.model_validate( + dict(zip( + cls.model_fields.keys(), + struct.unpack(cls.STRUCT_FORMAT, data) + )) + ) + + def __str__(self) -> str: + return pformat(self) + + def pack(self) -> bytes: + """Packs the object into bytes. + + :return: The object packed into bytes. + """ + data = [] + for value in self.model_dump().values(): + if isinstance(value, str): + data.append(value.encode()) + else: + data.append(value) + return struct.pack(self.STRUCT_FORMAT, *data) + + +class Packet(Sendable): + STRUCT_FORMAT: ClassVar[str] = 'IIII' + """The struct format of the packet header. + ``` + I: command + I: source + I: stream_id + I: size + ``` + """ + + command: MessageType + """The command of this packet.""" + source: int = 0 + """The source of this packet (0 = server).""" + stream_id: int = Field(default=0, ge=0) + """The stream id of this packet.""" + size: int = Field(default=0, ge=0) + """The size of the data in this packet.""" + data: bytes = b'' + + time: datetime = Field(default_factory=datetime.now) + + @classmethod + def from_bytes(cls, header: bytes) -> 'Packet': + """Creates a packet from the header data. + + :param header: The bytes of the header. + :return: The packet created from the header data. + """ + return super().from_bytes(header) + + def pack(self) -> bytes: + """Packs the packet into bytes. + + :return: The packet packed into bytes. + """ + return struct.pack( + f'{self.STRUCT_FORMAT}{self.size}s', + self.command, + self.source, + self.stream_id, + self.size, + self.data + ) + + +class ServerInfo(Sendable): + STRUCT_FORMAT: ClassVar[str] = '20s128s128s?4096s' + """The struct format of the server info data. + ``` + 20s: protocol_version + 128s: terrain_name + 128s: server_name + ?: has_password + 4096s: info + ``` + """ + + protocol_version: str = Field(default=RORNET_VERSION, max_length=20) + """The protocol version of the server.""" + terrain_name: str = Field(default='', max_length=128) + """The name of the terrain.""" + server_name: str = Field(default='', max_length=128) + """The name of the server.""" + has_password: bool = False + """Whether the server has a password.""" + info: str = Field(default='', max_length=4096) + """Info text (MOTD file contents).""" + + # validators + _strip_null_character = strip_nulls_after( + 'protocol_version', + 'terrain_name', + 'server_name', + 'info', + ) + + @classmethod + def from_bytes(cls, data: bytes) -> 'ServerInfo': + """Creates a server info from the bytes. + + :param data: The bytes to create the server info from. + :return: The server info created from the bytes. + """ + return super().from_bytes(data) + + def pack(self) -> bytes: + """Packs the server info into bytes. + + :return: The server info packed into bytes. + """ + return super().pack() + + +class UserInfo(Sendable): + STRUCT_FORMAT: ClassVar[str] = 'Iiii40s40s40s10s10s25s40s10s128s' + """The struct format of the user info data. + ``` + I: uid + i: auth_status + i: slot_num + i: color_idx + 40s: username + 40s: token + 40s: server_password + 10s: language + 10s: client_name + 25s: client_version + 40s: client_guid + 10s: session_type + 128s: session_options + ``` + """ + + unique_id: int = Field(default=0, ge=0) + """The unique id of the user (set by the server).""" + auth_status: AuthLevels + """The authentication status of the user (set by the server).""" + slot_num: int = -1 + """The slot number the user occupies in the server (set by the + server).""" + color_num: int = -1 + """The color number of the user (set by the server).""" + username: str = Field(max_length=40) + user_token: str = Field(max_length=40) + server_password: str = Field(max_length=40) + language: str = Field(max_length=10) + """The language of the user (e.g. "de-DE" or "en-US").""" + client_name: str + """The name and version of the client.""" + client_version: str = Field(max_length=25) + """The version of the client (e.g. "2022.12").""" + client_guid: str = Field(max_length=40) + session_type: str = Field(max_length=10) + """The requested session type (e.g. "normal" "bot" "rcon")""" + session_options: str = Field(max_length=128) + """Reserved for future options.""" + + @property + def user_color(self) -> str: + """Get the hex color of the username.""" + colors = list(PlayerColor) + if -1 < self.color_num < len(colors): + return colors[self.color_num].value + return Color.WHITE.value + + # validators + _strip_nulls = strip_nulls_after( + 'username', + 'user_token', + 'server_password', + 'language', + 'client_name', + 'client_version', + 'client_guid', + 'session_type', + 'session_options', + ) + + @classmethod + def from_bytes(cls, data: bytes) -> 'UserInfo': + """Creates a user info from the bytes. + + :param data: The bytes to create the user info from. + :return: The user info created from the bytes. + """ + return super().from_bytes(data) + + def pack(self) -> bytes: + """Packs the user info into bytes. + + :return: The user info packed into bytes. + """ + return super().pack() + + +class BaseStreamRegister(BaseModel): + STRUCT_FORMAT: ClassVar[str] = 'iiii128s' + """The struct format of the stream register data. + ``` + i: type + i: status + i: origin_source_id + i: origin_stream_id + 128s: name + ``` + """ + + type: StreamType + status: int + origin_source_id: int + origin_stream_id: int + name: str = Field(max_length=128) + """The name of the stream.""" + + def __str__(self) -> str: + return pformat(self) + + @field_validator('name', mode='before') + def __strip_null_character(cls, v: str | bytes) -> str: + if isinstance(v, bytes): + return v.strip(b'\x00').decode() + return v + + +class GenericStreamRegister(BaseStreamRegister, Sendable): + STRUCT_FORMAT: ClassVar[str] = (BaseStreamRegister.STRUCT_FORMAT + '128s') + """The struct format of the generic stream register data. + ``` + i: type + i: status + i: origin_source_id + i: origin_stream_id + 128s: name + 128s: reg_data + ``` + """ + + type: Literal[StreamType.CHAT] | Literal[StreamType.CHARACTER] + name: Literal['chat', 'default'] + reg_data: str = Field(max_length=128) + + position: Vector3 = Vector3() + """The position of the actor.""" + rotation: Vector4 = Vector4() + """The rotation of the actor.""" + + _strip_nulls = strip_nulls_after('reg_data') + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + """Creates a stream register from the bytes. + + :param data: The bytes to create the stream register from. + :return: The stream register created from the bytes. + """ + return super().from_bytes(data) + + def pack(self) -> bytes: + """Packs the stream register into bytes. + + :return: The stream register packed into bytes. + """ + return struct.pack( + self.STRUCT_FORMAT, + self.type, + self.status, + self.origin_source_id, + self.origin_stream_id, + self.name.encode(), + self.reg_data.encode(), + ) + + +class ChatStreamRegister(GenericStreamRegister): + type: Literal[StreamType.CHAT] + name: Literal['chat'] + reg_data: str = Field(max_length=128) + + +class CharacterStreamRegister(GenericStreamRegister): + type: Literal[StreamType.CHARACTER] + name: Literal['default'] + reg_data: str = Field(max_length=128) + + +class ActorStreamRegister(BaseStreamRegister): + STRUCT_FORMAT: ClassVar[str] = ( + BaseStreamRegister.STRUCT_FORMAT + 'ii60s60s' + ) + """The struct format of the actor stream register data. + ``` + i: type + i: status + i: origin_source_id + i: origin_stream_id + 128s: name + i: buffer_size + i: timestamp + 60s: skin + 60s: section_config + ``` + """ + + type: Literal[StreamType.ACTOR] + buffer_size: int + timestamp: int + skin: str = Field(max_length=60) + section_config: str = Field(max_length=60) + + actor_type: ActorType | None = None + """The type of the actor (parsed from the actor filename).""" + + position: Vector3 = Vector3() + """The position of the actor.""" + rotation: Vector4 = Vector4() + """The rotation of the actor.""" + + _strip_nulls = strip_nulls_after('skin', 'section_config') + + @classmethod + def from_bytes(cls, data: bytes) -> 'ActorStreamRegister': + """Creates a stream register from the bytes. + + :param data: The bytes to create the stream register from. + :return: The stream register created from the bytes. + """ + return cls.model_validate( + dict(zip( + cls.model_fields.keys(), + struct.unpack(cls.STRUCT_FORMAT, data) + )) + ) + + def pack(self) -> bytes: + """Packs the stream register into bytes. + + :return: The stream register packed into bytes. + """ + return struct.pack( + self.STRUCT_FORMAT, + self.type, + self.status, + self.origin_source_id, + self.origin_stream_id, + self.name.encode(), + self.buffer_size, + self.timestamp, + self.skin.encode(), + self.section_config.encode(), + ) + + +StreamRegister = Annotated[ + ChatStreamRegister | CharacterStreamRegister | ActorStreamRegister, + Field(discriminator='type') +] + + +def stream_register_factory(data: bytes) -> StreamRegister: + """Creates a stream register of the given type. + + :param data: The bytes to create the stream register from. + :return: The stream register of the given type. + """ + uint = 'I' + uint_size = struct.calcsize(uint) + + stream_type = StreamType(struct.unpack(uint, data[:uint_size])[0]) + if stream_type is StreamType.CHAT: + return ChatStreamRegister.from_bytes(data) + elif stream_type is StreamType.CHARACTER: + return CharacterStreamRegister.from_bytes(data) + elif stream_type is StreamType.ACTOR: + return ActorStreamRegister.from_bytes(data) + raise ValueError(f'Invalid stream type: {type!r}') + + +class CharacterPositionStreamData(Sendable): + STRUCT_FORMAT: ClassVar[str] = 'i3fff10s' + """The struct format of the character position stream data. + ``` + i: command + 3f: position + f: rotation + f: animation_time + 10s: animation_mode + ``` + """ + + command: Literal[CharacterCommand.POSITION] + position: Vector3 + rotation: float + """The rotation in radians.""" + animation_time: float + animation_mode: CharacterAnimation + + @field_validator('animation_mode', mode='before') + def __strip_null_character(cls, v: Any) -> Any: + if isinstance(v, bytes): + return v.strip(b'\x00').decode() + return v + + @classmethod + def from_bytes(cls, data: bytes) -> 'CharacterPositionStreamData': + """Creates a character position from the bytes. + + :param data: The bytes to create the character position from. + :return: The character position created from the bytes. + """ + command, x, y, z, *values = struct.unpack( + cls.STRUCT_FORMAT, + data + ) + return cls.model_validate( + dict(zip( + cls.model_fields.keys(), + (command, Vector3(x=x, y=y, z=z), *values) + )) + ) + + def pack(self) -> bytes: + """Packs the character position into bytes. + + :return: The character position packed into bytes. + """ + return struct.pack( + self.STRUCT_FORMAT, + self.command, + *self.position, + self.rotation, + self.animation_time, + self.animation_mode.value.encode(), + ) + + +class CharacterAttachStreamData(Sendable): + STRUCT_FORMAT: ClassVar[str] = 'iiii' + """The struct format of the character attach stream data. + ``` + i: command + i: source_id + i: stream_id + i: position + ``` + """ + + command: Literal[CharacterCommand.ATTACH] + source_id: int + stream_id: int + position: int + + @classmethod + def from_bytes(cls, data: bytes) -> 'CharacterAttachStreamData': + """Creates a character attach from the bytes. + + :param data: The bytes to create the character attach from. + :return: The character attach created from the bytes. + """ + return cls.model_validate( + dict(zip( + cls.model_fields.keys(), + struct.unpack(cls.STRUCT_FORMAT, data) + )) + ) + + def pack(self) -> bytes: + """Packs the character attach into bytes. + + :return: The character attach packed into bytes. + """ + return struct.pack( + self.STRUCT_FORMAT, + self.source_id, + self.stream_id, + self.position, + ) + + +class CharacterDetachStreamData(Sendable): + STRUCT_FORMAT: ClassVar[str] = 'i' + """The struct format of the character detach stream data. + ``` + i: command + ``` + """ + + command: Literal[CharacterCommand.DETACH] + + @classmethod + def from_bytes(cls, data: bytes) -> 'CharacterDetachStreamData': + """Creates a character detach from the bytes. + + :param data: The bytes to create the character detach from. + :return: The character detach created from the bytes. + """ + return super().from_bytes(data) + + def pack(self) -> bytes: + """Packs the character detach into bytes. + + :return: The character detach packed into bytes. + """ + return super().pack() + + +class VehicleStreamData(Sendable): + STRUCT_FORMAT: ClassVar[str] = 'IfffIfffI3f' + """The struct format of the vehicle state data. + ``` + I: time + f: engine_speed + f: engine_force + f: engine_clutch + I: engine_gear + f: steering + f: brake + f: wheel_speed + I: flag_mask + 3f: position + Xs: node_data + ``` + """ + + time: int + engine_rpm: float + engine_accerlation: float + engine_clutch: float + engine_gear: int + steering: float + brake: float + wheel_speed: float + flag_mask: int + position: Vector3 + node_data: bytes + + @classmethod + def from_bytes(cls, data: bytes) -> 'VehicleStreamData': + """Creates a vehicle state from the bytes. + + :param data: The bytes to create the vehicle state from. + :return: The vehicle state created from the bytes. + """ + *values, x, y, z = struct.unpack( + cls.STRUCT_FORMAT, + data[:cls.calc_size()] + ) + + node_data, *_ = struct.unpack( + f'{len(data) - cls.calc_size()}s', + data[cls.calc_size():] + ) + + return cls.model_validate( + dict(zip( + cls.model_fields.keys(), + (*values, Vector3(x=x, y=y, z=z), node_data) + )) + ) + + def pack(self) -> bytes: + """Packs the vehicle state into bytes. + + :return: The vehicle state packed into bytes. + """ + return struct.pack( + f'{self.STRUCT_FORMAT}{len(self.node_data)}s', + self.time, + self.engine_rpm, + self.engine_accerlation, + self.engine_clutch, + self.engine_gear, + self.steering, + self.brake, + self.wheel_speed, + self.flag_mask, + *self.position, + self.node_data, + ) + + +StreamData = ( + CharacterAttachStreamData + | CharacterPositionStreamData + | CharacterDetachStreamData + | VehicleStreamData +) + + +def stream_data_factory(type: StreamType, data: bytes) -> StreamData: + """Creates a stream data of the given type. + + :param type: The type of the stream data. + :param data: The bytes to create the stream data from. + :return: The stream data of the given type. + """ + stream_data: StreamData | None = None + if type is StreamType.CHARACTER: + command = CharacterCommand(struct.unpack('i', data[:4])[0]) + if command is CharacterCommand.ATTACH: + stream_data = CharacterAttachStreamData.from_bytes(data) + elif command is CharacterCommand.POSITION: + stream_data = CharacterPositionStreamData.from_bytes(data) + elif command is CharacterCommand.DETACH: + stream_data = CharacterDetachStreamData.from_bytes(data) + else: + raise ValueError(f'Invalid character command: {command!r}') + elif type is StreamType.ACTOR: + stream_data = VehicleStreamData.from_bytes(data) + else: + raise ValueError(f'Invalid stream type: {type!r}') + return stream_data diff --git a/ror_server_bot/ror_bot/models/stats.py b/ror_server_bot/ror_bot/models/stats.py new file mode 100644 index 0000000..f3c0971 --- /dev/null +++ b/ror_server_bot/ror_bot/models/stats.py @@ -0,0 +1,25 @@ +from datetime import datetime, timedelta + +from pydantic import BaseModel, Field + + +class DistanceStats(BaseModel): + meters_driven: float = 0 + meters_sailed: float = 0 + meters_walked: float = 0 + meters_flown: float = 0 + + +class GlobalStats(DistanceStats): + connected_at: datetime = Field(default_factory=datetime.now) + usernames: set[str] = Field(default=set()) + user_count: int = 0 + connection_times: list[timedelta] = Field(default=[]) + + def add_user(self, username: str): + self.usernames.add(username) + self.user_count += 1 + + +class UserStats(DistanceStats): + online_since: datetime = Field(default_factory=datetime.now) diff --git a/ror_server_bot/ror_bot/models/truck_file.py b/ror_server_bot/ror_bot/models/truck_file.py new file mode 100644 index 0000000..4ed2e62 --- /dev/null +++ b/ror_server_bot/ror_bot/models/truck_file.py @@ -0,0 +1,70 @@ +import json +import logging +import re +from pathlib import Path + +from pydantic import BaseModel, RootModel + +from ror_server_bot.ror_bot.enums import ActorType + +logger = logging.getLogger(__name__) + +truckfile_re = re.compile( + r'((?P[a-z0-9]*)\-)?((.*)UID\-)?(?P.*)' + r'\.(?Ptruck|car|load|airplane|boat|trailer|train|fixed)' +) + + +class TruckFilenames(RootModel): + root: dict[str, str] + + def __iter__(self): + return iter(self.root) + + def __getitem__(self, key: str) -> str: + return self.root[key] + + def get(self, key, default: str | None = None) -> str | None: + return self.root.get(key, default) + + @classmethod + def from_json(cls, filename: Path) -> 'TruckFilenames': + with open(filename) as file: + return cls.model_validate(json.load(file)) + + +class TruckFile(BaseModel): + filename: Path + """The full filename of the .truck file including the extension.""" + guid: str | None = None + """The guid included in the filename (optional).""" + name: str + """The display name of the actor.""" + type: ActorType + """The type of the actor.""" + + @classmethod + def from_filename( + cls, + json_file: Path, + truck_filename: str + ) -> 'TruckFile': + """Creates a truck file from the filename. + + :param json_file: The json file to get the truck file name from. + :param filename: The filename to create the truck file from. + :return: The truck file created from the filename. + """ + name = TruckFilenames.from_json(json_file).get(truck_filename) + match = truckfile_re.search(truck_filename) + if name is None and match is not None: + return cls( + filename=truck_filename, + **match.groupdict() + ) + else: + return cls( + filename=truck_filename, + name=name, + type=truck_filename.rsplit('.', maxsplit=1)[-1].lower() + ) diff --git a/ror_server_bot/ror_bot/models/validators.py b/ror_server_bot/ror_bot/models/validators.py new file mode 100644 index 0000000..dcc8b16 --- /dev/null +++ b/ror_server_bot/ror_bot/models/validators.py @@ -0,0 +1,12 @@ +from pydantic import field_validator + + +def strip_nulls_after(*fields: str): + """A validator that strips null characters from provided fields.""" + def __strip_null_character(v: str) -> str: + return v.strip('\x00') + return field_validator( + *fields, + mode='after', + check_fields=False + )(__strip_null_character) diff --git a/ror_server_bot/ror_bot/models/vector.py b/ror_server_bot/ror_bot/models/vector.py new file mode 100644 index 0000000..37f60a4 --- /dev/null +++ b/ror_server_bot/ror_bot/models/vector.py @@ -0,0 +1,246 @@ +import math + +from pydantic import BaseModel + + +class Vector3(BaseModel): + x: float = 0.0 + y: float = 0.0 + z: float = 0.0 + + def __getitem__(self, index: int) -> float: + return [self.x, self.y, self.z][index] + + def __setitem__(self, index: int, value: float): + if index == 0: + self.x = value + elif index == 1: + self.y = value + elif index == 2: + self.z = value + else: + raise IndexError(index) + + def __iter__(self): + return iter((self.x, self.y, self.z)) + + def __len__(self): + return len(Vector3.model_fields) + + def __hash__(self): + return hash((self.x, self.y, self.z)) + + def __eq__(self, __value: object) -> bool: + if isinstance(__value, Vector3): + return ( + self.x == __value.x + and self.y == __value.y + and self.z == __value.z + ) + elif ( + isinstance(__value, tuple) + and len(__value) == len(Vector3.model_fields) + and all(isinstance(v, (int, float)) for v in __value) + ): + return bool( + self.x == __value[0] + and self.y == __value[1] + and self.z == __value[2] + ) + return NotImplemented + + def __lt__(self, __value: object) -> bool: + if isinstance(__value, Vector3): + return ( + self.x < __value.x + and self.y < __value.y + and self.z < __value.z + ) + elif ( + isinstance(__value, tuple) + and len(__value) == len(Vector3.model_fields) + and all(isinstance(v, (int, float)) for v in __value) + ): + return bool( + self.x < __value[0] + and self.y < __value[1] + and self.z < __value[2] + ) + return NotImplemented + + def __le__(self, __value: object) -> bool: + return self.__lt__(__value) or self.__eq__(__value) + + def __gt__(self, __value: object) -> bool: + if isinstance(__value, Vector3): + return ( + self.x > __value.x + and self.y > __value.y + and self.z > __value.z + ) + elif ( + isinstance(__value, tuple) + and len(__value) == len(Vector3.model_fields) + and all(isinstance(v, (int, float)) for v in __value) + ): + return bool( + self.x > __value[0] + and self.y > __value[1] + and self.z > __value[2] + ) + return NotImplemented + + def __ge__(self, __value: object) -> bool: + return self.__gt__(__value) or self.__eq__(__value) + + def __repr__(self) -> str: + return f'Vector3({self.x}, {self.y}, {self.z})' + + def __str__(self) -> str: + return f'({self.x}, {self.y}, {self.z})' + + def __format__(self, format_spec: str) -> str: + return ( + f'({self.x:{format_spec}}, ' + f'{self.y:{format_spec}}, ' + f'{self.z:{format_spec}})' + ) + + def distance(self, other: 'Vector3') -> float: + """Calculates the distance to another Vector3 + + :param other: A Vector3 to calculate the distance to + :return: The distance to the other Vector3 + """ + return math.sqrt( + (self.x - other.x) ** 2 + + (self.y - other.y) ** 2 + + (self.z - other.z) ** 2 + ) + + +class Vector4(BaseModel): + x: float = 0.0 + y: float = 0.0 + z: float = 0.0 + w: float = 0.0 + + def __getitem__(self, index: int) -> float: + return [self.x, self.y, self.z, self.w][index] + + def __setitem__(self, index: int, value: float): + if index == 0: + self.x = value + elif index == 1: + self.y = value + elif index == 2: + self.z = value + elif index == 3: + self.w = value + else: + raise IndexError(index) + + def __iter__(self): + return iter((self.x, self.y, self.z, self.w)) + + def __len__(self): + return len(Vector4.model_fields) + + def __hash__(self): + return hash((self.x, self.y, self.z, self.w)) + + def __eq__(self, __value: object) -> bool: + if isinstance(__value, Vector4): + return ( + self.x == __value.x + and self.y == __value.y + and self.z == __value.z + and self.w == __value.w + ) + elif ( + isinstance(__value, tuple) + and len(__value) == len(Vector4.model_fields) + and all(isinstance(v, (int, float)) for v in __value) + ): + return bool( + self.x == __value[0] + and self.y == __value[1] + and self.z == __value[2] + and self.w == __value[3] + ) + return NotImplemented + + def __lt__(self, __value: object) -> bool: + if isinstance(__value, Vector4): + return ( + self.x < __value.x + and self.y < __value.y + and self.z < __value.z + and self.w < __value.w + ) + elif ( + isinstance(__value, tuple) + and len(__value) == len(Vector4.model_fields) + and all(isinstance(v, (int, float)) for v in __value) + ): + return bool( + self.x < __value[0] + and self.y < __value[1] + and self.z < __value[2] + and self.w < __value[3] + ) + return NotImplemented + + def __le__(self, __value: object) -> bool: + return self.__lt__(__value) or self.__eq__(__value) + + def __gt__(self, __value: object) -> bool: + if isinstance(__value, Vector4): + return ( + self.x > __value.x + and self.y > __value.y + and self.z > __value.z + and self.w > __value.w + ) + elif ( + isinstance(__value, tuple) + and len(__value) == len(Vector4.model_fields) + and all(isinstance(v, (int, float)) for v in __value) + ): + return bool( + self.x > __value[0] + and self.y > __value[1] + and self.z > __value[2] + and self.w > __value[3] + ) + return NotImplemented + + def __ge__(self, __value: object) -> bool: + return self.__gt__(__value) or self.__eq__(__value) + + def __repr__(self) -> str: + return f'Vector4({self.x}, {self.y}, {self.z}, {self.w})' + + def __str__(self) -> str: + return f'({self.x}, {self.y}, {self.z}, {self.w})' + + def __format__(self, format_spec: str) -> str: + return ( + f'({self.x:{format_spec}}, ' + f'{self.y:{format_spec}}, ' + f'{self.z:{format_spec}}, ' + f'{self.w:{format_spec}})' + ) + + def distance(self, other: 'Vector4') -> float: + """Calculates the distance to another Vector4 + + :param other: A Vector4 to calculate the distance to + :return: The distance to the other Vector4 + """ + return math.sqrt( + (self.x - other.x) ** 2 + + (self.y - other.y) ** 2 + + (self.z - other.z) ** 2 + + (self.w - other.w) ** 2 + ) diff --git a/ror_server_bot/ror_bot/packet_handler.py b/ror_server_bot/ror_bot/packet_handler.py new file mode 100644 index 0000000..5b5aba8 --- /dev/null +++ b/ror_server_bot/ror_bot/packet_handler.py @@ -0,0 +1,80 @@ +import inspect + +from pyee.asyncio import AsyncIOEventEmitter + +from .enums import MessageType +from .models import Packet + + +class PacketHandler: + def __init__(self, event_emitter: AsyncIOEventEmitter) -> None: + """Create a new PacketHandler. This class is used to register + event handlers on an EventEmitter. The event handlers are + methods of this class that start with 'on_' or 'once_' and + end with the name of the packet type. + + Methods must have the following signature: + ``` + def on_hello(self, packet: Packet) -> None: + ... + def once_hello(self, packet: Packet) -> None: + ... + ``` + + The method `on_hello` will be called when a `MessageType.HELLO` + packet is received and the method `once_hello` will be called + once when a `MessageType.HELLO` packet is received. + + :param event_emitter: The EventEmitter to register the packet + handlers on. The event_emitter must have the wildcard option + enabled. See `pymitter.EventEmitter` for more information. + """ + for (name, method) in inspect.getmembers(self, inspect.ismethod): + for prefix in ('on_', 'once_'): + if name.startswith(prefix): + if name == prefix + 'packet': + event = 'packet' + else: + message_type = name[len(prefix):].upper() + if message_type not in MessageType._member_names_: + raise ValueError( + f'Invalid packet type: {message_type}' + ) + event = f'packet.{message_type}' + + parameters = inspect.signature(method).parameters + + if len(parameters) != 1: + raise ValueError( + f'Invalid signature for "{name}". ' + 'Expected 2 parameters: self, packet', + ) + + (param_name, parameter), *_ = parameters.items() + + if parameter.annotation != Packet: + raise ValueError( + f'Invalid signature for {name}. ' + f'Expected "{param_name}" to be of type Packet' + ) + + if prefix == 'on_': + event_emitter.on(event, method) + elif prefix == 'once_': + event_emitter.once(event, method) + + +if __name__ == '__main__': + class Test(PacketHandler): + def __init__(self, event_emitter: AsyncIOEventEmitter) -> None: + super().__init__(event_emitter) + + def on_hello(self, foo: Packet) -> None: + print('hello', foo) + + def other_method(self) -> None: + print('other_method') + + ee = AsyncIOEventEmitter() + test = Test(ee) + ee.emit('packet.*', Packet(command=MessageType.HELLO)) diff --git a/ror_server_bot/ror_bot/ror_client.py b/ror_server_bot/ror_bot/ror_client.py new file mode 100644 index 0000000..cc760e1 --- /dev/null +++ b/ror_server_bot/ror_bot/ror_client.py @@ -0,0 +1,424 @@ +import asyncio +import logging +import struct +import time +from enum import Enum +from typing import Callable + +from pyee.asyncio import AsyncIOEventEmitter + +from .enums import MessageType, StreamType +from .models import ( + Packet, + RoRClientConfig, + stream_data_factory, + stream_register_factory, + UserInfo, +) +from .packet_handler import PacketHandler +from .ror_connection import RoRConnection +from .stream_manager import UserNotFoundError +from .user import StreamNotFoundError + +logger = logging.getLogger(__name__) + + +def check_packet_type(packet: Packet, message_type: MessageType) -> bool: + """Check if the packet is of the specified type. + + :param packet: The packet to check. + :param message_type: The type to check the packet against. + :return: True if the packet is of the specified type, False + """ + return packet.command is message_type + + +class RoRClientEvents(Enum): + FRAME_STEP = 'frame_step' + NET_QUALITY = 'net_quality' + CHAT = 'chat' + PRIVATE_CHAT = 'private_chat' + USER_JOIN = 'user_join' + USER_INFO = 'user_info' + USER_LEAVE = 'user_leave' + GAME_CMD = 'game_cmd' + STREAM_REGISTER = 'stream_register' + STREAM_REGISTER_RESULT = 'stream_register_result' + STREAM_DATA = 'stream_data' + STREAM_UNREGISTER = 'stream_unregister' + + +class RoRClient(PacketHandler): + STABLE_FPS = 20 + + def __init__(self, client_config: RoRClientConfig) -> None: + """Create a new RoRClient. This class is used to connect to a + RoR server and handle packets received from the server. It + inherits from PacketHandler and registers event handlers on + an AsyncIOEventEmitter. + + See `ror_client.packet_handler.PacketHandler` for more information. + + :param client_config: The configuration to use for the client. + """ + self.config = client_config + + self.server = RoRConnection( + username=self.config.user.name, + user_token=self.config.user.token, + password=self.config.server.password, + host=self.config.server.host, + port=self.config.server.port, + ) + + super().__init__(self.server) + + self._frame_step_task: asyncio.Task + + self.event_emitter = AsyncIOEventEmitter() + + self.event_emitter.add_listener('new_listener', self._new_listener) + self.event_emitter.add_listener('error', self._error) + + async def __aenter__(self): + for attempt in range(self.config.reconnection_tries): + try: + logger.info( + 'Attempt %d/%d to connect to RoR server: %s', + attempt + 1, + self.config.reconnection_tries, + self.server.address + ) + self.server = await self.server.__aenter__() + except ConnectionRefusedError: + logger.warning('Connection refused!') + + if attempt < self.config.reconnection_tries - 1: + logger.info( + 'Waiting %.2f seconds before next attempt', + self.config.reconnection_interval + ) + await asyncio.sleep(self.config.reconnection_interval) + else: + break + + if self.server.is_connected: + logger.info('Connected to RoR server: %s', self.server.address) + else: + raise ConnectionError( + f'Could not connect to RoR server {self.server.address} ' + f'after {self.config.reconnection_tries} attempts', + ) + + self._frame_step_task = asyncio.create_task( + self._frame_step_loop(), + name='frame_step_loop' + ) + + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.server.__aexit__(exc_type, exc, tb) + + if self._frame_step_task: + self._frame_step_task.cancel() + + def _new_listener(self, event: str, listener: Callable): + """Handles new listener events. + + :param event: The event that was added. + :param listener: The listener that was added. + """ + logger.debug( + 'New listener added: event="%s" listener="%s"', + event, + listener.__name__ + ) + + def _error(self, error: Exception): + """Handles error events. + + :param error: The error that was emitted. + """ + logger.error('%r', error, exc_info=True) + + def emit(self, event: RoRClientEvents, *args, **kwargs): + """Emit an event on the event emitter. + + :param event: The event to emit. + :param args: The arguments to pass to the event handler. + :param kwargs: The keyword arguments to pass to the event handler. + """ + if event is not RoRClientEvents.FRAME_STEP: + # we do not need to log every frame_step event emit + logger.debug( + '[EMIT] event=%r listeners=%d', + event.value, + len(self.event_emitter.listeners(event.value)) + ) + self.event_emitter.emit(event.value, *args, **kwargs) + + async def _frame_step_loop(self): + """Send frame_step events at a stable rate.""" + start_time = time.time() + current_time = start_time + delta = 0 + while True: + prev_time = current_time + current_time = time.time() + delta += current_time - prev_time + + if delta >= (self.STABLE_FPS / 60): + self.emit(RoRClientEvents.FRAME_STEP, delta) + delta = 0 + + await asyncio.sleep(0.01) + + def on_packet(self, packet: Packet) -> None: + """Handle packets received from the server. + + :param packet: The packet to handle. + """ + if ( + packet.command in (MessageType.HELLO, MessageType.WELCOME) + or f'packet.{packet.command.name}' in self.server.event_names() + ): + return + + logger.warning('Unhandled packet command: %r', packet.command) + + def on_net_quality(self, packet: Packet) -> None: + """Handle net_quality packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.NET_QUALITY): + return + + net_quality: int = struct.unpack('I', packet.data)[0] + + if self.server.net_quality != net_quality: + logger.info( + 'Net quality for uid=%d changed: %d -> %d', + packet.source, + self.server.net_quality, + net_quality + ) + self.emit(RoRClientEvents.NET_QUALITY, packet.source, net_quality) + else: + logger.debug( + 'Net quality unchanged for uid=%d: %d', + packet.source, + net_quality + ) + + self.server.net_quality = net_quality + + def on_chat(self, packet: Packet) -> None: + """Handle chat packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.CHAT): + return + + message = packet.data.decode().strip('\x00') + + logger.info('[CHAT] from_uid=%d message=%r', packet.source, message) + + if message and packet.source != self.server.unique_id: + self.emit(RoRClientEvents.CHAT, packet.source, message) + + def on_private_chat(self, packet: Packet) -> None: + """Handle private_chat packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.PRIVATE_CHAT): + return + + message = packet.data.decode().strip('\x00') + + logger.info('[PRIV] from_uid=%d message=%r', packet.source, message) + + if message and packet.source != self.server.unique_id: + self.emit(RoRClientEvents.PRIVATE_CHAT, packet.source, message) + + def on_user_join(self, packet: Packet) -> None: + """Handle user_join packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.USER_JOIN): + return + + logger.info('User join received') + + user_info = UserInfo.from_bytes(packet.data) + + self.server.stream_manager.add_user(user_info) + + if user_info.unique_id != self.server.unique_id: + self.emit(RoRClientEvents.USER_JOIN, packet.source, user_info) + + def on_user_info(self, packet: Packet) -> None: + """Handle user_info packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.USER_INFO): + return + + logger.info('User info received') + + user_info = UserInfo.from_bytes(packet.data) + + self.server.stream_manager.add_user(user_info) + + if user_info.unique_id != self.server.unique_id: + self.emit(RoRClientEvents.USER_INFO, packet.source, user_info) + + def on_user_leave(self, packet: Packet) -> None: + """Handle user_leave packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.USER_LEAVE): + return + + logger.info( + 'uid=%r left with reason: %s', + packet.source, + packet.data + ) + + if packet.source == self.server.unique_id: + raise ConnectionError('RoRClient disconnected from the server') + + self.server.stream_manager.delete_user(packet.source) + + self.emit(RoRClientEvents.USER_LEAVE, packet.source) + + def on_game_cmd(self, packet: Packet) -> None: + """Handle game_cmd packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.GAME_CMD): + return + + if packet.source == self.server.unique_id: + return + + game_cmd = packet.data.decode().strip('\x00') + + logger.debug('[GAME_CMD] from_uid=%d cmd=%r', packet.source, game_cmd) + + if game_cmd and packet.source != self.server.unique_id: + self.emit(RoRClientEvents.GAME_CMD, packet.source, game_cmd) + + async def on_stream_register(self, packet: Packet) -> None: + """Handle stream_register packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.STREAM_REGISTER): + return + + stream = stream_register_factory(packet.data) + + logger.info('Stream register received: %s', stream) + + self.server.stream_manager.add_stream(stream) + + if stream.type is StreamType.ACTOR: + # why? + await self.server.reply_to_stream_register(stream, status=-1) + + self.emit(RoRClientEvents.STREAM_REGISTER, packet.source, stream) + + def on_stream_register_result(self, packet: Packet) -> None: + """Handle stream_register_result packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.STREAM_REGISTER_RESULT): + return + + stream = stream_register_factory(packet.data) + + logger.info('Stream register result received: %s', stream) + + self.emit( + RoRClientEvents.STREAM_REGISTER_RESULT, + packet.source, + stream, + ) + + def on_stream_data(self, packet: Packet) -> None: + """Handle stream_data packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.STREAM_DATA): + return + + try: + stream = self.server.stream_manager.get_stream( + packet.source, + packet.stream_id + ) + except UserNotFoundError: + logger.error( + 'Could not find user! uid=%d', + packet.source, + exc_info=True + ) + except StreamNotFoundError: + logger.error( + 'Could not find stream! sid=%d', + packet.stream_id, + exc_info=True + ) + else: + if stream.type in (StreamType.CHARACTER, StreamType.ACTOR): + logger.info('%s stream received', stream.type.name.title()) + stream_data = stream_data_factory(stream.type, packet.data) + + logger.debug('Stream data: %s', stream_data) + elif stream.type is StreamType.CHAT: + logger.info('Chat stream received') + stream_data = None + else: + raise ValueError(f'Unknown stream type: {stream.type!r}') + + self.emit( + RoRClientEvents.STREAM_DATA, + packet.source, + stream, + stream_data + ) + + def on_stream_unregister(self, packet: Packet) -> None: + """Handle stream_unregister packets. + + :param packet: The packet to handle. + """ + if not check_packet_type(packet, MessageType.STREAM_UNREGISTER): + return + + if len(packet.data) != 0: + raise ValueError('Stream unregister packet has data') + + logger.info('Stream unregister received: %s', packet) + + self.server.stream_manager.delete_stream( + packet.source, + packet.stream_id + ) + + self.emit( + RoRClientEvents.STREAM_UNREGISTER, + packet.source, + packet.stream_id + ) diff --git a/ror_server_bot/ror_bot/ror_connection.py b/ror_server_bot/ror_bot/ror_connection.py new file mode 100644 index 0000000..a3340f5 --- /dev/null +++ b/ror_server_bot/ror_bot/ror_connection.py @@ -0,0 +1,595 @@ +import asyncio +import hashlib +import logging +import math +import struct +import time +from datetime import datetime +from typing import Callable + +from pyee.asyncio import AsyncIOEventEmitter + +from ror_server_bot import RORNET_VERSION + +from .enums import ( + AuthLevels, + CharacterAnimation, + CharacterCommand, + MessageType, + StreamType, +) +from .models import ( + ActorStreamRegister, + CharacterStreamRegister, + ChatStreamRegister, + Packet, + ServerInfo, + StreamRegister, + UserInfo, +) +from .models.sendable import CharacterPositionStreamData +from .models.vector import Vector3 +from .stream_manager import StreamManager + +logger = logging.getLogger(__name__) + + +class PacketError(Exception): + """An error that occurs when a packet is malformed.""" + + +class UnexpectedCommandError(Exception): + """An error that occurs when a packet with an unexpected command is + received.""" + + +class RoRConnection(AsyncIOEventEmitter): + def __init__( + self, + username: str, + user_token: str, + password: str, + host: str, + port: int, + heartbeat_interval: float = 1.0 + ) -> None: + """Creates a new RoRConnection object. This object should be used + with the async with statement. + + For example: + ``` + >>> async with RoRConnection(...) as conn: + >>> await conn.send_chat('Hello World!') + ``` + + :param username: The username to connect with. + :param user_token: The user token to connect with. + :param password: The password to the server. + :param host: The IP address of the server. + :param port: The port the server is running on. + :param heartbeat_interval: The interval to send heartbeat + packets to the server, defaults to 1.0. + """ + super().__init__() + + self.add_listener('new_listener', self._new_listener) + self.add_listener('error', self._error) + + self._reader: asyncio.StreamReader + self._writer: asyncio.StreamWriter + self._writer_lock: asyncio.Lock + + self._reader_task: asyncio.Task + self._heartbeat_task: asyncio.Task + + self._task_group = asyncio.TaskGroup() + + self._connect_time: datetime + + self.stream_id = 10 # stream ids under 10 are reserved + self.net_quality = 0 + + self._host = host + self._port = port + self._password = hashlib.sha1(password.encode()).hexdigest().upper() + self._heartbeat_interval = heartbeat_interval + + self._is_connected = False + + self.server_info: ServerInfo + self.user_info = UserInfo( + auth_status=AuthLevels.BOT, + username=username, + user_token=user_token, + server_password=self._password, + language='en-US', + client_name='bot', + client_version='2022.12', + client_guid='', + session_type='bot', + session_options='', + ) + self.stream_manager = StreamManager() + + @property + def is_connected(self) -> bool: + """Gets if the client is connected to the server.""" + return self._is_connected + + @property + def connect_time(self) -> datetime: + """Gets the time the client connected to the server.""" + return self._connect_time + + @property + def address(self) -> str: + """Gets the address of the server.""" + return f'{self._host}:{self._port}' + + @property + def unique_id(self) -> int: + """Gets the unique id of the client.""" + return self.user_info.unique_id + + async def __aenter__(self) -> 'RoRConnection': + """Connects to the server. + + :return: The connected RoRConnection object. + """ + await self._task_group.__aenter__() + + logger.info('Connecting to %s', self.address) + + self._reader, self._writer = await asyncio.open_connection( + self._host, + self._port + ) + + self._writer_lock = asyncio.Lock() + + logger.info('Starting reader loop') + + self._reader_task = self._task_group.create_task( + self.__reader_loop(), + name='reader_loop' + ) + + hello_packet = await self.__send_hello() + self.server_info = ServerInfo.from_bytes(hello_packet.data) + + logger.info('Received Server Info: %s', self.server_info) + + welcome_packet = await self.__send_welcome() + self.user_info = UserInfo.from_bytes(welcome_packet.data) + + logger.info('Received User Info: %s', self.user_info) + self.stream_manager.add_user(self.user_info) + + await self.__register_streams() + + self._connect_time = datetime.now() + + self._is_connected = True + + logger.info('Starting heartbeat loop') + + self._heartbeat_task = self._task_group.create_task( + self.__heartbeat_loop(), + name='heartbeat_loop' + ) + + return self + + async def __aexit__(self, exc_type, exc, tb): + """Disconnects from the server. + + :param exc_type: The exception type. + :param exc: The exception. + :param tb: The traceback. + """ + logger.info('Disconnecting from %s', self.address) + + await self._send(Packet( + command=MessageType.USER_LEAVE, + source=self.unique_id, + stream_id=self.stream_id, + size=0, + )) + + await self._task_group.__aexit__(exc_type, exc, tb) + + if self._reader_task is not None: + self._reader_task.cancel() + + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + + async with self._writer_lock: + self._reader.feed_eof() + self._writer.close() + await self._writer.wait_closed() + + self._is_connected = False + + async def __send_hello(self) -> Packet: + logger.info('Sending Hello Message') + + hello_packet = Packet( + command=MessageType.HELLO, + source=0, # we do not have a unique id yet + stream_id=self.stream_id, + size=len(RORNET_VERSION), + data=RORNET_VERSION + ) + + future: asyncio.Future[Packet] = asyncio.Future() + + @self.once('packet') + async def hello(packet: Packet): + if packet.command is not MessageType.HELLO: + raise UnexpectedCommandError('Did not recieve hello response') + future.set_result(packet) + + await self._send(hello_packet) + + return await future + + async def __send_welcome(self) -> Packet: + logger.info('Sending User Info: %s', self.user_info) + + data = self.user_info.pack() + welcome_packet = Packet( + command=MessageType.USER_INFO, + source=self.unique_id, + stream_id=self.stream_id, + size=len(data), + data=data + ) + + future: asyncio.Future[Packet] = asyncio.Future() + + @self.once('packet') + async def welcome(packet: Packet): + if packet.command is MessageType.WELCOME: + future.set_result(packet) + elif packet.command is MessageType.SERVER_FULL: + raise ConnectionError('Server is full') + elif packet.command is MessageType.BANNED: + raise ConnectionError('RoR Client is banned') + elif packet.command is MessageType.WRONG_PASSWORD: + raise ConnectionError('Wrong password') + elif packet.command is MessageType.WRONG_VERSION: + raise ConnectionError('Wrong version') + else: + raise UnexpectedCommandError( + 'Invalid response: %r', + packet.command + ) + + await self._send(welcome_packet) + + return await future + + async def __register_streams(self): + chat_stream_reg = ChatStreamRegister( + type=StreamType.CHAT, + status=0, + origin_source_id=self.unique_id, + origin_stream_id=self.stream_id, + name='chat', + reg_data='0', + ) + logger.info('Sending Chat Stream Register: %s', chat_stream_reg) + + await self.register_stream(chat_stream_reg) + + char_stream_reg = CharacterStreamRegister( + type=StreamType.CHARACTER, + status=0, + origin_source_id=self.unique_id, + origin_stream_id=self.stream_id, + name='default', + reg_data=b'\x02', + ) + + logger.info('Sending Character Stream Register: %s', char_stream_reg) + + await self.register_stream(char_stream_reg) + + async def __reader_loop(self): + """The main reader loop. Reads packets from the server and emits + events. + + This function should not be called directly. + + This function will emit the following events when a packet is + received: + - `packet.*`: Emits for every packet received. + - `packet.`: Emits an event with the name of the + command from the packet received. For example, if a packet with + the command `MessageType.CHAT` is received, the event + `packet.CHAT` will be emitted. + """ + while True: + header = await self._reader.readexactly(Packet.calc_size()) + + logger.debug('[HEAD] %s', header) + + try: + packet = Packet.from_bytes(header) + except struct.error as e: + raise PacketError( + f'Failed to read packet header: {header}' + ) from e + + if ( + packet.command is not MessageType.STREAM_UNREGISTER + and packet.size == 0 + ): + raise PacketError(f'No data to read: {packet}') + + payload = await self._reader.read(packet.size) + + if len(payload) != packet.size: + logger.warning( + 'Packet size mismatch: data=%s packet=%s', + payload, + packet + ) + + logger.debug('[RECV] %s', payload) + + packet.data = payload + + logger.debug('[PCKT] %s', packet) + + # emit to packet wildcard + self.emit('packet', packet) + + # command event + self.emit('packet.' + packet.command.name, packet) + + await asyncio.sleep(0.01) + + async def __heartbeat_loop(self): + """The heartbeat loop. Sends a character position stream packet + to the server on a constant interval. This is done to prevent + the server from kicking the client for inactivity. + + This function should not be called directly. + """ + if not self.is_connected: + raise ConnectionError( + 'Cannot start heartbeat loop when not connected' + ) + + stream = CharacterPositionStreamData( + command=CharacterCommand.POSITION, + position=Vector3(), + rotation=0, + animation_time=self._heartbeat_interval, + animation_mode=CharacterAnimation.IDLE_SWAY, + ) + + packet = Packet( + command=MessageType.STREAM_DATA, + source=self.unique_id, + stream_id=self.stream_manager.get_character_sid(self.unique_id) + ) + + logger.info( + 'Sending character stream data every %f seconds. %s', + self._heartbeat_interval, + stream + ) + + start_time = time.time() + current_time = start_time + delta = 0 + while self._is_connected: + prev_time = current_time + current_time = time.time() + delta += current_time - prev_time + + if delta >= self._heartbeat_interval: + stream.animation_time = delta + delta = 0 + + if self._heartbeat_interval >= 1: + # avoid spamming logs + logger.info('Sending heartbeat character stream data.') + + data = stream.pack() + packet.data = data + packet.size = len(data) + + await self._send(packet) + + await asyncio.sleep(0.1) + + def _new_listener(self, event: str, listener: Callable): + """Handles new listener events. + + :param event: The event that was added. + :param listener: The listener that was added. + """ + logger.debug( + 'New listener added: event="%s" listener="%s"', + event, + listener.__name__ + ) + + def _error(self, error: Exception): + """Handles error events. + + :param error: The error that was emitted. + """ + logger.error('Error: %r', error, exc_info=True, stacklevel=2) + + async def _send(self, packet: Packet): + """Sends a packet to the server. + + :param packet: The packet to send. + """ + async with self._writer_lock: + data = packet.pack() + + logger.debug('[SEND] %s', data) + + self._writer.write(data) + + await self._writer.drain() + + async def register_stream(self, stream: StreamRegister) -> int: + """Registers a stream with the server as the client. + + :param stream: The stream being registered. + :return: The stream id of the stream. + """ + stream.origin_source_id = self.unique_id + stream.origin_stream_id = self.stream_id + + if isinstance(stream, ActorStreamRegister): + stream.timestamp = -1 + + stream_data = stream.pack() + packet = Packet( + command=MessageType.STREAM_REGISTER, + source=stream.origin_source_id, + stream_id=stream.origin_stream_id, + size=len(stream_data), + data=stream_data + ) + await self._send(packet) + self.stream_manager.add_stream(stream) + self.stream_id += 1 + + return stream.origin_stream_id + + async def unregister_stream(self, stream_id: int): + """Unregisters a stream with the server as the client. + + :param stream_id: The stream id of the stream to unregister. + """ + packet = Packet( + command=MessageType.STREAM_UNREGISTER, + source=self.unique_id, + stream_id=stream_id, + ) + await self._send(packet) + self.stream_manager.delete_stream(self.unique_id, stream_id) + + async def reply_to_stream_register( + self, + stream: StreamRegister, + status: int + ): + """Replies to a stream register request. + + :param stream: The stream to reply to. + :param status: The status to reply with. + """ + stream.status = status + data = stream.pack() + packet = Packet( + command=MessageType.STREAM_REGISTER_RESULT, + source=self.unique_id, + stream_id=stream.origin_stream_id, + size=len(data), + data=data + ) + await self._send(packet) + + async def send_chat(self, message: str): + """Sends a message to the game chat. + + :param message: The message to send. + """ + logger.info('[CHAT] message="%s"', message) + + data = message.encode() + + await self._send(Packet( + command=MessageType.CHAT, + source=self.unique_id, + stream_id=self.stream_manager.get_chat_sid(self.unique_id), + size=len(data), + data=data + )) + + async def send_private_chat(self, uid: int, message: str): + """Sends a private message to a user. + + :param uid: The uid of the user to send the message to. + :param message: The message to send. + """ + logger.info('[PRIV] to_uid=%d message="%s"', uid, message) + + data = struct.pack('I8000s', uid, message.encode()) + await self._send(Packet( + command=MessageType.PRIVATE_CHAT, + source=self.unique_id, + stream_id=self.stream_manager.get_chat_sid(self.unique_id), + size=len(data), + data=data + )) + + async def send_multiline_chat(self, message: str): + """Sends a multiline message to the game chat. + + :param message: The message to send. + """ + max_line_len = 100 + if len(message) > max_line_len: + logger.debug('[CHAT] multiline_message="%s"', message) + + total_lines = math.ceil(len(message) / max_line_len) + for i in range(total_lines): + line = message[max_line_len*i:max_line_len*(i+1)] + if i > 0: + line = f'| {line}' + await self.send_chat(line) + else: + await self.send_chat(message) + + async def kick(self, uid: int, reason: str = 'No reason given'): + """Kicks a user from the server. + + :param uid: The uid of the user to kick. + :param reason: The reason for kicking the user, defaults to + 'No reason given' + """ + await self.send_chat(f'!kick {uid} {reason}') + + async def ban(self, uid: int, reason: str = 'No reason given'): + """Bans a user from the server. + + :param uid: The uid of the user to ban. + :param reason: The reason for banning the user, defaults to + 'No reason given' + """ + await self.send_chat(f'!ban {uid} {reason}') + + async def say(self, uid: int, message: str): + """Send a message as a user anonymously. + + :param uid: The uid of the user to send the message to. If -1, + the message will be sent to everyone. + :param message: The message to send. + """ + await self.send_chat(f'!say {uid} {message}') + + async def send_game_cmd(self, command: str): + """Sends a game command (Angelscript) to the server. + + :param command: The command to send. + """ + logger.debug('[GAME_CMD] cmd="%s"', command) + data = command.encode() + await self._send(Packet( + command=MessageType.GAME_CMD, + source=self.unique_id, + stream_id=0, + size=len(data), + data=data + )) diff --git a/ror_server_bot/ror_bot/stream_manager.py b/ror_server_bot/ror_bot/stream_manager.py new file mode 100644 index 0000000..414054d --- /dev/null +++ b/ror_server_bot/ror_bot/stream_manager.py @@ -0,0 +1,286 @@ +import logging +from datetime import datetime +from itertools import chain + +from ror_server_bot import pformat + +from .enums import AuthLevels +from .models import GlobalStats, StreamRegister, UserInfo, Vector3, Vector4 +from .user import User + +logger = logging.getLogger(__name__) + + +class UserNotFoundError(Exception): + """Raised when a user is not found.""" + + +class StreamManager: + def __init__(self) -> None: + self.users: dict[int, User] = {} + self.global_stats = GlobalStats() + + @property + def user_count(self) -> int: + """Gets the number of users.""" + return len(self.users) - 1 # subtract 1 for the server client + + @property + def user_ids(self) -> list[int]: + """Gets the ids of the users.""" + return list(self.users.keys()) + + @property + def stream_ids(self) -> list[int]: + """Gets the ids of the streams.""" + return list(chain.from_iterable( + user.stream_ids for user in self.users.values() + )) + + def get_uid_by_username(self, username: str) -> int | None: + """Gets the uid of the user by their username. + + :param username: The username of the user. + :return: The uid of the user. + """ + for uid, user in self.users.items(): + if user.username == username: + return uid + return None + + def get_user(self, uid: int) -> User: + """Gets a user from the stream manager. + + :param uid: The uid of the user. + :return: The user. + """ + try: + return self.users[uid] + except KeyError as e: + raise UserNotFoundError(uid, pformat(self.users)) from e + + def add_user(self, user_info: UserInfo): + """Adds a client to the stream manager. + + :param user_info: The user info of the client to add. + """ + # update global stats if this is a new user + if user_info.unique_id not in self.users: + self.global_stats.add_user(user_info.username) + + # set the user to a new user if not already set + self.users.setdefault(user_info.unique_id, User(info=user_info)) + + # update the user info for the user + self.users[user_info.unique_id].info = user_info + + logger.info( + 'Added user %r uid=%d', + user_info.username, + user_info.unique_id + ) + + def delete_user(self, uid: int): + """Deletes a client from the stream manager. + + :param uid: The uid of the client to delete. + """ + user = self.users.pop(uid) + + self.global_stats.meters_driven += user.stats.meters_driven + self.global_stats.meters_sailed += user.stats.meters_sailed + self.global_stats.meters_walked += user.stats.meters_walked + self.global_stats.meters_flown += user.stats.meters_flown + + self.global_stats.connection_times.append( + datetime.now() - user.stats.online_since + ) + + logger.debug('Deleted user %r uid=%d', user.username, uid) + + def add_stream(self, stream: StreamRegister): + """Adds a stream to the stream manager. + + :param stream: The stream to add. + """ + self.get_user(stream.origin_source_id).add_stream(stream) + + def delete_stream(self, uid: int, sid: int): + """Deletes a stream from the stream manager. + + :param uid: The uid of the stream to delete. + :param sid: The sid of the stream to delete. + """ + self.get_user(uid).delete_stream(sid) + + def get_stream(self, uid: int, sid: int) -> StreamRegister: + """Gets a stream from the stream manager. + + :param uid: The uid of the stream. + :param sid: The sid of the stream. + :return: The stream. + """ + return self.get_user(uid).get_stream(sid) + + def get_current_stream(self, uid: int) -> StreamRegister: + """Gets the current stream of the user. + + :param uid: The uid of the user. + :return: The current stream of the user. + """ + return self.get_user(uid).get_current_stream() + + def set_current_stream(self, uid: int, actor_uid: int, sid: int): + """Sets the current stream of the user. + + :param uid: The uid of the user. + :param actor_uid: The uid of the actor. + :param sid: The sid of the stream. + """ + self.get_user(uid).set_current_stream(actor_uid, sid) + + def set_character_sid(self, uid: int, sid: int): + """Sets the character stream id of the user. + + :param uid: The uid of the user. + :param sid: The sid of the character stream. + """ + self.get_user(uid).character_stream_id = sid + + def get_character_sid(self, uid: int) -> int: + """Gets the character stream id of the user. + + :param uid: The uid of the user. + :return: The character stream id of the user. + """ + return self.get_user(uid).character_stream_id + + def set_chat_sid(self, uid: int, sid: int): + """Sets the chat stream id of the user. + + :param uid: The uid of the user. + :param sid: The sid of the chat stream. + """ + self.get_user(uid).chat_stream_id = sid + + def get_chat_sid(self, uid: int) -> int: + """Gets the chat stream id of the user. + + :param uid: The uid of the user. + :return: The chat stream id of the user. + """ + return self.get_user(uid).chat_stream_id + + def set_position(self, uid: int, sid: int, position: Vector3): + """Sets the position of the stream. + + :param uid: The uid of the user. + :param sid: The sid of the stream. + :param position: The position to set. + """ + self.get_user(uid).set_position(sid, position) + + def get_position(self, uid: int, sid: int = -1) -> Vector3: + """Gets the position of the stream. + + :param uid: The uid of the user. + :param sid: The sid of the stream, defaults to -1 + :return: The position of the stream. + """ + if sid == -1: + return self.get_current_stream(uid).position + else: + return self.get_user(uid).get_stream(sid).position + + def set_rotation(self, uid: int, sid: int, rotation: Vector4): + """Sets the rotation of the stream. + + :param uid: The uid of the user. + :param sid: The sid of the stream. + :param rotation: The rotation to set. + """ + self.get_user(uid).set_rotation(sid, rotation) + + def get_rotation(self, uid: int, sid: int = -1) -> Vector4: + """Gets the rotation of the stream. + + :param uid: The uid of the user. + :param sid: The sid of the stream, defaults to -1 + :return: The rotation of the stream. + """ + if sid == -1: + return self.get_current_stream(uid).rotation + else: + return self.get_user(uid).get_stream(sid).rotation + + def get_online_since(self, uid: int) -> datetime: + """Gets the online since of the user. + + :param uid: The uid of the user. + :return: The online since of the user. + """ + return self.get_user(uid).stats.online_since + + def total_streams(self, uid: int) -> int: + """Gets the total number of streams. + + :param uid: The uid of the user. + :return: The total number of streams. + """ + return self.get_user(uid).total_streams + + def get_username(self, uid: int) -> str: + """Gets the username of the user. + + :param uid: The uid of the user. + :return: The username of the user. + """ + return self.get_user(uid).username + + def get_username_colored(self, uid: int) -> str: + """Gets the username of the user with color. + + :param uid: The uid of the user. + :return: The username of the user with color. + """ + return self.get_user(uid).username_colored + + def get_language(self, uid: int) -> str: + """Gets the language of the user. + + :param uid: The uid of the user. + :return: The language of the user. + """ + return self.get_user(uid).language + + def get_client_name(self, uid: int) -> str: + """Gets the client name of the user. + + :param uid: The uid of the user. + :return: The client name of the user. + """ + return self.get_user(uid).client_name + + def get_client_version(self, uid: int) -> str: + """Gets the client version of the user. + + :param uid: The uid of the user. + :return: The client version of the user. + """ + return self.get_user(uid).client_version + + def get_client_guid(self, uid: int) -> str: + """Gets the client guid of the user. + + :param uid: The uid of the user. + :return: The client guid of the user. + """ + return self.get_user(uid).client_guid + + def get_auth_status(self, uid: int) -> AuthLevels: + """Gets the authentication status of the user. + + :param uid: The uid of the user. + :return: The authentication status of the user. + """ + return self.get_user(uid).auth_status diff --git a/ror_server_bot/ror_bot/user.py b/ror_server_bot/ror_bot/user.py new file mode 100644 index 0000000..a121e9d --- /dev/null +++ b/ror_server_bot/ror_bot/user.py @@ -0,0 +1,201 @@ +from pydantic import BaseModel, Field + +from ror_server_bot import TRUCK_TO_NAME_FILE + +from .enums import ActorType, AuthLevels, Color, StreamType +from .models import ( + StreamRegister, + TruckFile, + UserInfo, + UserStats, + Vector3, + Vector4, +) + + +class StreamNotFoundError(Exception): + """Raised when a stream is not found.""" + + +class CurrentStream(BaseModel): + unique_id: int = -1 + stream_id: int = -1 + + +class User(BaseModel): + info: UserInfo + streams: dict[int, StreamRegister] = Field(default={}) + stats: UserStats = UserStats() + character_stream_id: int = -1 + chat_stream_id: int = -1 + + current_stream: CurrentStream = CurrentStream() + + @property + def auth_status(self) -> AuthLevels: + """Get the authentication status of the user.""" + return self.info.auth_status + + @property + def username(self) -> str: + """Get the username of the user.""" + return self.info.username + + @property + def username_colored(self) -> str: + """Get the username formatted with the user's color.""" + return ( + f'{self.info.user_color}{self.info.username}{Color.WHITE.value}' + ) + + @property + def language(self) -> str: + """Get the language of the user.""" + return self.info.language + + @property + def client_name(self) -> str: + """Get the client name of the user.""" + return self.info.client_name + + @property + def client_version(self) -> str: + """Get the client version of the user.""" + return self.info.client_version + + @property + def client_guid(self) -> str: + """Get the client guid of the user.""" + return self.info.client_guid + + @property + def total_streams(self) -> int: + """Get the total number of streams the user has.""" + return len(self.streams) + + @property + def stream_ids(self) -> list[int]: + """Get the ids of the streams the user has.""" + return list(self.streams.keys()) + + def add_stream(self, stream: StreamRegister): + """Adds a stream to the user. + + :param stream: The stream to add. + """ + if stream.type is StreamType.ACTOR: + filename = TruckFile.from_filename( + TRUCK_TO_NAME_FILE, + stream.name + ) + stream.actor_type = filename.type + elif stream.type is StreamType.CHARACTER: + self.character_stream_id = stream.origin_stream_id + elif stream.type is StreamType.CHAT: + self.chat_stream_id = stream.origin_stream_id + self.streams[stream.origin_stream_id] = stream + + def delete_stream(self, stream_id: int): + """Deletes a stream from the user. + + :param stream_id: The stream id of the stream to delete. + """ + stream = self.streams.pop(stream_id) + + if stream.origin_stream_id == self.character_stream_id: + self.character_stream_id = -1 + elif stream.origin_stream_id == self.chat_stream_id: + self.chat_stream_id = -1 + + def get_stream(self, stream_id: int) -> StreamRegister: + """Gets a stream from the user. + + :param stream_id: The stream id of the stream to get. + :return: The stream. + """ + try: + return self.streams[stream_id] + except KeyError as e: + raise StreamNotFoundError(stream_id) from e + + def get_current_stream(self) -> StreamRegister: + """Gets the current stream of the user. + + :return: The current stream of the user. + """ + return self.streams[self.current_stream.stream_id] + + def set_current_stream(self, actor_uid: int, stream_id: int): + """Sets the current stream of the user. + + :param actor_uid: The uid of the actor. + :param sid: The sid of the stream. + """ + self.current_stream.unique_id = actor_uid + self.current_stream.stream_id = stream_id + + if ( + stream_id != self.character_stream_id + or self.info.unique_id != actor_uid + ): + self.set_position(self.character_stream_id, Vector3()) + + def set_position(self, sid: int, position: Vector3): + """Sets the position of the user. + + :param sid: The sid of the stream. + :param position: The position of the user. + """ + stream = self.streams[sid] + + if ( + (-1, -1, -1) < position < (1, 1, 1) + or (-1, -1, -1) < stream.position < (1, 1, 1) + ): + stream.position = position + + distance_meters = position.distance(stream.position) + if distance_meters < 10: + if stream.type is StreamType.CHARACTER: + self.stats.meters_walked += distance_meters + elif stream.type is StreamType.ACTOR: + if stream.actor_type in ( + ActorType.CAR, + ActorType.TRUCK, + ActorType.TRAIN + ): + self.stats.meters_driven += distance_meters + elif stream.actor_type is ActorType.BOAT: + self.stats.meters_sailed += distance_meters + elif stream.actor_type is ActorType.AIRPLANE: + self.stats.meters_flown += distance_meters + else: + stream.position = position + + def get_position(self, sid: int | None) -> Vector3: + """Gets the position of the user. + + :param sid: The sid of the stream. + :return: The position of the user. + """ + if sid is None: + return self.get_current_stream().position + return self.streams[sid].position + + def set_rotation(self, sid: int, rotation: Vector4): + """Sets the rotation of the user. + + :param sid: The sid of the stream. + :param rotation: The rotation of the user. + """ + self.streams[sid].rotation = rotation + + def get_rotation(self, sid: int | None) -> Vector4: + """Gets the rotation of the user. + + :param sid: The sid of the stream. + :return: The rotation of the user. + """ + if sid is None: + return self.get_current_stream().rotation + return self.streams[sid].rotation