diff --git a/config_sample/config.yaml b/config_sample/config.yaml index 34a77ebd..3f1038b5 100644 --- a/config_sample/config.yaml +++ b/config_sample/config.yaml @@ -174,3 +174,10 @@ need_webhook: title: "AO NEED CALL" message: "A user has used the /need command for players in the Attorney Online server!" url: + +# Additional list of proxies that are allowed to connect to the server. +# This is useful if you have a custom reverse proxy in front of the server. +# Localhost and cloudflare are allowed by default (see proxy_manager.py). +# authorized_proxies: +# - 1.1.1.1 +# - 2.2.2.2 \ No newline at end of file diff --git a/server/area.py b/server/area.py index c0df23f8..46ed56dc 100644 --- a/server/area.py +++ b/server/area.py @@ -1777,7 +1777,7 @@ def add_to_judgelog(self, client, msg): """ if len(self.judgelog) >= 10: self.judgelog = self.judgelog[1:] - self.judgelog.append(f"{client.char_name} ({client.ip}) {msg}.") + self.judgelog.append(f"{client.char_name} ({client.ipid}) {msg}.") def add_music_playing(self, client, name, showname="", autoplay=None): """ diff --git a/server/client_manager.py b/server/client_manager.py index 36f658bc..a927a45a 100644 --- a/server/client_manager.py +++ b/server/client_manager.py @@ -4,14 +4,21 @@ import math import os from heapq import heappop, heappush +import logging +from pathlib import Path +import json from server import database from server.constants import TargetType, encode_ao_packet, contains_URL, derelative from server.exceptions import ClientError, AreaError, ServerError +from server.network.aoprotocol_ws import AOProtocolWS import oyaml as yaml # ordered yaml -import json +import geoip2.database + + +logger = logging.getLogger(__name__) class ClientManager: @@ -49,6 +56,7 @@ def __init__(self, server, transport, user_id, ipid): self.pm_mute = False self.mod_call_time = 0 self.ipid = ipid + self.ip = "" self.version = "" self.software = "" @@ -1787,11 +1795,6 @@ def auth_mod(self, password): else: raise ClientError("Invalid password.") - @property - def ip(self): - """Get an anonymized version of the IP address.""" - return self.ipid - @property def char_name(self): """Get the name of the character that the client is using.""" @@ -2050,6 +2053,18 @@ def __init__(self, server): self.clients = set() self.server = server self.cur_id = [i for i in range(self.server.config["playerlimit"])] + self.ipRange_bans = [] + + try: + self.geoIpReader = geoip2.database.Reader( + "./storage/GeoLite2-ASN.mmdb") + self.useGeoIp = True + # if you're on debian and the geoip-database-extra package is installed + # you can use /usr/share/GeoIP/GeoIPASNum.dat instead + except FileNotFoundError: + self.useGeoIp = False + + self.load_ipranges() def new_client_preauth(self, client): maxclients = self.server.config["multiclient_limit"] @@ -2070,10 +2085,18 @@ def new_client(self, transport): transport.write(b"BD#This server is full.#%") raise ClientError - peername = transport.get_extra_info("peername")[0] + client_ip = self.get_client_ip(transport) + + # TODO: Should probably check if the IP is specifically banned here? + + if self.is_ip_rangebanned(client_ip): + msg = f"BD#Rangebanned IP: {client_ip}#%" + transport.write(msg.encode("utf-8")) + raise ClientError c = self.Client(self.server, transport, user_id, - database.ipid(peername)) + database.ipid(client_ip)) + c.ip = client_ip self.clients.add(c) temp_ipid = c.ipid for client in self.server.client_manager.clients: @@ -2226,7 +2249,66 @@ def get_multiclients(self, ipid=-1, hdid=""): def get_mods(self): return [c for c in self.clients if c.is_mod] - + + def get_client_ip(self, transport) -> str: + """Gets the real IP of the client.""" + if not isinstance(transport, AOProtocolWS.WSTransport): + # This means the client is connecting with TCP, so just return the IP + return transport.get_extra_info("peername")[0] + + # Using websockets, so use property in the websocket object + client_ip = transport.ws.remote_address[0] + if 'X-Forwarded-For' not in transport.ws.request_headers: + # Client doesn't claim to be behind a proxy, so all looks ok + return client_ip + + # This means the client claims to be behind a reverse proxy + # However, we can't trust this information and need to check the proxy IP against a whitelist + proxy_ip = client_ip + # X-Forwarded-For may contain a comma-delimited list of IPs, so get the first one + claimed_client_ip = transport.ws.request_headers['X-Forwarded-For'].split(',')[0].strip() + if not self.server.proxy_manager.is_ip_authorized_as_proxy(proxy_ip): + msg = f"Unauthorized proxy detected. Proxy IP: {proxy_ip}. Client IP: {claimed_client_ip}." + logging.warning( + msg, + proxy_ip, claimed_client_ip) + + ban_msg = f"BD#{msg}#%" + + transport.write(ban_msg.encode("utf-8")) + raise ClientError + + # The proxy is authorized, so we can trust the claimed client IP + return claimed_client_ip + + def is_ip_rangebanned(self, client_ip: str) -> bool: + if self.useGeoIp: + try: + geo_ip_response = self.geoIpReader.asn(client_ip) + asn = str(geo_ip_response.autonomous_system_number) + except geoip2.errors.AddressNotFoundError: + asn = "Loopback" + pass + else: + asn = "Loopback" + + for line, rangeBan in enumerate(self.ipRange_bans): + if rangeBan != "" and ((client_ip.startswith(rangeBan) and (rangeBan.endswith('.') or rangeBan.endswith(':'))) or asn == rangeBan): + return True + + return False + + def load_ipranges(self): + """Load a list of banned IP ranges.""" + path = Path("config/iprange_ban.txt") + + if not path.is_file(): + logger.debug("Cannot find iprange_ban.txt") + return + + with open("config/iprange_ban.txt", "r", encoding="utf-8") as f: + self.ipRange_bans.extend(f.read().splitlines()) + class BattleChar: def __init__(self, client, fighter_name, fighter): self.fighter = fighter_name diff --git a/server/database.py b/server/database.py index ab9e56b2..bc397749 100644 --- a/server/database.py +++ b/server/database.py @@ -472,7 +472,7 @@ def log_area(self, event_subtype, client, area, message=None, target=None): def log_connect(self, client, failed=False): """Log a connect attempt.""" logger.info( - f"{client.ipid} (HDID: {client.hdid}) " + f"(ID: {client.ipid}, HDID: {client.hdid}, IP: {client.ip}) " + f'{"was blocked from connecting" if failed else "connected"}.' ) with self.db as conn: diff --git a/server/logger.py b/server/logger.py index a2a0e083..a28ec37e 100644 --- a/server/logger.py +++ b/server/logger.py @@ -36,7 +36,7 @@ def parse_client_info(client): """Prepend information about a client to a log entry.""" if client is None: return "" - ipid = client.ip + ipid = client.ipid prefix = f"[{ipid:<15}][{client.id:<3}][{client.name}]" if client.is_mod: prefix += "[MOD]" diff --git a/server/network/aoprotocol.py b/server/network/aoprotocol.py index f18a4b43..18767171 100644 --- a/server/network/aoprotocol.py +++ b/server/network/aoprotocol.py @@ -1918,7 +1918,7 @@ def net_cmd_zz(self, args): "[{} UTC] {} ({}) in hub {} [{}]{} without reason (not using 2.6?)".format( current_time, self.client.char_name, - self.client.ip, + self.client.ipid, self.client.area.area_manager.name, self.client.area.abbreviation, self.client.area.name, @@ -1928,7 +1928,7 @@ def net_cmd_zz(self, args): self.client.set_mod_call_delay() database.log_area("modcall", self.client, self.client.area) self.server.webhooks.modcall( - char=self.client.char_name, ipid=self.client.ip, area=self.client.area + char=self.client.char_name, ipid=self.client.ipid, area=self.client.area ) else: self.server.send_all_cmd_pred( @@ -1936,7 +1936,7 @@ def net_cmd_zz(self, args): "[{} UTC] {} ({}) in hub {} [{}]{} with reason: {}".format( current_time, self.client.char_name, - self.client.ip, + self.client.ipid, self.client.area.area_manager.name, self.client.area.abbreviation, self.client.area.name, @@ -1949,7 +1949,7 @@ def net_cmd_zz(self, args): self.client.area, message=args[0]) self.server.webhooks.modcall( char=self.client.char_name, - ipid=self.client.ip, + ipid=self.client.ipid, area=self.client.area, reason=args[0][:100], ) diff --git a/server/network/aoprotocol_ws.py b/server/network/aoprotocol_ws.py index 99bbc084..e60fab0c 100644 --- a/server/network/aoprotocol_ws.py +++ b/server/network/aoprotocol_ws.py @@ -8,32 +8,15 @@ class AOProtocolWS(AOProtocol): """A websocket wrapper around AOProtocol.""" - class TransportWrapper: - """A class to wrap asyncio's Transport class.""" + class WSTransport(asyncio.Transport): + """A subclass of asyncio's Transport class to handle websocket connections.""" def __init__(self, websocket): + super().__init__() self.ws = websocket - def get_extra_info(self, key): - """Get extra info about the client. - Used for getting the remote address. - - :param key: requested key - - """ - remote_address = self.ws.remote_address - if (remote_address[0] == "127.0.0.1"): - # See if proxy - try: - remote_address = ( - self.ws.request_headers['X-Forwarded-For'], 0) - except Exception: - pass - info = {"peername": remote_address} - return info[key] - def write(self, message): - """Write message to the socket. + """Write message to the socket. Overrides asyncio.Transport.write. :param message: message in bytes @@ -42,7 +25,7 @@ def write(self, message): asyncio.ensure_future(self.ws_try_writing_message(message)) def close(self): - """Disconnect the client by force.""" + """Disconnect the client by force. Overrides asyncio.Transport.close.""" asyncio.ensure_future(self.ws.close()) async def ws_try_writing_message(self, message): @@ -64,7 +47,7 @@ def __init__(self, server, websocket): def ws_on_connect(self): """Handle a new client connection.""" - self.connection_made(self.TransportWrapper(self.ws)) + self.connection_made(self.WSTransport(self.ws)) async def ws_handle(self): try: diff --git a/server/network/proxy_manager.py b/server/network/proxy_manager.py new file mode 100644 index 00000000..05a4023a --- /dev/null +++ b/server/network/proxy_manager.py @@ -0,0 +1,73 @@ +import logging +import ipaddress + +import aiohttp + +logger = logging.getLogger(__name__) + + +# Proxy Manager, authorizes IPs that claim to be proxies. Implemented as singleton +class ProxyManager: + def __init__(self, server): + # IP addresses that are whitelisted for use as proxies + self.ip_whitelist = [] + self.server = server + + async def init(self): + # Localhost is always a trusted IP address + # This important for setups using CloudFlare tunnels + # See https://www.cloudflare.com/products/tunnel/ + self.ip_whitelist.append("127.0.0.1") + + cloudflare_ips = await self.get_cloudflare_ips() + self.ip_whitelist.extend(cloudflare_ips) + + if 'authorized_proxies' in self.server.config and \ + isinstance(self.server.config['authorized_proxies'], list): + self.ip_whitelist.extend(self.server.config['authorized_proxies']) + + def is_ip_authorized_as_proxy(self, ip: str) -> bool: + """ + Check if the specified IP address is authorized for use as a proxy. + """ + # Convert the given IP address to an ipaddress.IPv4Address or ipaddress.IPv6Address object + try: + ip_address = ipaddress.ip_address(ip) + except ValueError: + # The given IP is not a valid IP address + return False + + for entry in self.ip_whitelist: + try: + # Try to parse the entry as a CIDR block + network = ipaddress.ip_network(entry, strict=False) + # Check if the IP address is within the CIDR block + if ip_address in network: + logger.debug('IP address %s is approved for use as a proxy. Found CIDR match in %s', + ip, network) + return True + except ValueError: + try: + entry_ip_address = ipaddress.ip_address(entry) + except ValueError: + # The entry is not a valid IP address, skip it + continue + + # Check if the IP address matches the entry + if ip_address == entry_ip_address: + logger.debug('IP address %s is approved for use as a proxy. Found exact match in %s', + ip_address, entry_ip_address) + return True + + return False + + @staticmethod + async def get_cloudflare_ips() -> [str]: + async with aiohttp.ClientSession() as session: + async with session.get('https://www.cloudflare.com/ips-v4/#') as response: + if response.status == 200: + response_data = await response.text() + return response_data.splitlines() + else: + logger.error('Failed to get Cloudflare IPs: %s, %s', response.status, response.text) + return [] diff --git a/server/tsuserver.py b/server/tsuserver.py index 846d0699..ea80427e 100644 --- a/server/tsuserver.py +++ b/server/tsuserver.py @@ -5,7 +5,6 @@ import traceback import websockets -import geoip2.database import yaml import server.logger @@ -14,12 +13,13 @@ from server.client_manager import ClientManager from server.emotes import Emotes from server.discordbot import Bridgebot -from server.exceptions import ClientError, ServerError +from server.exceptions import ServerError from server.network.aoprotocol import AOProtocol from server.network.aoprotocol_ws import new_websocket_client from server.network.masterserverclient import MasterServerClient from server.network.webhooks import Webhooks from server.constants import remove_URL, dezalgo +from server.network.proxy_manager import ProxyManager logger = logging.getLogger("main") @@ -45,7 +45,6 @@ def __init__(self): self.backgrounds_categories = None self.server_links = None self.zalgo_tolerance = None - self.ipRange_bans = [] self.geoIpReader = None self.useGeoIp = False self.need_webhook = False @@ -69,15 +68,7 @@ def __init__(self): "y_offset", ] self.command_aliases = {} - - try: - self.geoIpReader = geoip2.database.Reader( - "./storage/GeoLite2-ASN.mmdb") - self.useGeoIp = True - # on debian systems you can use /usr/share/GeoIP/GeoIPASNum.dat if the geoip-database-extra package is installed - except FileNotFoundError: - self.useGeoIp = False - + self.proxy_manager = ProxyManager(self) self.ms_client = None sys.setrecursionlimit(50) try: @@ -89,7 +80,7 @@ def __init__(self): self.load_music() self.load_backgrounds() self.load_server_links() - self.load_ipranges() + self.client_manager = ClientManager(self) self.hub_manager = HubManager(self) except yaml.YAMLError: print("There was a syntax error parsing a configuration file:") @@ -106,7 +97,6 @@ def __init__(self): print("Please check sample config files for the correct format.") sys.exit(1) - self.client_manager = ClientManager(self) server.logger.setup_logging(debug=self.config["debug"]) self.webhooks = Webhooks(self) @@ -161,6 +151,8 @@ def start(self): asyncio.ensure_future(self.schedule_unbans()) + asyncio.ensure_future(self.proxy_manager.init(), loop=loop) + database.log_misc("start") print("Server started and is listening on port {}".format( self.config["port"])) @@ -194,28 +186,6 @@ def new_client(self, transport): :param transport: asyncio transport :returns: created client object """ - peername = transport.get_extra_info("peername")[0] - - if self.useGeoIp: - try: - geoIpResponse = self.geoIpReader.asn(peername) - asn = str(geoIpResponse.autonomous_system_number) - except geoip2.errors.AddressNotFoundError: - asn = "Loopback" - pass - else: - asn = "Loopback" - - for line, rangeBan in enumerate(self.ipRange_bans): - if rangeBan != "" and ((peername.startswith(rangeBan) and (rangeBan.endswith('.') or rangeBan.endswith(':'))) or asn == rangeBan): - msg = "BD#" - msg += "Abuse\r\n" - msg += f"ID: {line}\r\n" - msg += "Until: N/A" - msg += "#%" - - transport.write(msg.encode("utf-8")) - raise ClientError c = self.client_manager.new_client(transport) c.server = self @@ -353,14 +323,6 @@ def load_iniswaps(self): except Exception: logger.debug("Cannot find iniswaps.yaml") - def load_ipranges(self): - """Load a list of banned IP ranges.""" - try: - with open("config/iprange_ban.txt", "r", encoding="utf-8") as ipranges: - self.ipRange_bans = ipranges.read().splitlines() - except Exception: - logger.debug("Cannot find iprange_ban.txt") - def load_music_list(self): try: with open("config/music.yaml", "r", encoding="utf-8") as music: