Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjustments to work better with cloudflare and reverse proxies #110

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
406916b
Make IP property of client again and log it on connection
OmniTroid Nov 22, 2023
eb5b0d3
Add ProxyManager
OmniTroid Nov 22, 2023
9f319c7
Add 127.0.0.1 as whitelisted IP address
OmniTroid Nov 25, 2023
c98d5e4
Clarify whitelist comment
OmniTroid Nov 25, 2023
044de6c
Fix type confusion
OmniTroid Nov 25, 2023
15cea5c
Add clause if WS and behind proxy, check legitimacy
OmniTroid Nov 25, 2023
db3072c
Add debug stuff to proxymanager
OmniTroid Nov 25, 2023
487fc3b
Listen on wss port
OmniTroid Nov 25, 2023
db31ffc
Remove wss listening
OmniTroid Nov 26, 2023
d8ee313
Implement proxy manager as singleton
OmniTroid Nov 26, 2023
36d204b
Rework WS transport
OmniTroid Nov 26, 2023
6c5ce67
Fix instance issues with proxymanager and rename key function
OmniTroid Nov 26, 2023
f9fab05
Move everything related to rejecting connections to client manager
OmniTroid Nov 26, 2023
679a53f
Add authorized_proxies to config
OmniTroid Nov 26, 2023
b253adb
Make sure clientmanager loads ip ranges and put it in try block
OmniTroid Nov 26, 2023
b39b3a5
Remove get_extra_info overload and move x-forwarded-for into client m…
OmniTroid Nov 26, 2023
e79c31d
Load authorized proxies from config too
OmniTroid Nov 26, 2023
20476a4
Remove unauthorizedproxyexception
OmniTroid Nov 26, 2023
eb2952f
Move get client_ip and rangebans into functions
OmniTroid Nov 26, 2023
9b3b2d1
Remove proxymanager as singleton
OmniTroid Nov 26, 2023
4ac9f08
add comment about IP bans
OmniTroid Nov 26, 2023
9745a6a
Remove unused imports
OmniTroid Nov 26, 2023
da12c65
Handle case where X-Forwarded-For contains multiple IPs
OmniTroid Nov 28, 2023
ed6b7b2
Merge branch 'master' into rproxy-adjustments
OmniTroid Nov 28, 2023
ef907df
Merge branch 'master' into rproxy-adjustments
OmniTroid Aug 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions config_sample/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,10 @@ need_webhook:
title: "AO NEED CALL"
message: "A user has used the /need command for players in the Attorney Online server!"
url:

# Additional list of proxies that are allowed to connect to the server.
# This is useful if you have a custom reverse proxy in front of the server.
# Localhost and cloudflare are allowed by default (see proxy_manager.py).
# authorized_proxies:
# - 1.1.1.1
# - 2.2.2.2
2 changes: 1 addition & 1 deletion server/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,7 @@ def add_to_judgelog(self, client, msg):
"""
if len(self.judgelog) >= 10:
self.judgelog = self.judgelog[1:]
self.judgelog.append(f"{client.char_name} ({client.ip}) {msg}.")
self.judgelog.append(f"{client.char_name} ({client.ipid}) {msg}.")

def add_music_playing(self, client, name, showname="", autoplay=None):
"""
Expand Down
100 changes: 91 additions & 9 deletions server/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
import math
import os
from heapq import heappop, heappush
import logging
from pathlib import Path
import json


from server import database
from server.constants import TargetType, encode_ao_packet, contains_URL, derelative
from server.exceptions import ClientError, AreaError, ServerError
from server.network.aoprotocol_ws import AOProtocolWS

import oyaml as yaml # ordered yaml
import json
import geoip2.database


logger = logging.getLogger(__name__)


class ClientManager:
Expand Down Expand Up @@ -49,6 +56,7 @@ def __init__(self, server, transport, user_id, ipid):
self.pm_mute = False
self.mod_call_time = 0
self.ipid = ipid
self.ip = ""
self.version = ""
self.software = ""

Expand Down Expand Up @@ -1787,11 +1795,6 @@ def auth_mod(self, password):
else:
raise ClientError("Invalid password.")

@property
def ip(self):
"""Get an anonymized version of the IP address."""
return self.ipid

@property
def char_name(self):
"""Get the name of the character that the client is using."""
Expand Down Expand Up @@ -2050,6 +2053,18 @@ def __init__(self, server):
self.clients = set()
self.server = server
self.cur_id = [i for i in range(self.server.config["playerlimit"])]
self.ipRange_bans = []

try:
self.geoIpReader = geoip2.database.Reader(
"./storage/GeoLite2-ASN.mmdb")
self.useGeoIp = True
# if you're on debian and the geoip-database-extra package is installed
# you can use /usr/share/GeoIP/GeoIPASNum.dat instead
except FileNotFoundError:
self.useGeoIp = False

self.load_ipranges()

def new_client_preauth(self, client):
maxclients = self.server.config["multiclient_limit"]
Expand All @@ -2070,10 +2085,18 @@ def new_client(self, transport):
transport.write(b"BD#This server is full.#%")
raise ClientError

peername = transport.get_extra_info("peername")[0]
client_ip = self.get_client_ip(transport)

# TODO: Should probably check if the IP is specifically banned here?

if self.is_ip_rangebanned(client_ip):
msg = f"BD#Rangebanned IP: {client_ip}#%"
transport.write(msg.encode("utf-8"))
raise ClientError

c = self.Client(self.server, transport, user_id,
database.ipid(peername))
database.ipid(client_ip))
c.ip = client_ip
self.clients.add(c)
temp_ipid = c.ipid
for client in self.server.client_manager.clients:
Expand Down Expand Up @@ -2226,7 +2249,66 @@ def get_multiclients(self, ipid=-1, hdid=""):

def get_mods(self):
return [c for c in self.clients if c.is_mod]


def get_client_ip(self, transport) -> str:
"""Gets the real IP of the client."""
if not isinstance(transport, AOProtocolWS.WSTransport):
# This means the client is connecting with TCP, so just return the IP
return transport.get_extra_info("peername")[0]

# Using websockets, so use property in the websocket object
client_ip = transport.ws.remote_address[0]
if 'X-Forwarded-For' not in transport.ws.request_headers:
# Client doesn't claim to be behind a proxy, so all looks ok
return client_ip

# This means the client claims to be behind a reverse proxy
# However, we can't trust this information and need to check the proxy IP against a whitelist
proxy_ip = client_ip
# X-Forwarded-For may contain a comma-delimited list of IPs, so get the first one
claimed_client_ip = transport.ws.request_headers['X-Forwarded-For'].split(',')[0].strip()
if not self.server.proxy_manager.is_ip_authorized_as_proxy(proxy_ip):
msg = f"Unauthorized proxy detected. Proxy IP: {proxy_ip}. Client IP: {claimed_client_ip}."
logging.warning(
msg,
proxy_ip, claimed_client_ip)

ban_msg = f"BD#{msg}#%"

transport.write(ban_msg.encode("utf-8"))
raise ClientError

# The proxy is authorized, so we can trust the claimed client IP
return claimed_client_ip

def is_ip_rangebanned(self, client_ip: str) -> bool:
if self.useGeoIp:
try:
geo_ip_response = self.geoIpReader.asn(client_ip)
asn = str(geo_ip_response.autonomous_system_number)
except geoip2.errors.AddressNotFoundError:
asn = "Loopback"
pass
else:
asn = "Loopback"

for line, rangeBan in enumerate(self.ipRange_bans):
if rangeBan != "" and ((client_ip.startswith(rangeBan) and (rangeBan.endswith('.') or rangeBan.endswith(':'))) or asn == rangeBan):
return True

return False

def load_ipranges(self):
"""Load a list of banned IP ranges."""
path = Path("config/iprange_ban.txt")

if not path.is_file():
logger.debug("Cannot find iprange_ban.txt")
return

with open("config/iprange_ban.txt", "r", encoding="utf-8") as f:
self.ipRange_bans.extend(f.read().splitlines())

class BattleChar:
def __init__(self, client, fighter_name, fighter):
self.fighter = fighter_name
Expand Down
2 changes: 1 addition & 1 deletion server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def log_area(self, event_subtype, client, area, message=None, target=None):
def log_connect(self, client, failed=False):
"""Log a connect attempt."""
logger.info(
f"{client.ipid} (HDID: {client.hdid}) "
f"(ID: {client.ipid}, HDID: {client.hdid}, IP: {client.ip}) "
+ f'{"was blocked from connecting" if failed else "connected"}.'
)
with self.db as conn:
Expand Down
2 changes: 1 addition & 1 deletion server/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_client_info(client):
"""Prepend information about a client to a log entry."""
if client is None:
return ""
ipid = client.ip
ipid = client.ipid
prefix = f"[{ipid:<15}][{client.id:<3}][{client.name}]"
if client.is_mod:
prefix += "[MOD]"
Expand Down
8 changes: 4 additions & 4 deletions server/network/aoprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,7 +1918,7 @@ def net_cmd_zz(self, args):
"[{} UTC] {} ({}) in hub {} [{}]{} without reason (not using 2.6?)".format(
current_time,
self.client.char_name,
self.client.ip,
self.client.ipid,
self.client.area.area_manager.name,
self.client.area.abbreviation,
self.client.area.name,
Expand All @@ -1928,15 +1928,15 @@ def net_cmd_zz(self, args):
self.client.set_mod_call_delay()
database.log_area("modcall", self.client, self.client.area)
self.server.webhooks.modcall(
char=self.client.char_name, ipid=self.client.ip, area=self.client.area
char=self.client.char_name, ipid=self.client.ipid, area=self.client.area
)
else:
self.server.send_all_cmd_pred(
"ZZ",
"[{} UTC] {} ({}) in hub {} [{}]{} with reason: {}".format(
current_time,
self.client.char_name,
self.client.ip,
self.client.ipid,
self.client.area.area_manager.name,
self.client.area.abbreviation,
self.client.area.name,
Expand All @@ -1949,7 +1949,7 @@ def net_cmd_zz(self, args):
self.client.area, message=args[0])
self.server.webhooks.modcall(
char=self.client.char_name,
ipid=self.client.ip,
ipid=self.client.ipid,
area=self.client.area,
reason=args[0][:100],
)
Expand Down
29 changes: 6 additions & 23 deletions server/network/aoprotocol_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,15 @@
class AOProtocolWS(AOProtocol):
"""A websocket wrapper around AOProtocol."""

class TransportWrapper:
"""A class to wrap asyncio's Transport class."""
class WSTransport(asyncio.Transport):
"""A subclass of asyncio's Transport class to handle websocket connections."""

def __init__(self, websocket):
super().__init__()
self.ws = websocket

def get_extra_info(self, key):
"""Get extra info about the client.
Used for getting the remote address.

:param key: requested key

"""
remote_address = self.ws.remote_address
if (remote_address[0] == "127.0.0.1"):
# See if proxy
try:
remote_address = (
self.ws.request_headers['X-Forwarded-For'], 0)
except Exception:
pass
info = {"peername": remote_address}
return info[key]

def write(self, message):
"""Write message to the socket.
"""Write message to the socket. Overrides asyncio.Transport.write.

:param message: message in bytes

Expand All @@ -42,7 +25,7 @@ def write(self, message):
asyncio.ensure_future(self.ws_try_writing_message(message))

def close(self):
"""Disconnect the client by force."""
"""Disconnect the client by force. Overrides asyncio.Transport.close."""
asyncio.ensure_future(self.ws.close())

async def ws_try_writing_message(self, message):
Expand All @@ -64,7 +47,7 @@ def __init__(self, server, websocket):

def ws_on_connect(self):
"""Handle a new client connection."""
self.connection_made(self.TransportWrapper(self.ws))
self.connection_made(self.WSTransport(self.ws))

async def ws_handle(self):
try:
Expand Down
73 changes: 73 additions & 0 deletions server/network/proxy_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
import ipaddress

import aiohttp

logger = logging.getLogger(__name__)


# Proxy Manager, authorizes IPs that claim to be proxies. Implemented as singleton
class ProxyManager:
def __init__(self, server):
# IP addresses that are whitelisted for use as proxies
self.ip_whitelist = []
self.server = server

async def init(self):
# Localhost is always a trusted IP address
# This important for setups using CloudFlare tunnels
# See https://www.cloudflare.com/products/tunnel/
self.ip_whitelist.append("127.0.0.1")

cloudflare_ips = await self.get_cloudflare_ips()
self.ip_whitelist.extend(cloudflare_ips)

if 'authorized_proxies' in self.server.config and \
isinstance(self.server.config['authorized_proxies'], list):
self.ip_whitelist.extend(self.server.config['authorized_proxies'])

def is_ip_authorized_as_proxy(self, ip: str) -> bool:
"""
Check if the specified IP address is authorized for use as a proxy.
"""
# Convert the given IP address to an ipaddress.IPv4Address or ipaddress.IPv6Address object
try:
ip_address = ipaddress.ip_address(ip)
except ValueError:
# The given IP is not a valid IP address
return False

for entry in self.ip_whitelist:
try:
# Try to parse the entry as a CIDR block
network = ipaddress.ip_network(entry, strict=False)
# Check if the IP address is within the CIDR block
if ip_address in network:
logger.debug('IP address %s is approved for use as a proxy. Found CIDR match in %s',
ip, network)
return True
except ValueError:
try:
entry_ip_address = ipaddress.ip_address(entry)
except ValueError:
# The entry is not a valid IP address, skip it
continue

# Check if the IP address matches the entry
if ip_address == entry_ip_address:
logger.debug('IP address %s is approved for use as a proxy. Found exact match in %s',
ip_address, entry_ip_address)
return True

return False

@staticmethod
async def get_cloudflare_ips() -> [str]:
async with aiohttp.ClientSession() as session:
async with session.get('https://www.cloudflare.com/ips-v4/#') as response:
if response.status == 200:
response_data = await response.text()
return response_data.splitlines()
else:
logger.error('Failed to get Cloudflare IPs: %s, %s', response.status, response.text)
return []
Loading
Loading