From 63ca0f0c918aa4d9d2162453ed4d73a7e12f6fbb Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Mon, 15 Jan 2024 16:02:50 +1000 Subject: [PATCH 1/9] Start cleanup for 1.0.1, namespace update --- config.yaml | 22 +++-- lib/database.py | 16 ++-- lib/diameter.py | 63 +++++---------- lib/diameterAsync.py | 5 +- lib/logtool.py | 6 +- lib/messaging.py | 155 ++++++++++++++---------------------- lib/messagingAsync.py | 108 +++++++++++++++++-------- services/apiService.py | 2 +- services/diameterService.py | 8 +- services/georedService.py | 8 +- services/hssService.py | 8 +- services/logService.py | 6 +- services/metricService.py | 4 +- 13 files changed, 206 insertions(+), 205 deletions(-) diff --git a/config.yaml b/config.yaml index d83374da..5e053087 100644 --- a/config.yaml +++ b/config.yaml @@ -95,7 +95,6 @@ logging: diameter_logging_file: /var/log/pyhss_diameter.log geored_logging_file: /var/log/pyhss_geored.log metric_logging_file: /var/log/pyhss_metrics.log - log_to_terminal: True sqlalchemy_sql_echo: False sqlalchemy_pool_recycle: 15 sqlalchemy_pool_size: 30 @@ -113,7 +112,7 @@ database: webhooks: enabled: False endpoints: - - http://127.0.0.1:8181 + - 'http://127.0.0.1:8181' ## Geographic Redundancy Parameters geored: @@ -123,16 +122,23 @@ geored: - 'http://hss01.mnc001.mcc001.3gppnetwork.org:8080' - 'http://hss02.mnc001.mcc001.3gppnetwork.org:8080' -#Redis is required to run PyHSS. A locally running instance is recommended for production. +#Redis is required to run PyHSS. An instance running on a local network is recommended for production. redis: - # Whether to use a UNIX socket instead of a tcp connection to redis. Host and port is ignored if useUnixSocket is True. - useUnixSocket: False + # Which connection type to attempt. Valid options are: tcp, unix, sentinel + # tcp - Connection via a standard TCP socket to a given host and port. + # unix - Connect to redis via a unix socket, provided by unixSocketPath. + # sentinel - Connect to one or more redis sentinel hosts. + connectionType: "tcp" unixSocketPath: '/var/run/redis/redis-server.sock' host: localhost port: 6379 - # [Deprecated] Additional peers to query for roaming SOS subscribers - additionalPeers: - - "redis2.mnc001.mcc001.3gppnetwork.org:6379" + sentinel: + masterName: exampleMaster + hosts: + - exampleSentinel.mnc001.mcc001.3gppnetwork.org: + port: 6379 + password: '' + prometheus: enabled: False diff --git a/lib/database.py b/lib/database.py index ee673fb5..013299a1 100755 --- a/lib/database.py +++ b/lib/database.py @@ -18,6 +18,7 @@ from messaging import RedisMessaging import yaml import json +import socket import traceback with open("../config.yaml", 'r') as stream: @@ -331,7 +332,8 @@ def __init__(self, logTool, redisMessaging=None): db_string = 'postgresql+psycopg2://' + str(self.config['database']['username']) + ':' + str(self.config['database']['password']) + '@' + str(self.config['database']['server']) + '/' + str(self.config['database']['database']) else: db_string = 'mysql://' + str(self.config['database']['username']) + ':' + str(self.config['database']['password']) + '@' + str(self.config['database']['server']) + '/' + str(self.config['database']['database'] + "?autocommit=true") - + + self.hostname = socket.gethostname() self.engine = create_engine( db_string, @@ -351,7 +353,7 @@ def __init__(self, logTool, redisMessaging=None): #Load IMEI TAC database into Redis if enabled if ('tac_database_csv' in self.config['eir']): self.load_IMEI_database_into_Redis() - self.tacData = json.loads(self.redisMessaging.getValue(key="tacDatabase")) + self.tacData = json.loads(self.redisMessaging.getValue(key="tacDatabase", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='database')) else: self.logTool.log(service='Database', level='info', message="Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config", redisClient=self.redisMessaging) self.tacData = {} @@ -388,7 +390,7 @@ def load_IMEI_database_into_Redis(self): if count == 0: self.logTool.log(service='Database', level='info', message="Checking to see if entries are already present...", redisClient=self.redisMessaging) - redis_imei_result = self.redisMessaging.getValue(key="tacDatabase") + redis_imei_result = self.redisMessaging.getValue(key="tacDatabase", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='database') if redis_imei_result is not None: if len(redis_imei_result) > 0: self.logTool.log(service='Database', level='info', message="IMEI TAC Database already loaded into Redis - Skipping reading from file...", redisClient=self.redisMessaging) @@ -396,7 +398,7 @@ def load_IMEI_database_into_Redis(self): self.logTool.log(service='Database', level='info', message="No data loaded into Redis, proceeding to load...", redisClient=self.redisMessaging) tacList['tacList'].append({str(tacPrefix): {'name': name, 'model': model}}) count += 1 - self.redisMessaging.setValue(key="tacDatabase", value=json.dumps(tacList)) + self.redisMessaging.setValue(key="tacDatabase", value=json.dumps(tacList), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='database') self.tacData = tacList self.logTool.log(service='Database', level='info', message="Loaded " + str(count) + " IMEI TAC entries into Redis", redisClient=self.redisMessaging) except Exception as E: @@ -919,14 +921,14 @@ def handleGeored(self, jsonData, operation: str="PATCH", asymmetric: bool=False, georedDict['body'] = jsonData georedDict['operation'] = operation georedDict['timestamp'] = time.time_ns() - self.redisMessaging.sendMessage(queue=f'geored', message=json.dumps(georedDict), queueExpiry=120) + self.redisMessaging.sendMessage(queue=f'geored', message=json.dumps(georedDict), queueExpiry=120, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='geored') if asymmetric: if len(asymmetricUrls) > 0: georedDict['body'] = jsonData georedDict['operation'] = operation georedDict['timestamp'] = time.time_ns() georedDict['urls'] = asymmetricUrls - self.redisMessaging.sendMessage(queue=f'asymmetric-geored', message=json.dumps(georedDict), queueExpiry=120) + self.redisMessaging.sendMessage(queue=f'asymmetric-geored', message=json.dumps(georedDict), queueExpiry=120, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='geored') return True except Exception as E: @@ -954,7 +956,7 @@ def handleWebhook(self, objectData, operation: str="PATCH"): webhook['headers'] = webhookHeaders webhook['operation'] = operation webhook['timestamp'] = time.time_ns() - self.redisMessaging.sendMessage(queue=f'webhook', message=json.dumps(webhook), queueExpiry=120) + self.redisMessaging.sendMessage(queue=f'webhook', message=json.dumps(webhook), queueExpiry=120, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='webhook') return True def Sanitize_Datetime(self, result): diff --git a/lib/diameter.py b/lib/diameter.py index 684377c6..7c78e5bd 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -13,6 +13,7 @@ import yaml import json import time +import socket import traceback class Diameter: @@ -38,16 +39,7 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 else: self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) - """ - The below handling of additional peers is deprecated and will be replaced with redis sentinel in the next major refactor. - """ - self.redisPeerConnections = [] - if self.redisAdditionalPeers: - for additionalPeer in self.redisAdditionalPeers: - additionalPeerHost = additionalPeer.split(':')[0] - additionalPeerPort = additionalPeer.split(':')[1] - redisPeerConnection = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=False, unixSocketPath=self.redisUnixSocketPath) - self.redisPeerConnections.append({"peer": additionalPeer, "connection": Redis(host=additionalPeerHost, port=additionalPeerPort)}) + self.hostname = socket.gethostname() self.database = Database(logTool=logTool) self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) @@ -569,7 +561,7 @@ def getConnectedPeersByType(self, peerType: str) -> list: if peerType not in peerTypes: return [] filteredConnectedPeers = [] - activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter').decode()) for key, value in activePeers.items(): if activePeers.get(key, {}).get('peerType', '') == peerType and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': @@ -583,7 +575,7 @@ def getConnectedPeersByType(self, peerType: str) -> list: def getPeerByHostname(self, hostname: str) -> dict: try: hostname = hostname.lower() - activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter').decode()) for key, value in activePeers.items(): if activePeers.get(key, {}).get('diameterHostname', '').lower() == hostname and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': @@ -639,7 +631,7 @@ def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" sendTime = time.time_ns() outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return request except Exception as e: @@ -672,7 +664,7 @@ def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" sendTime = time.time_ns() outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Queueing for peer type: {peerType} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return connectedPeerList except Exception as e: @@ -716,14 +708,14 @@ def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeo sendTime = time.time_ns() outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) startTimer = time.time() while True: try: if not time.time() >= startTimer + timeout: if sessionId is None: - queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound") + queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] queuedMessages(NoSessionId): {queuedMessages}", redisClient=self.redisMessaging) for queuedMessage in queuedMessages: queuedMessage = json.loads(queuedMessage) @@ -740,7 +732,7 @@ def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeo return messageHex time.sleep(0.02) else: - queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound") + queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] queuedMessages({sessionId}): {queuedMessages} responseType: {responseType}", redisClient=self.redisMessaging) for queuedMessage in queuedMessages: queuedMessage = json.loads(queuedMessage) @@ -1050,8 +1042,8 @@ def storeEmergencySubscriber(self, subscriberIp: str, subscriberData: dict, gxSe if existingEmergencySubscriber: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [getEmergencySubscriber] Found existing emergency subscriber to overwrite: {existingEmergencySubscriber}", redisClient=self.redisMessaging) for key, value in existingEmergencySubscriber.items(): - self.redisMessaging.multiDeleteQueue(queue=f"emergencySubscriber:{value.get('ip')}:{value.get('imsi')}:{value.get('servingPgw')}", redisPeerConnections=self.redisPeerConnections) - result = self.redisMessaging.multiSetValue(key=emergencySubscriberKey, value=json.dumps(subscriberData), keyExpiry=authExpiry, redisPeerConnections=self.redisPeerConnections) + self.redisMessaging.deleteQueue(queue=f"emergencySubscriber:{value.get('ip')}:{value.get('imsi')}:{value.get('servingPgw')}", redisPeerConnections=self.redisPeerConnections, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') + result = self.redisMessaging.setValue(key=emergencySubscriberKey, value=json.dumps(subscriberData), keyExpiry=authExpiry, redisPeerConnections=self.redisPeerConnections, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') return True except Exception as e: self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getEmergencySubscriber] Error storing emergency subscriber in redis: {traceback.format_exc()}", redisClient=self.redisMessaging) @@ -1070,13 +1062,13 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non return None if subscriberIp and subscriberImsi: - emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:{subscriberIp}:{subscriberImsi}:*") + emergencySubscriberKeyList = self.redisMessaging.getQueues(pattern=f"emergencySubscriber:{subscriberIp}:{subscriberImsi}:*", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if emergencySubscriberKeyList: for matchedKey in emergencySubscriberKeyList: for peerName, keyName in matchedKey.items(): if isinstance(keyName, list): keyName = keyName[0] if len(keyName) > 0 else '' - emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.getRedisPeerConnection(peerName=peerName)) + emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.redisMessaging(peerName=peerName), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if not emergencySubscriberData: return None emergencySubscriberData = json.loads(emergencySubscriberData) @@ -1084,13 +1076,13 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non return emergencySubscriber if subscriberIp and not subscriberImsi: - emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:{subscriberIp}:*") + emergencySubscriberKeyList = self.redisMessaging.getQueues(pattern=f"emergencySubscriber:{subscriberIp}:*", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if emergencySubscriberKeyList: for matchedKey in emergencySubscriberKeyList: for peerName, keyName in matchedKey.items(): if isinstance(keyName, list): keyName = keyName[0] if len(keyName) > 0 else '' - emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.getRedisPeerConnection(peerName=peerName)) + emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.redisMessaging(peerName=peerName), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if not emergencySubscriberData: return None emergencySubscriberData = json.loads(emergencySubscriberData) @@ -1098,13 +1090,13 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non return emergencySubscriber if subscriberImsi and not subscriberIp: - emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:*:{subscriberImsi}:*") + emergencySubscriberKeyList = self.redisMessaging.getQueues(pattern=f"emergencySubscriber:*:{subscriberImsi}:*", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if emergencySubscriberKeyList: for matchedKey in emergencySubscriberKeyList: for peerName, keyName in matchedKey.items(): if isinstance(keyName, list): keyName = keyName[0] if len(keyName) > 0 else '' - emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.getRedisPeerConnection(peerName=peerName)) + emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.redisMessaging(peerName=peerName), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if not emergencySubscriberData: return None emergencySubscriberData = json.loads(emergencySubscriberData) @@ -1112,13 +1104,13 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non return emergencySubscriber if gxSessionId: - emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:*:*:{gxSessionId}") + emergencySubscriberKeyList = self.redisMessaging.getQueues(pattern=f"emergencySubscriber:*:*:{gxSessionId}", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if emergencySubscriberKeyList: for matchedKey in emergencySubscriberKeyList: for peerName, keyName in matchedKey.items(): if isinstance(keyName, list): keyName = keyName[0] if len(keyName) > 0 else '' - emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.getRedisPeerConnection(peerName=peerName)) + emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.redisMessaging(peerName=peerName), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if not emergencySubscriberData: return None emergencySubscriberData = json.loads(emergencySubscriberData) @@ -1130,20 +1122,6 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non except Exception as e: self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getEmergencySubscriber] Error getting emergency subscriber from redis: {traceback.format_exc()}", redisClient=self.redisMessaging) return None - - def getRedisPeerConnection(self, peerName: str): - """ - [Deprecated] Returns a redis peer connection given a peerName. - Returns None on failure. - """ - try: - for peerConnection in self.redisPeerConnections: - if str(peerConnection.get('peer').lower()) == str(peerName.lower()): - return peerConnection.get('connection') - self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getRedisPeerConnection] No redis peers matched for: {peerName}", redisClient=self.redisMessaging) - return None - except Exception as e: - self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getRedisPeerConnection] Error matching redis peer: {traceback.format_exc()}", redisClient=self.redisMessaging) def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body for avp_dicts in avps: @@ -1332,7 +1310,6 @@ def Answer_257(self, packet_vars, avps): #Device Watchdog Answer def Answer_280(self, packet_vars, avps): - avp = '' #Initiate empty var AVP avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -1342,8 +1319,6 @@ def Answer_280(self, packet_vars, avps): avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) response = self.generate_diameter_packet("01", "00", 280, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet self.logTool.log(service='HSS', level='debug', message="Successfully Generated DWA", redisClient=self.redisMessaging) - orignHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - orignHost = binascii.unhexlify(orignHost).decode('utf-8') #Format it return response #Disconnect Peer Answer diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index f8c5be94..11240f48 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -2,6 +2,7 @@ import math import asyncio import yaml +import socket from messagingAsync import RedisMessagingAsync @@ -41,7 +42,7 @@ def __init__(self, logTool): self.redisMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) self.logTool = logTool - + self.hostname = socket.gethostname() #Generates rounding for calculating padding async def myRound(self, n, base=4): @@ -246,7 +247,7 @@ async def getConnectedPeersByType(self, peerType: str) -> list: if peerType not in peerTypes: return [] filteredConnectedPeers = [] - activePeers = await(self.redisMessaging.getValue(key="ActiveDiameterPeers")) + activePeers = await(self.redisMessaging.getValue(key="ActiveDiameterPeers", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter')) for key, value in activePeers.items(): if activePeers.get(key, {}).get('peerType', '') == 'pgw' and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': diff --git a/lib/logtool.py b/lib/logtool.py index 85061139..b5877aa3 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -1,6 +1,7 @@ import logging import logging.handlers as handlers import os, sys, time +import socket from datetime import datetime sys.path.append(os.path.realpath('../')) import asyncio @@ -41,6 +42,7 @@ def __init__(self, config: dict): self.redisMessagingAsync = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.hostname = socket.gethostname() async def logAsync(self, service: str, level: str, message: str, redisClient=None) -> bool: """ @@ -55,7 +57,7 @@ async def logAsync(self, service: str, level: str, message: str, redisClient=Non timestamp = time.time() dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() print(f"[{dateTimeString}] [{level.upper()}] {message}") - await(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60)) + await(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='log')) return True def log(self, service: str, level: str, message: str, redisClient=None) -> bool: @@ -71,7 +73,7 @@ def log(self, service: str, level: str, message: str, redisClient=None) -> bool: timestamp = time.time() dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() print(f"[{dateTimeString}] [{level.upper()}] {message}") - redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60) + redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') return True def setupFileLogger(self, loggerName: str, logFilePath: str): diff --git a/lib/messaging.py b/lib/messaging.py index 8a8e62d9..9656efe0 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -1,4 +1,5 @@ -from redis import Redis +from unittest.mock import sentinel +from redis import Redis, Sentinel import time, json, uuid, traceback class RedisMessaging: @@ -7,17 +8,33 @@ class RedisMessaging: A class for sending and receiving redis messages. """ - def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=False, unixSocketPath: str='/var/run/redis/redis-server.sock'): + def __init__(self, useTcp: bool=False, host: str='localhost', port: int=6379, useUnixSocket: bool=False, unixSocketPath: str='/var/run/redis/redis-server.sock', useSentinel: bool=False, sentinelHosts: list=[]): if useUnixSocket: self.redisClient = Redis(unix_socket_path=unixSocketPath) + elif useSentinel: + sentinelList = [] + for host in sentinelHosts: + for key, value in host.items(): + sentinelList.append((key, int(host.get('port', 6379)))) else: self.redisClient = Redis(host=host, port=port) - def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: + def handlePrefix(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): + """ + Adds a prefix to the Key or Queue name, if enabled. + Returns the same Key or Queue if not enabled. + """ + if usePrefix: + return f"{prefixHostname}:{prefixServiceName}:{key}" + else: + return key + + def sendMessage(self, queue: str, message: str, queueExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a message in a given Queue (Key). """ try: + queue = self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) self.redisClient.rpush(queue, message) if queueExpiry is not None: self.redisClient.expire(queue, queueExpiry) @@ -25,7 +42,7 @@ def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: except Exception as e: return '' - def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: + def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a prometheus metric in a format readable by the metric service. """ @@ -44,35 +61,36 @@ def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricA } ]) - metricQueueName = f"metric" + queue = self.handlePrefix(key='metric', usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) try: - self.redisClient.rpush(metricQueueName, prometheusMetricBody) + self.redisClient.rpush(queue, prometheusMetricBody) if metricExpiry is not None: - self.redisClient.expire(metricQueueName, metricExpiry) + self.redisClient.expire(queue, metricExpiry) return f'Succesfully stored metric called: {metricName}, with value of: {metricType}' except Exception as e: return '' - def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None) -> str: + def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a message in a given Queue (Key). """ try: - logQueueName = f"log" + queue = self.handlePrefix(key='log', usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) logMessage = json.dumps({"message": message, "service": serviceName, "level": logLevel, "timestamp": logTimestamp}) - self.redisClient.rpush(logQueueName, logMessage) + self.redisClient.rpush(queue, logMessage) if logExpiry is not None: - self.redisClient.expire(logQueueName, logExpiry) - return f'{message} stored in {logQueueName} successfully.' + self.redisClient.expire(queue, logExpiry) + return f'{message} stored in {queue} successfully.' except Exception as e: return '' - def getMessage(self, queue: str) -> str: + def getMessage(self, queue: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Gets the oldest message from a given Queue (Key), while removing it from the key as well. Deletes the key if the last message is being removed. """ try: + queue = self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) message = self.redisClient.lpop(queue) if message is None: message = '' @@ -85,101 +103,68 @@ def getMessage(self, queue: str) -> str: except Exception as e: return '' - def getQueues(self, pattern: str='*') -> list: + def getQueues(self, pattern: str='*', usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> list: """ Returns all Queues (Keys) in the database. """ try: + pattern = self.handlePrefix(key=pattern, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) allQueues = self.redisClient.scan_iter(match=pattern) return [x.decode() for x in allQueues] except Exception as e: return f"{traceback.format_exc()}" - def multiGetQueues(self, pattern: str='*', redisPeerConnections: list=[]) -> list: - try: - allQueues = [] - for redisPeerConnection in redisPeerConnections: - try: - peerName = redisPeerConnection.get('peer') - peerConnection = redisPeerConnection.get('connection') - - keys = [key.decode() for key in peerConnection.scan_iter(match=pattern)] - - allQueues.append({peerName: keys}) - except Exception as e: - continue - - localhost_keys = [key.decode() for key in self.redisClient.scan_iter(match=pattern)] - - allQueues.append({"localhost": localhost_keys}) - - return allQueues - except Exception as e: - return f"{traceback.format_exc()}" - - - def getNextQueue(self, pattern: str='*') -> dict: + def getNextQueue(self, pattern: str='*', usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> dict: """ Returns the next Queue (Key) in the list. """ try: + pattern = self.handlePrefix(key=pattern, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) for nextQueue in self.redisClient.scan_iter(match=pattern): return nextQueue.decode() except Exception as e: return {} - def awaitMessage(self, key: str): + def awaitMessage(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): """ Blocks until a message is received at the given key, then returns the message. """ try: + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) message = self.redisClient.blpop(key) return tuple(data.decode() for data in message) except Exception as e: return '' - def awaitBulkMessage(self, key: str, count: int=100): + def awaitBulkMessage(self, key: str, count: int=100, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): """ Blocks until one or more messages are received at the given key, then returns the amount of messages specified by count. """ try: + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) message = self.redisClient.blmpop(0, 1, key, direction='RIGHT', count=count) return message except Exception as e: print(traceback.format_exc()) return '' - def deleteQueue(self, queue: str) -> bool: + def deleteQueue(self, queue: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> bool: """ Deletes the given Queue (Key) """ try: + queue = self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) self.redisClient.delete(queue) return True except Exception as e: return False - def multiDeleteQueue(self, queue: str, redisPeerConnections: list=[]) -> bool: - """ - Deletes the given Queue (Key) on each peer, including the default connection. - """ - try: - for redisPeerConnection in redisPeerConnections: - try: - peerConnection = redisPeerConnection.get('connection') - peerConnection.delete(queue) - except Exception as e: - continue - self.redisClient.delete(queue) - return True - except Exception as e: - return False - - def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: + def setValue(self, key: str, value: str, keyExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a value under a given key and sets an expiry (in seconds) if provided. """ try: + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) self.redisClient.set(key, value) if keyExpiry is not None: self.redisClient.expire(key, keyExpiry) @@ -187,61 +172,41 @@ def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: except Exception as e: return '' - def multiSetValue(self, key: str, value: str, keyExpiry: int=None, redisPeerConnections: list=[]) -> str: - """ - Stores a value under a given key and sets an expiry (in seconds) if provided, for each peer, including the default connection. - """ - try: - for redisPeerConnection in redisPeerConnections: - try: - peerConnection = redisPeerConnection.get('connection') - peerConnection.set(key, value) - if keyExpiry is not None: - peerConnection.expire(key, keyExpiry) - except Exception as e: - continue - self.redisClient.set(key, value) - if keyExpiry is not None: - self.redisClient.expire(key, keyExpiry) - return f'{value} stored in {key} successfully.' - except Exception as e: - return '' - - def getValue(self, key: str, redisClient=None) -> str: + def getValue(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Gets the value stored under a given key. """ try: - if redisClient: - message = redisClient.get(key) - else: - message = self.redisClient.get(key) - if message is None: - message = '' - else: - return message + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) + message = self.redisClient.get(key) + if message is None: + message = '' + else: + return message except Exception as e: - return f"{traceback.format_exc()}" + return '' - def getList(self, key: str) -> list: + def getList(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> list: """ Gets the list stored under a given key. """ try: - allResults = self.redisClient.lrange(key, 0, -1) - if allResults is None: - result = [] - else: - return [result.decode() for result in allResults] + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) + allResults = self.redisClient.lrange(key, 0, -1) + if allResults is None: + result = [] + else: + return [result.decode() for result in allResults] except Exception as e: return [] - def RedisHGetAll(self, key: str): + def RedisHGetAll(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): """ Wrapper for Redis HGETALL *Deprecated: will be removed upon completed database cleanup. """ try: + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) data = self.redisClient.hgetall(key) return data except Exception as e: diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 6c33e0a6..4d56e7f0 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -1,5 +1,8 @@ import asyncio +import traceback +import socket import redis.asyncio as redis +from redis.asyncio.sentinel import Sentinel import time, json, uuid class RedisMessagingAsync: @@ -15,11 +18,22 @@ def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=Fa self.redisClient = redis.Redis(host=host, port=port) pass - async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: + async def handlePrefix(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): + """ + Adds a prefix to the Key or Queue name, if enabled. + Returns the same Key or Queue if not enabled. + """ + if usePrefix: + return f"{prefixHostname}:{prefixServiceName}:{key}" + else: + return key + + async def sendMessage(self, queue: str, message: str, queueExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. """ try: + queue = await(self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) await(self.redisClient.rpush(queue, message)) if queueExpiry is not None: await(self.redisClient.expire(queue, queueExpiry)) @@ -27,11 +41,12 @@ async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> except Exception as e: return '' - async def sendBulkMessage(self, queue: str, messageList: list, queueExpiry: int=None) -> str: + async def sendBulkMessage(self, queue: str, messageList: list, queueExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Empties a given asyncio queue into a redis pipeline, then sends to redis. """ try: + queue = await(self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) redisPipe = self.redisClient.pipeline() for message in messageList: @@ -46,7 +61,7 @@ async def sendBulkMessage(self, queue: str, messageList: list, queueExpiry: int= except Exception as e: return '' - async def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: + async def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a prometheus metric in a format readable by the metric service, asynchronously. """ @@ -66,6 +81,7 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m ]) metricQueueName = f"metric" + metricQueueName = await(self.handlePrefix(key=metricQueueName, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) try: async with self.redisClient.pipeline(transaction=True) as redisPipe: @@ -77,12 +93,13 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m except Exception as e: return '' - async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None) -> str: + async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a log message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. """ try: logQueueName = f"log" + logQueueName = await(self.handlePrefix(key=logQueueName, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) logMessage = json.dumps({"message": message, "service": serviceName, "level": logLevel, "timestamp": logTimestamp}) async with self.redisClient.pipeline(transaction=True) as redisPipe: await redisPipe.rpush(logQueueName, logMessage) @@ -93,42 +110,49 @@ async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: in except Exception as e: return '' - async def getMessage(self, queue: str) -> str: + async def getMessage(self, queue: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Gets the oldest message from a given Queue (Key) asynchronously, while removing it from the key as well. Deletes the key if the last message is being removed. """ try: - message = await(self.redisClient.lpop(queue)) - if message is None: - message = '' - else: - try: - if message[0] is None: - return '' - else: - message = message[0].decode() - except (UnicodeDecodeError, AttributeError): - pass - return message + queue = await(self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) + message = await(self.redisClient.lpop(queue)) + if message is None: + message = '' + else: + try: + if message[0] is None: + return '' + else: + message = message[0].decode() + except (UnicodeDecodeError, AttributeError): + pass + return message except Exception as e: return '' - async def getQueues(self, pattern: str='*') -> list: + async def getQueues(self, pattern: str='*', usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> list: """ Returns all Queues (Keys) in the database, asynchronously. """ try: - allQueuesBinary = await(self.redisClient.scan_iter(match=pattern)) - allQueues = [x.decode() for x in allQueuesBinary] - return allQueues + pattern = await(self.handlePrefix(key=pattern, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) + allQueuesBinary = [] + async for nextQueue in self.redisClient.scan_iter(match=pattern): + if nextQueue: + allQueuesBinary.append(nextQueue) + allQueues = [x.decode() for x in allQueuesBinary] + return allQueues except Exception as e: + print(traceback.format_exc()) return [] - async def getNextQueue(self, pattern: str='*') -> str: + async def getNextQueue(self, pattern: str='*', usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Returns the next Queue (Key) in the list, asynchronously. """ try: + pattern = await(self.handlePrefix(key=pattern, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) async for nextQueue in self.redisClient.scan_iter(match=pattern): if nextQueue is not None: return nextQueue.decode('utf-8') @@ -136,50 +160,66 @@ async def getNextQueue(self, pattern: str='*') -> str: print(e) return '' - async def awaitMessage(self, key: str): + async def awaitMessage(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): """ Asynchronously blocks until a message is received at the given key, then returns the message. """ try: + key = await(self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) message = (await(self.redisClient.blpop(key))) return tuple(data.decode() for data in message) except Exception as e: return '' - async def deleteQueue(self, queue: str) -> bool: + async def awaitBulkMessage(self, key: str, count: int=100, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): + """ + Asynchronously blocks until one or more messages are received at the given key, then returns the amount of messages specified by count. + """ + try: + key = await(self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) + message = await(self.redisClient.blmpop(0, 1, key, direction='RIGHT', count=count)) + return message + except Exception as e: + print(traceback.format_exc()) + return '' + + async def deleteQueue(self, queue: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> bool: """ Deletes the given Queue (Key) asynchronously. """ try: + queue = await(self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) deleteQueueResult = await(self.redisClient.delete(queue)) return True except Exception as e: return False - async def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: + async def setValue(self, key: str, value: str, keyExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a value under a given key asynchronously and sets an expiry (in seconds) if provided. """ try: + key = await(self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) async with self.redisClient.pipeline(transaction=True) as redisPipe: await redisPipe.set(key, value) if keyExpiry is not None: - await redisPipe.expire(key, value) - setValueResult, expireValueResult = await redisPipe.execute() + await redisPipe.expire(key, keyExpiry) + setValueResult = await redisPipe.execute() return f'{value} stored in {key} successfully.' except Exception as e: - return '' + return traceback.format_exc() - async def getValue(self, key: str) -> str: + async def getValue(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Gets the value stored under a given key asynchronously. """ try: - message = await(self.redisClient.get(key)) - if message is None: - message = '' - else: - return message + key = await(self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) + message = await(self.redisClient.get(key)) + if message is None: + message = '' + else: + return message except Exception as e: return '' diff --git a/services/apiService.py b/services/apiService.py index 2949cadd..aca8d4cd 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -1302,7 +1302,7 @@ class PyHSS_OAM_Peers(Resource): def get(self): '''Get active Diameter Peers''' try: - diameterPeers = json.loads(redisMessaging.getValue("ActiveDiameterPeers")) + diameterPeers = json.loads(redisMessaging.getValue("ActiveDiameterPeers", usePrefix=True, prefixHostname=originHostname, prefixServiceName='diameter')) return diameterPeers, 200 except Exception as E: logTool.log(service='API', level='error', message=f"[API] An error occurred: {traceback.format_exc()}", redisClient=redisMessaging) diff --git a/services/diameterService.py b/services/diameterService.py index 04629b31..af10e654 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -2,6 +2,7 @@ import sys, os, json import time, yaml, uuid from datetime import datetime +import socket sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from diameterAsync import DiameterAsync @@ -41,6 +42,7 @@ def __init__(self): self.diameterRequests = 0 self.diameterResponses = 0 self.workerPoolSize = int(self.config.get('hss', {}).get('diameter_service_workers', 10)) + self.hostname = socket.gethostname() async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inboundData) -> bool: """ @@ -86,7 +88,7 @@ async def handleActiveDiameterPeers(self): del self.activePeers[key] await(self.logActivePeers()) - await(self.redisPeerMessaging.setValue(key='ActiveDiameterPeers', value=json.dumps(self.activePeers), keyExpiry=86400)) + await(self.redisPeerMessaging.setValue(key='ActiveDiameterPeers', value=json.dumps(self.activePeers), keyExpiry=86400, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter')) await(asyncio.sleep(1)) except Exception as e: @@ -174,7 +176,7 @@ async def inboundDataWorker(self, coroutineUuid: str) -> bool: break if messageList: - await self.redisReaderMessaging.sendBulkMessage(queue=inboundQueueName, messageList=messageList, queueExpiry=self.diameterRequestTimeout) + await self.redisReaderMessaging.sendBulkMessage(queue=inboundQueueName, messageList=messageList, queueExpiry=self.diameterRequestTimeout, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') messageList = [] except Exception as e: @@ -189,7 +191,7 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s while not writer.transport.is_closing(): try: await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Waiting for messages for host {clientAddress} on port {clientPort}")) - pendingOutboundMessage = json.loads((await(self.redisWriterMessaging.awaitMessage(key=f"diameter-outbound-{clientAddress}-{clientPort}")))[1]) + pendingOutboundMessage = json.loads((await(self.redisWriterMessaging.awaitMessage(key=f"diameter-outbound-{clientAddress}-{clientPort}", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter')))[1]) diameterOutboundBinary = bytes.fromhex(pendingOutboundMessage.get('diameter-outbound', '')) await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.")) diff --git a/services/georedService.py b/services/georedService.py index 12f16802..b0cb7ae5 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -1,6 +1,7 @@ import os, sys, json, yaml import uuid, time import asyncio, aiohttp +import socket sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from banners import Banners @@ -32,6 +33,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.georedPeers = self.config.get('geored', {}).get('endpoints', []) self.webhookPeers = self.config.get('webhooks', {}).get('endpoints', []) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + self.hostname = socket.gethostname() if not self.config.get('geored', {}).get('enabled'): self.logger.error("[Geored] Fatal Error - geored not enabled under geored.enabled, exiting.") @@ -255,7 +257,7 @@ async def handleAsymmetricGeoredQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='asymmetric-geored')))[1]) + georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='asymmetric-geored', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='geored')))[1]) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Message: {georedMessage}")) georedOperation = georedMessage['operation'] @@ -286,7 +288,7 @@ async def handleGeoredQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='geored')))[1]) + georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='geored', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='geored')))[1]) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}")) georedOperation = georedMessage['operation'] @@ -316,7 +318,7 @@ async def handleWebhookQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - webhookMessage = json.loads((await(self.redisWebhookMessaging.awaitMessage(key='webhook')))[1]) + webhookMessage = json.loads((await(self.redisWebhookMessaging.awaitMessage(key='webhook', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='webhook')))[1]) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}")) diff --git a/services/hssService.py b/services/hssService.py index 46abcbca..08393446 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -1,4 +1,4 @@ -import os, sys, json, yaml, time, traceback +import os, sys, json, yaml, time, traceback, socket sys.path.append(os.path.realpath('../lib')) from messaging import RedisMessaging from diameter import Diameter @@ -30,6 +30,8 @@ def __init__(self): self.logTool.log(service='HSS', level='info', message=f"{self.banners.hssService()}", redisClient=self.redisMessaging) self.diameterLibrary = Diameter(logTool=self.logTool, originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + self.hostname = socket.gethostname() + def handleQueue(self): """ @@ -40,7 +42,7 @@ def handleQueue(self): if self.benchmarking: startTime = time.perf_counter() - inboundMessageList = self.redisMessaging.awaitBulkMessage(key='diameter-inbound') + inboundMessageList = self.redisMessaging.awaitBulkMessage(key='diameter-inbound', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if inboundMessageList == None: continue @@ -86,7 +88,7 @@ def handleQueue(self): self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if self.benchmarking: self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) diff --git a/services/logService.py b/services/logService.py index 34e7ae08..827662ae 100644 --- a/services/logService.py +++ b/services/logService.py @@ -1,4 +1,4 @@ -import os, sys, json, yaml +import os, sys, json, yaml, socket from datetime import datetime import time import logging @@ -37,6 +37,8 @@ def __init__(self): 'DEBUG': {'verbosity': 5, 'logging': logging.DEBUG}, 'NOTSET': {'verbosity': 6, 'logging': logging.NOTSET}, } + self.hostname = socket.gethostname() + print(f"{self.banners.logService()}") def handleLogs(self): @@ -46,7 +48,7 @@ def handleLogs(self): activeLoggers = {} while True: try: - logMessage = json.loads(self.redisMessaging.awaitMessage(key='log')[1]) + logMessage = json.loads(self.redisMessaging.awaitMessage(key='log', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='log')[1]) print(f"[Log] Message: {logMessage}") diff --git a/services/metricService.py b/services/metricService.py index 12d51c14..968ddc56 100644 --- a/services/metricService.py +++ b/services/metricService.py @@ -1,6 +1,7 @@ import asyncio import sys, os, json import time, json, yaml +import socket from prometheus_client import make_wsgi_app, start_http_server, Counter, Gauge, Summary, Histogram, CollectorRegistry from werkzeug.middleware.dispatcher import DispatcherMiddleware from flask import Flask @@ -25,6 +26,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.logTool = LogTool(config=self.config) self.registry = CollectorRegistry(auto_describe=True) self.logTool.log(service='Metric', level='info', message=f"{self.banners.metricService()}", redisClient=self.redisMessaging) + self.hostname = socket.gethostname() def handleMetrics(self): """ @@ -34,7 +36,7 @@ def handleMetrics(self): actions = {'inc': 'inc', 'dec': 'dec', 'set':'set'} prometheusTypes = {'counter': Counter, 'gauge': Gauge, 'histogram': Histogram, 'summary': Summary} - metric = self.redisMessaging.awaitMessage(key='metric')[1] + metric = self.redisMessaging.awaitMessage(key='metric', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='metric')[1] self.logTool.log(service='Metric', level='debug', message=f"[Metric] [handleMetrics] Received Metric: {metric}", redisClient=self.redisMessaging) prometheusJsonList = json.loads(metric) From 50eee27668c43d45dca77b3eacbd606a737a7239 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 16 Jan 2024 10:18:00 +1000 Subject: [PATCH 2/9] Pull in RAR fixes --- config.yaml | 10 ++- lib/database.py | 3 +- lib/diameter.py | 152 ++++++++++++++++++++++++++++++++---- lib/messaging.py | 10 +-- lib/messagingAsync.py | 1 - services/diameterService.py | 24 +++++- services/georedService.py | 7 +- 7 files changed, 176 insertions(+), 31 deletions(-) diff --git a/config.yaml b/config.yaml index 5e053087..13f35c3e 100644 --- a/config.yaml +++ b/config.yaml @@ -1,5 +1,7 @@ ## HSS Parameters -hss: +hss: + # Transport Type. "TCP" and "SCTP" are valid options. + # Note: SCTP works but is still experimental. TCP has been load-tested and performs in a production environment. transport: "TCP" #IP Addresses to bind on (List) - For TCP only the first IP is used, for SCTP all used for Transport (Multihomed). bind_ip: ["0.0.0.0"] @@ -70,6 +72,12 @@ hss: # Whether or not to a subscriber to connect to an undefined network when outbound roaming. allow_undefined_networks: True + # SCTP Socket Parameters + sctp: + rtoMax: 5000 + rtoMin: 500 + rtoInitial: 1000 + api: page_size: 200 diff --git a/lib/database.py b/lib/database.py index 013299a1..6d256867 100755 --- a/lib/database.py +++ b/lib/database.py @@ -1455,8 +1455,7 @@ def Get_Served_IMS_Subscribers(self, get_local_users_only=False): IMS_SUBSCRIBER.scscf.isnot(None)) for result in results: result = result.__dict__ - self.logTool.log(service='Database', level='debug', message="Result: " + str(result, redisClient=self.redisMessaging) + - " type: " + str(type(result))) + self.logTool.log(service='Database', level='debug', message="Result: " + str(result) + " type: " + str(type(result)), redisClient=self.redisMessaging) result = self.Sanitize_Datetime(result) result.pop('_sa_instance_state') if get_local_users_only == True: diff --git a/lib/diameter.py b/lib/diameter.py index 7c78e5bd..66fe8c0f 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -15,6 +15,7 @@ import time import socket import traceback +import re class Diameter: @@ -155,7 +156,6 @@ def Reverse(self, str): return (slicedString) def DecodePLMN(self, plmn): - self.logTool.log(service='HSS', level='debug', message="Decoding PLMN: " + str(plmn), redisClient=self.redisMessaging) if "f" in plmn: mcc = self.Reverse(plmn[0:2]) + self.Reverse(plmn[2:4]).replace('f', '') @@ -168,6 +168,7 @@ def DecodePLMN(self, plmn): return mcc, mnc def EncodePLMN(self, mcc, mnc): + plmn = list('XXXXXX') if len(mnc) == 2: plmn[0] = self.Reverse(mcc)[1] plmn[1] = self.Reverse(mcc)[2] @@ -573,6 +574,7 @@ def getConnectedPeersByType(self, peerType: str) -> list: return [] def getPeerByHostname(self, hostname: str) -> dict: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [getPeerByHostname] Looking for peer with hostname {hostname}", redisClient=self.redisMessaging) try: hostname = hostname.lower() activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter').decode()) @@ -582,6 +584,7 @@ def getPeerByHostname(self, hostname: str) -> dict: return(activePeers.get(key, {})) except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getPeerByHostname] Failed to find peer with hostname {hostname}", redisClient=self.redisMessaging) return {} def getDiameterMessageType(self, binaryData: str) -> dict: @@ -626,8 +629,12 @@ def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: peerPort = connectedPeer['port'] except Exception as e: return '' - request = diameterApplication["requestMethod"](**kwargs) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + try: + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Error generating request: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" sendTime = time.time_ns() outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) @@ -659,7 +666,12 @@ def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> peerPort = connectedPeer['port'] except Exception as e: return '' - request = diameterApplication["requestMethod"](**kwargs) + try: + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Error generating request: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" sendTime = time.time_ns() @@ -696,12 +708,20 @@ def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeo except Exception as e: continue connectedPeer = self.getPeerByHostname(hostname=hostname) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Sending request via connected peer {connectedPeer} from hostname {hostname}", redisClient=self.redisMessaging) try: peerIp = connectedPeer['ipAddress'] peerPort = connectedPeer['port'] except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Could not get connection information for connectedPeer: {connectedPeer}", redisClient=self.redisMessaging) + return '' + + try: + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Error generating request: {traceback.format_exc()}", redisClient=self.redisMessaging) return '' - request = diameterApplication["requestMethod"](**kwargs) responseType = diameterApplication["responseAcronym"] sessionId = kwargs.get('sessionId', None) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) @@ -782,8 +802,12 @@ def generateDiameterResponse(self, binaryData: str) -> str: if 'flags' in diameterApplication: assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Attempting to generate response", redisClient=self.redisMessaging) - response = diameterApplication["responseMethod"](packet_vars, avps) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Successfully generated response: {response}", redisClient=self.redisMessaging) + try: + response = diameterApplication["responseMethod"](packet_vars, avps) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Successfully generated response: {response}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Error generating response: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' break except Exception as e: continue @@ -1158,7 +1182,7 @@ def Charging_Rule_Generator(self, ChargingRules=None, ue_ip=None, chargingRuleNa #Populate all Flow Information AVPs Flow_Information = '' for tft in ChargingRules['tft']: - self.logTool.log(service='HSS', level='debug', message=tft, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Adding TFT: " + str(tft), redisClient=self.redisMessaging) #If {{ UE_IP }} in TFT splice in the real UE IP Value try: tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) @@ -2777,9 +2801,30 @@ def Answer_16777236_265(self, packet_vars, avps): msisdn = imsSubscriberDetails.get('msisdn', None) except Exception as e: pass + if identifier == None: + try: + ueIP = subscriptionId.split('@')[1].split(':')[0] + ue = self.database.Get_UE_by_IP(ueIP) + subscriberId = ue.get('subscriber_id', None) + subscriberDetails = self.database.Get_Subscriber(subscriber_id=subscriberId) + imsi = subscriberDetails.get('imsi', None) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Found IMSI {imsi} by IP: {ueIP}", redisClient=self.redisMessaging) + except Exception as e: + pass else: imsi = None msisdn = None + try: + ueIP = subscriptionId.split(':')[0] + ue = self.database.Get_UE_by_IP(ueIP) + subscriberId = ue.get('subscriber_id', None) + subscriberDetails = self.database.Get_Subscriber(subscriber_id=subscriberId) + imsi = subscriberDetails.get('imsi', None) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Found IMSI {imsi} by IP: {ueIP}", redisClient=self.redisMessaging) + except Exception as e: + pass + + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] IMSI: {imsi}\nMSISDN: {msisdn}", redisClient=self.redisMessaging) imsEnabled = self.validateImsSubscriber(imsi=imsi, msisdn=msisdn) @@ -2857,6 +2902,81 @@ def Answer_16777236_265(self, packet_vars, avps): except Exception as e: pass + #Extract the SDP for each direction to find the source and destination IP Addresses and Ports used for the RTP streams + try: + sdp1 = self.get_avp_data(avps, 524)[0] + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] got first SDP body raw: " + str(sdp1), redisClient=self.redisMessaging) + sdp1 = binascii.unhexlify(sdp1).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] got first SDP body decoded: " + str(sdp1), redisClient=self.redisMessaging) + sdp2 = self.get_avp_data(avps, 524)[1] + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] got second SDP body raw: " + str(sdp2), redisClient=self.redisMessaging) + sdp2 = binascii.unhexlify(sdp2).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] got second SDP body decoded: " + str(sdp2), redisClient=self.redisMessaging) + + regex_ipv4 = r"IN IP4 (\d*\.\d*\.\d*\.\d*)" + regex_ipv6 = r"IN IP6 ([0-9a-fA-F:]{3,39})" + regex_port_audio = r"m=audio (\d*)" + regex_port_rtcp = r"a=rtcp:(\d*)" + + #Check for IPv4 Matches in first SDP Body + matches_ipv4 = re.search(regex_ipv4, sdp1, re.MULTILINE) + if matches_ipv4: + sdp1_ipv4 = str(matches_ipv4.group(1)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP IPv4" + str(sdp1_ipv4), redisClient=self.redisMessaging) + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] No matches for IPv4 in SDP", redisClient=self.redisMessaging) + if not matches_ipv4: + #Check for IPv6 Matches + matches_ipv6 = re.search(regex_ipv6, sdp1, re.MULTILINE) + if matches_ipv6: + sdp1_ipv6 = str(matches_ipv6.group(1)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP IPv6" + str(sdp1_ipv6), redisClient=self.redisMessaging) + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] No matches for IPv6 in SDP", redisClient=self.redisMessaging) + + + #Check for IPv4 Matches in second SDP Body + matches_ipv4 = re.search(regex_ipv4, sdp2, re.MULTILINE) + if matches_ipv4: + sdp2_ipv4 = str(matches_ipv4.group(1)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP2 IPv4 " + str(sdp2_ipv4), redisClient=self.redisMessaging) + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] No matches for IPv4 in SDP2", redisClient=self.redisMessaging) + if not matches_ipv4: + #Check for IPv6 Matches + matches_ipv6 = re.search(regex_ipv6, sdp2, re.MULTILINE) + if matches_ipv6: + sdp2_ipv6 = str(matches_ipv6.group(1)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP2 IPv6 " + str(sdp2_ipv6), redisClient=self.redisMessaging) + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] No matches for IPv6 in SDP", redisClient=self.redisMessaging) + + #Extract RTP Port + matches_rtp_port = re.search(regex_port_audio, sdp2, re.MULTILINE) + if matches_rtp_port: + sdp2_rtp_port = str(matches_rtp_port.group(1)) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP2 RTP Port " + str(sdp2_rtp_port), redisClient=self.redisMessaging) + + #Extract RTP Port + matches_rtp_port = re.search(regex_port_audio, sdp1, re.MULTILINE) + if matches_rtp_port: + sdp1_rtp_port = str(matches_rtp_port.group(1)) + sdp1_rtcp_port = int(sdp1_rtp_port) - 1 + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP1 RTP Port " + str(sdp1_rtp_port), redisClient=self.redisMessaging) + + + #Extract RTP Port + matches_rtp_port = re.search(regex_port_audio, sdp2, re.MULTILINE) + if matches_rtp_port: + sdp2_rtp_port = str(matches_rtp_port.group(1)) + sdp2_rtcp_port = int(sdp2_rtp_port) - 1 + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP2 RTP Port " + str(sdp2_rtp_port), redisClient=self.redisMessaging) + + + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Failed to extract SDP due to error" + str(e), redisClient=self.redisMessaging) + + """ The below logic is applied: 1. Grab the Flow Rules and bitrates from the PCSCF in the AAR, @@ -2869,14 +2989,14 @@ def Answer_16777236_265(self, packet_vars, avps): chargingRule = { "charging_rule_id": 1000, "qci": 1, - "arp_preemption_capability": True, + "arp_preemption_capability": False, "mbr_dl": dlBandwidth, "mbr_ul": ulBandwidth, "gbr_ul": ulBandwidth, - "precedence": 100, - "arp_priority": 2, + "precedence": 40, + "arp_priority": 15, "rule_name": "GBR-Voice", - "arp_preemption_vulnerability": False, + "arp_preemption_vulnerability": True, "gbr_dl": dlBandwidth, "tft_group_id": 1, "rating_group": None, @@ -2885,20 +3005,20 @@ def Answer_16777236_265(self, packet_vars, avps): "tft_group_id": 1, "direction": 1, "tft_id": 1, - "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535" - }, + "tft_string": "permit out 17 from " + str(sdp2_ipv4) + "/32 " + str(sdp2_rtcp_port) + "-" + str(sdp2_rtp_port) + " to " + str(ueIp) + "/32 " + str(sdp1_rtcp_port) + "-" + str(sdp1_rtp_port) }, { "tft_group_id": 1, "direction": 2, "tft_id": 2, - "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535" - } + "tft_string": "permit out 17 from " + str(sdp2_ipv4) + "/32 " + str(sdp2_rtcp_port) + "-" + str(sdp2_rtp_port) + " to " + str(ueIp) + "/32 " + str(sdp1_rtcp_port) + "-" + str(sdp1_rtp_port) } ] } if not emergencySubscriber: self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=sessionId) + + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAR Generated to be sent to serving PGW: {servingPgw} via peer {servingPgwPeer}", redisClient=self.redisMessaging) reAuthAnswer = self.awaitDiameterRequestAndResponse( requestType='RAR', hostname=servingPgwPeer, diff --git a/lib/messaging.py b/lib/messaging.py index 9656efe0..55ada5c4 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -1,5 +1,4 @@ -from unittest.mock import sentinel -from redis import Redis, Sentinel +from redis import Redis import time, json, uuid, traceback class RedisMessaging: @@ -8,14 +7,9 @@ class RedisMessaging: A class for sending and receiving redis messages. """ - def __init__(self, useTcp: bool=False, host: str='localhost', port: int=6379, useUnixSocket: bool=False, unixSocketPath: str='/var/run/redis/redis-server.sock', useSentinel: bool=False, sentinelHosts: list=[]): + def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=False, unixSocketPath: str='/var/run/redis/redis-server.sock'): if useUnixSocket: self.redisClient = Redis(unix_socket_path=unixSocketPath) - elif useSentinel: - sentinelList = [] - for host in sentinelHosts: - for key, value in host.items(): - sentinelList.append((key, int(host.get('port', 6379)))) else: self.redisClient = Redis(host=host, port=port) diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 4d56e7f0..98bfc1b0 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -2,7 +2,6 @@ import traceback import socket import redis.asyncio as redis -from redis.asyncio.sentinel import Sentinel import time, json, uuid class RedisMessagingAsync: diff --git a/services/diameterService.py b/services/diameterService.py index af10e654..0555097b 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -2,7 +2,7 @@ import sys, os, json import time, yaml, uuid from datetime import datetime -import socket +import sctp, socket sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from diameterAsync import DiameterAsync @@ -288,6 +288,28 @@ async def startServer(self, host: str=None, port: int=None, type: str=None): if type.upper() == 'TCP': server = await(asyncio.start_server(self.handleConnection, host, port)) + elif type.upper() == 'SCTP': + self.sctpSocket = sctp.sctpsocket_tcp(socket.AF_INET) + self.sctpSocket.setblocking(False) + self.sctpSocket.events.clear() + self.sctpSocket.bind((host, port)) + self.sctpRtoInfo = self.sctpSocket.get_rtoinfo() + self.sctpRtoMin = self.config.get('hss', {}).get('sctp', {}).get('rtoMin', 500) + self.sctpRtoMax = self.config.get('hss', {}).get('sctp', {}).get('rtoMax', 5000) + self.sctpRtoInitial = self.config.get('hss', {}).get('sctp', {}).get('rtoInitial', 1000) + self.sctpRtoInfo.initial = int(self.sctpRtoInitial) + self.sctpRtoInfo.max = int(self.sctpRtoMax) + self.sctpRtoInfo.min = int(self.sctpRtoMin) + self.sctpSocket.set_rtoinfo(self.sctpRtoInfo) + self.sctpAssociatedParameters = self.sctpSocket.get_assocparams() + sctpInitParameters = { "initialRto": self.sctpRtoInfo.initial, + "rtoMin": self.sctpRtoInfo.min, + "rtoMax": self.sctpRtoInfo.max + } + self.sctpSocket.listen() + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [startServer] SCTP Parameters: {sctpInitParameters}")) + + server = await(asyncio.start_server(self.handleConnection, sock=self.sctpSocket)) else: return False servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) diff --git a/services/georedService.py b/services/georedService.py index b0cb7ae5..14d3117c 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -58,7 +58,7 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr if operation not in requestOperations: return False - headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId)} + headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId), "User-Agent": f"PyHSS/1.0.1 (Geored)"} for attempt in range(retryCount): try: @@ -147,7 +147,7 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr return True - async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, headers: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, headers: dict, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: """ Sends a Webhook HTTP request to a given endpoint. """ @@ -161,6 +161,9 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h if operation not in requestOperations: return False + + if 'User-Agent' not in headers: + headers['User-Agent'] = f"PyHSS/1.0.1 (Webhook)" for attempt in range(retryCount): try: From 81afdf5847018e75fda929fa6336d5c565d910d7 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 16 Jan 2024 12:43:03 +1000 Subject: [PATCH 3/9] Move emergency subscriber from redis model to sqlalchemy --- lib/database.py | 99 ++++++++++++++++++++++++++++++++++++++++ lib/diameter.py | 101 +---------------------------------------- services/apiService.py | 63 +++++++++++++++++++++++++ 3 files changed, 164 insertions(+), 99 deletions(-) diff --git a/lib/database.py b/lib/database.py index 6d256867..021d5ee0 100755 --- a/lib/database.py +++ b/lib/database.py @@ -167,6 +167,21 @@ class ROAMING_NETWORK(Base): last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') operation_logs = relationship("ROAMING_NETWORK_OPERATION_LOG", back_populates="roaming_network") +class EMERGENCY_SUBSCRIBER(Base): + __tablename__ = 'emergency_subscriber' + emergency_subscriber_id = Column(Integer, primary_key = True, doc='Unique ID of EMERGENCY_SUBSCRIBER entry') + imsi = Column(String(18), doc='International Mobile Subscriber Identity') + serving_pgw = Column(String(512), doc='PGW serving this subscriber') + serving_pgw_timestamp = Column(DateTime, doc='Timestamp of Gx CCR') + gx_origin_realm = Column(String(512), doc='Origin Realm of the Gx CCR') + gx_origin_host = Column(String(512), doc='Origin host of the Gx CCR') + rat_type = Column(String(512), doc='Radio access technology type that the emergency subscriber has used') + ip = Column(String(512), doc='IP of the emergency subscriber') + access_network_gateway_address = Column(String(512), doc='ANGW emergency that the subscriber has used') + access_network_charging_address = Column(String(512), doc='AN Charging Address that the emergency subscriber has used') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("EMERGENCY_SUBSCRIBER_OPERATION_LOG", back_populates="emergency_subscriber") + class ROAMING_RULE(Base): __tablename__ = 'roaming_rule' roaming_rule_id = Column(Integer, primary_key = True, doc='Unique ID of ROAMING_RULE entry') @@ -285,6 +300,11 @@ class ROAMING_NETWORK_OPERATION_LOG(OPERATION_LOG_BASE): roaming_network = relationship("ROAMING_NETWORK", back_populates="operation_logs") roaming_network_id = Column(Integer, ForeignKey('roaming_network.roaming_network_id')) +class EMERGENCY_SUBSCRIBER_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'emergency_subscriber'} + emergency_subscriber = relationship("EMERGENCY_SUBSCRIBER", back_populates="operation_logs") + emergency_subscriber_id = Column(Integer, ForeignKey('emergency_subscriber.emergency_subscriber_id')) + class CHARGING_RULE_OPERATION_LOG(OPERATION_LOG_BASE): __mapper_args__ = {'polymorphic_identity': 'charging_rule'} charging_rule = relationship("CHARGING_RULE", back_populates="operation_logs") @@ -2127,6 +2147,85 @@ def Get_IMS_Subscriber_By_Session_Id(self, sessionId): result = self.Sanitize_Datetime(result) return result + def Get_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId = None, imsi=None, **kwargs): + self.logTool.log(service='Database', level='debug', message="Getting Emergency_Subscriber " + str(emergencySubscriberId), redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + + """ + Work out what filters we're using for this query + """ + + queryFilters = {} + + if emergencySubscriberId: + queryFilters['emergency_subscriber_id'] = emergencySubscriberId + if subscriberIp: + queryFilters['ip'] = subscriberIp + if gxSessionId: + queryFilters['serving_pgw'] = gxSessionId + if imsi: + queryFilters['imsi'] = imsi + + try: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(**queryFilters).one() + if not result: + return None + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result + + def Update_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId = None, imsi=None, subscriberData: dict={}) -> bool: + """ + First, get at most one emergency subscriber matching the provided identifiers. + Then, update all data with the provided data. + """ + Session = sessionmaker(bind = self.engine) + session = Session() + + queryFilters = {} + + if emergencySubscriberId: + queryFilters['emergency_subscriber_id'] = emergencySubscriberId + if subscriberIp: + queryFilters['ip'] = subscriberIp + if gxSessionId: + queryFilters['serving_pgw'] = gxSessionId + if imsi: + queryFilters['imsi'] = imsi + + self.logTool.log(service='Database', level='debug', message=f"Getting Emergency_Subscriber with provided filters: {queryFilters}", redisClient=self.redisMessaging) + + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(**queryFilters).first() + + if result is None: + result = EMERGENCY_SUBSCRIBER() + session.add(result) + + result.imsi = subscriberData.get('imsi') + result.serving_pgw = subscriberData.get('servingPgw') + result.serving_pgw_timestamp = subscriberData.get('requestTime') + result.gx_origin_realm = subscriberData.get('gxOriginRealm') + result.gx_origin_host = subscriberData.get('gxOriginHost') + result.rat_type = subscriberData.get('ratType') + result.ip = subscriberData.get('ip') + result.access_network_gateway_address = subscriberData.get('accessNetworkGatewayAddress') + result.access_network_charging_address = subscriberData.get('accessNetworkChargingAddress') + + try: + session.commit() + except Exception as E: + self.safe_close(session) + raise ValueError(E) + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result + def Store_IMSI_IMEI_Binding(self, imsi, imei, match_response_code, propagate=True): #IMSI 14-15 Digits #IMEI 15 Digits diff --git a/lib/diameter.py b/lib/diameter.py index 66fe8c0f..b57c670c 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1051,102 +1051,6 @@ def validateSubscriberRoaming(self, subscriber: dict, mcc: str, mnc: str) -> boo return True - - def storeEmergencySubscriber(self, subscriberIp: str, subscriberData: dict, gxSessionId: str, authExpiry: int=3600, subscriberImsi: str="Unknown") -> bool: - """ - Store a given Emergency Subscriber in redis. - If there's an existing entry for the same IMSI, then update the record with the new IP and details. - The subscriber entry will expire per authExpiry in seconds. - """ - try: - emergencySubscriberKey = f"emergencySubscriber:{subscriberIp}:{subscriberImsi}:{gxSessionId}" - # Check if our subscriber exists - if subscriberImsi and subscriberImsi != "Unknown": - existingEmergencySubscriber = self.getEmergencySubscriber(subscriberImsi=subscriberImsi) - if existingEmergencySubscriber: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [getEmergencySubscriber] Found existing emergency subscriber to overwrite: {existingEmergencySubscriber}", redisClient=self.redisMessaging) - for key, value in existingEmergencySubscriber.items(): - self.redisMessaging.deleteQueue(queue=f"emergencySubscriber:{value.get('ip')}:{value.get('imsi')}:{value.get('servingPgw')}", redisPeerConnections=self.redisPeerConnections, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - result = self.redisMessaging.setValue(key=emergencySubscriberKey, value=json.dumps(subscriberData), keyExpiry=authExpiry, redisPeerConnections=self.redisPeerConnections, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - return True - except Exception as e: - self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getEmergencySubscriber] Error storing emergency subscriber in redis: {traceback.format_exc()}", redisClient=self.redisMessaging) - return False - - - def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=None, gxSessionId: str=None) -> dict: - """ - Retrieves a provided Emergency Subscriber from redis, if it exists. - The first match from any defined redis instance is used. - Returns None on no match found, or failure. - """ - try: - - if not subscriberIp and not subscriberImsi: - return None - - if subscriberIp and subscriberImsi: - emergencySubscriberKeyList = self.redisMessaging.getQueues(pattern=f"emergencySubscriber:{subscriberIp}:{subscriberImsi}:*", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - if emergencySubscriberKeyList: - for matchedKey in emergencySubscriberKeyList: - for peerName, keyName in matchedKey.items(): - if isinstance(keyName, list): - keyName = keyName[0] if len(keyName) > 0 else '' - emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.redisMessaging(peerName=peerName), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - if not emergencySubscriberData: - return None - emergencySubscriberData = json.loads(emergencySubscriberData) - emergencySubscriber = {peerName: emergencySubscriberData} - return emergencySubscriber - - if subscriberIp and not subscriberImsi: - emergencySubscriberKeyList = self.redisMessaging.getQueues(pattern=f"emergencySubscriber:{subscriberIp}:*", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - if emergencySubscriberKeyList: - for matchedKey in emergencySubscriberKeyList: - for peerName, keyName in matchedKey.items(): - if isinstance(keyName, list): - keyName = keyName[0] if len(keyName) > 0 else '' - emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.redisMessaging(peerName=peerName), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - if not emergencySubscriberData: - return None - emergencySubscriberData = json.loads(emergencySubscriberData) - emergencySubscriber = {peerName: emergencySubscriberData} - return emergencySubscriber - - if subscriberImsi and not subscriberIp: - emergencySubscriberKeyList = self.redisMessaging.getQueues(pattern=f"emergencySubscriber:*:{subscriberImsi}:*", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - if emergencySubscriberKeyList: - for matchedKey in emergencySubscriberKeyList: - for peerName, keyName in matchedKey.items(): - if isinstance(keyName, list): - keyName = keyName[0] if len(keyName) > 0 else '' - emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.redisMessaging(peerName=peerName), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - if not emergencySubscriberData: - return None - emergencySubscriberData = json.loads(emergencySubscriberData) - emergencySubscriber = {peerName: emergencySubscriberData} - return emergencySubscriber - - if gxSessionId: - emergencySubscriberKeyList = self.redisMessaging.getQueues(pattern=f"emergencySubscriber:*:*:{gxSessionId}", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - if emergencySubscriberKeyList: - for matchedKey in emergencySubscriberKeyList: - for peerName, keyName in matchedKey.items(): - if isinstance(keyName, list): - keyName = keyName[0] if len(keyName) > 0 else '' - emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.redisMessaging(peerName=peerName), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') - if not emergencySubscriberData: - return None - emergencySubscriberData = json.loads(emergencySubscriberData) - emergencySubscriber = {peerName: emergencySubscriberData} - return emergencySubscriber - - return None - - except Exception as e: - self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getEmergencySubscriber] Error getting emergency subscriber from redis: {traceback.format_exc()}", redisClient=self.redisMessaging) - return None - def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body for avp_dicts in avps: if avp_dicts['avp_code'] == 278: @@ -2038,7 +1942,7 @@ def Answer_16777238_272(self, packet_vars, avps): "accessNetworkChargingAddress": accessNetworkChargingAddress, } - self.storeEmergencySubscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, subscriberImsi=imsi, gxSessionId=emergencySubscriberData.get('servingPgw')) + self.database.Update_Emergency_Subscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, imsi=imsi, gxSessionId=emergencySubscriberData.get('servingPgw')) avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet @@ -2778,7 +2682,7 @@ def Answer_16777236_265(self, packet_vars, avps): Determine if the AAR for the IP belongs to an inbound roaming emergency subscriber. """ try: - emergencySubscriberData = self.getEmergencySubscriber(subscriberIp=ueIp) + emergencySubscriberData = self.database.Get_Emergency_Subscriber(subscriberIp=ueIp) if emergencySubscriberData: emergencySubscriber = True except Exception as e: @@ -2824,7 +2728,6 @@ def Answer_16777236_265(self, packet_vars, avps): except Exception as e: pass - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] IMSI: {imsi}\nMSISDN: {msisdn}", redisClient=self.redisMessaging) imsEnabled = self.validateImsSubscriber(imsi=imsi, msisdn=msisdn) diff --git a/services/apiService.py b/services/apiService.py index aca8d4cd..0f9dd265 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -69,6 +69,8 @@ SUBSCRIBER_ROUTING = database.SUBSCRIBER_ROUTING ROAMING_NETWORK = database.ROAMING_NETWORK ROAMING_RULE = database.ROAMING_RULE +EMERGENCY_SUBSCRIBER = database.EMERGENCY_SUBSCRIBER + apiService.wsgi_app = ProxyFix(apiService.wsgi_app) api = Api(apiService, version='1.0', title=f'{siteName + " - " if siteName else ""}{originHostname} - PyHSS OAM API', @@ -128,6 +130,9 @@ databaseClient.Generate_JSON_Model_for_Flask(ROAMING_RULE) ) +EMERGENCY_SUBSCRIBER_model = api.schema_model('EMERGENCY_SUBSCRIBER JSON', + databaseClient.Generate_JSON_Model_for_Flask(EMERGENCY_SUBSCRIBER) +) #Legacy support for sh_profile. sh_profile is deprecated as of v1.0.1. imsSubscriberModel = databaseClient.Generate_JSON_Model_for_Flask(IMS_SUBSCRIBER) @@ -1828,6 +1833,64 @@ def get(self, subscriber_routing): print(E) return handle_exception(E) +@ns_pcrf.route('/') +class PyHSS_EMERGENCY_SUBSCRIBER_Get(Resource): + def get(self, emergency_subscriber_id): + '''Get all EMERGENCY_SUBSCRIBER data for specified EMERGENCY_SUBSCRIBER ID''' + try: + apn_data = databaseClient.GetObj(EMERGENCY_SUBSCRIBER, emergency_subscriber_id) + return apn_data, 200 + except Exception as E: + print(E) + return handle_exception(E) + + def delete(self, emergency_subscriber_id): + '''Delete all EMERGENCY_SUBSCRIBER data for specified EMERGENCY_SUBSCRIBER ID''' + try: + args = parser.parse_args() + operation_id = args.get('operation_id', None) + data = databaseClient.DeleteObj(EMERGENCY_SUBSCRIBER, emergency_subscriber_id, False, operation_id) + return data, 200 + except Exception as E: + print(E) + return handle_exception(E) + + @ns_pcrf.doc('Update EMERGENCY_SUBSCRIBER Object') + @ns_pcrf.expect(EMERGENCY_SUBSCRIBER_model) + def patch(self, emergency_subscriber_id): + '''Update EMERGENCY_SUBSCRIBER data for specified EMERGENCY_SUBSCRIBER ID''' + try: + json_data = request.get_json(force=True) + print("JSON Data sent: " + str(json_data)) + args = parser.parse_args() + operation_id = args.get('operation_id', None) + apn_data = databaseClient.UpdateObj(EMERGENCY_SUBSCRIBER, json_data, emergency_subscriber_id, False, operation_id) + + print("Updated object") + print(apn_data) + return apn_data, 200 + except Exception as E: + print(E) + return handle_exception(E) + +@ns_pcrf.route('/') +class PyHSS_EMERGENCY_SUBSCRIBER(Resource): + @ns_pcrf.doc('Create EMERGENCY_SUBSCRIBER Object') + @ns_pcrf.expect(EMERGENCY_SUBSCRIBER_model) + def put(self): + '''Create new EMERGENCY_SUBSCRIBER''' + try: + json_data = request.get_json(force=True) + print("JSON Data sent: " + str(json_data)) + args = parser.parse_args() + operation_id = args.get('operation_id', None) + emergency_subscriber_id = databaseClient.CreateObj(EMERGENCY_SUBSCRIBER, json_data, False, operation_id) + + return emergency_subscriber_id, 200 + except Exception as E: + print(E) + return handle_exception(E) + @ns_geored.route('/') class PyHSS_Geored(Resource): @ns_geored.doc('Receive GeoRed data') From 1e957b8504e97c74f42581c23945e3603430a7c2 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 16 Jan 2024 22:01:19 +1000 Subject: [PATCH 4/9] Emergency subscriber migration complete, pre testing --- lib/database.py | 180 +++++++++++++++++++++------- lib/diameter.py | 263 ++++++++++++++++++++++++++--------------- services/apiService.py | 97 +++++++++++---- 3 files changed, 381 insertions(+), 159 deletions(-) diff --git a/lib/database.py b/lib/database.py index 021d5ee0..f0950172 100755 --- a/lib/database.py +++ b/lib/database.py @@ -172,7 +172,9 @@ class EMERGENCY_SUBSCRIBER(Base): emergency_subscriber_id = Column(Integer, primary_key = True, doc='Unique ID of EMERGENCY_SUBSCRIBER entry') imsi = Column(String(18), doc='International Mobile Subscriber Identity') serving_pgw = Column(String(512), doc='PGW serving this subscriber') - serving_pgw_timestamp = Column(DateTime, doc='Timestamp of Gx CCR') + serving_pgw_timestamp = Column(String(512), doc='Timestamp of Gx CCR') + serving_pcscf = Column(String(512), doc='PCSCF serving this subscriber') + serving_pcscf_timestamp = Column(String(512), doc='Timestamp of Rx Media AAR') gx_origin_realm = Column(String(512), doc='Origin Realm of the Gx CCR') gx_origin_host = Column(String(512), doc='Origin host of the Gx CCR') rat_type = Column(String(512), doc='Radio access technology type that the emergency subscriber has used') @@ -2147,68 +2149,97 @@ def Get_IMS_Subscriber_By_Session_Id(self, sessionId): result = self.Sanitize_Datetime(result) return result - def Get_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId = None, imsi=None, **kwargs): - self.logTool.log(service='Database', level='debug', message="Getting Emergency_Subscriber " + str(emergencySubscriberId), redisClient=self.redisMessaging) + def Get_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId: str=None, rxSessionId: str=None, imsi: str=None, **kwargs) -> dict: + self.logTool.log(service='Database', level='debug', message=f"Getting Emergency_Subscriber}", redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) session = Session() - """ - Work out what filters we're using for this query - """ - - queryFilters = {} - - if emergencySubscriberId: - queryFilters['emergency_subscriber_id'] = emergencySubscriberId - if subscriberIp: - queryFilters['ip'] = subscriberIp - if gxSessionId: - queryFilters['serving_pgw'] = gxSessionId - if imsi: - queryFilters['imsi'] = imsi + result = None try: - result = session.query(EMERGENCY_SUBSCRIBER).filter_by(**queryFilters).one() + while not result: + if imsi and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(imsi=imsi).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Emergency_Subscriber] Matched emergency subscriber on IMSI: {imsi}", redisClient=self.redisMessaging) + break + if emergencySubscriberId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(emergency_subscriber_id=emergencySubscriberId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Emergency_Subscriber] Matched emergency subscriber on emergency_subscriber_id: {emergencySubscriberId}", redisClient=self.redisMessaging) + break + if subscriberIp and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(ip=subscriberIp).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Emergency_Subscriber] Matched emergency subscriber on IP: {subscriberIp}", redisClient=self.redisMessaging) + break + if gxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pgw=gxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Emergency_Subscriber] Matched emergency subscriber on Gx Session ID: {gxSessionId}", redisClient=self.redisMessaging) + break + if rxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pcscf=rxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Rx Session ID: {rxSessionId}", redisClient=self.redisMessaging) + break + break + if not result: return None + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"[database.py] [Get_Emergency_Subscriber] Error getting emergency subscriber: {traceback.format_exc()}", redisClient=self.redisMessaging) self.safe_close(session) - raise ValueError(E) - result = result.__dict__ - result.pop('_sa_instance_state') - self.safe_close(session) - return result + return None - def Update_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId = None, imsi=None, subscriberData: dict={}) -> bool: + def Update_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId: str=None, rxSessionId: str=None, imsi: str=None, subscriberData: dict={}, propagate: bool=True) -> dict: """ - First, get at most one emergency subscriber matching the provided identifiers. - Then, update all data with the provided data. + First, get at most one emergency subscriber. + Try and match on IMSI first (To detect an updated IP for an existing record), + If IMSI is None or no result was found, then try with a combination of all of the arguments. + Then update all data with the provided subscriberData, and push to geored. """ Session = sessionmaker(bind = self.engine) session = Session() - queryFilters = {} - - if emergencySubscriberId: - queryFilters['emergency_subscriber_id'] = emergencySubscriberId - if subscriberIp: - queryFilters['ip'] = subscriberIp - if gxSessionId: - queryFilters['serving_pgw'] = gxSessionId - if imsi: - queryFilters['imsi'] = imsi + result = None - self.logTool.log(service='Database', level='debug', message=f"Getting Emergency_Subscriber with provided filters: {queryFilters}", redisClient=self.redisMessaging) + while not result: + if imsi and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(imsi=imsi).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on IMSI: {imsi}", redisClient=self.redisMessaging) + break + if emergencySubscriberId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(emergency_subscriber_id=emergencySubscriberId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on emergency_subscriber_id: {emergencySubscriberId}", redisClient=self.redisMessaging) + break + if subscriberIp and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(ip=subscriberIp).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on IP: {subscriberIp}", redisClient=self.redisMessaging) + break + if gxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pgw=gxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Gx Session ID: {gxSessionId}", redisClient=self.redisMessaging) + break + if rxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pcscf=rxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Rx Session ID: {rxSessionId}", redisClient=self.redisMessaging) + break + break - result = session.query(EMERGENCY_SUBSCRIBER).filter_by(**queryFilters).first() - if result is None: + """ + If we havent matched in on any entries at this point, create a new emergency subscriber. + """ + if not result: result = EMERGENCY_SUBSCRIBER() session.add(result) result.imsi = subscriberData.get('imsi') result.serving_pgw = subscriberData.get('servingPgw') result.serving_pgw_timestamp = subscriberData.get('requestTime') + result.serving_pcscf = subscriberData.get('servingPcscf') + result.serving_pcscf_timestamp = subscriberData.get('aarRequestTime') result.gx_origin_realm = subscriberData.get('gxOriginRealm') result.gx_origin_host = subscriberData.get('gxOriginHost') result.rat_type = subscriberData.get('ratType') @@ -2218,14 +2249,83 @@ def Update_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscribe try: session.commit() + emergencySubscriberId = result.emergency_subscriber_id + if propagate: + self.handleGeored({ "emergency_subscriber_id": int(emergencySubscriberId), + "emergency_subscriber_imsi": subscriberData.get('imsi'), + "emergency_subscriber_serving_pgw": subscriberData.get('servingPgw'), + "emergency_subscriber_serving_pgw_timestamp": subscriberData.get('requestTime'), + "emergency_subscriber_serving_pcscf": subscriberData.get('servingPcscf'), + "emergency_subscriber_serving_pcscf_timestamp": subscriberData.get('aarRequestTime'), + "emergency_subscriber_gx_origin_realm": subscriberData.get('gxOriginRealm'), + "emergency_subscriber_gx_origin_host": subscriberData.get('gxOriginHost'), + "emergency_subscriber_rat_type": subscriberData.get('ratType'), + "emergency_subscriber_ip": subscriberData.get('ip'), + "emergency_subscriber_access_network_gateway_address": subscriberData.get('accessNetworkGatewayAddress'), + "emergency_subscriber_access_network_charging_address": subscriberData.get('accessNetworkChargingAddress'), + }) + except Exception as E: self.safe_close(session) - raise ValueError(E) + self.logTool.log(service='Database', level='error', message=f"[database.py] [Update_Emergency_Subscriber] Error updating emergency subscriber: {traceback.format_exc()}", redisClient=self.redisMessaging) + return None result = result.__dict__ result.pop('_sa_instance_state') self.safe_close(session) return result + def Delete_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId: str=None, rxSessionId: str=None, imsi: str=None, subscriberData: dict={}, propagate: bool=True) -> bool: + """ + First, get at most one emergency subscriber matching the provided identifiers. + Then delete the emergency subscriber, and push to geored. + """ + Session = sessionmaker(bind = self.engine) + session = Session() + + result = None + + while not result: + if imsi and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(imsi=imsi).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on IMSI: {imsi}", redisClient=self.redisMessaging) + break + if emergencySubscriberId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(emergency_subscriber_id=emergencySubscriberId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on emergency_subscriber_id: {emergencySubscriberId}", redisClient=self.redisMessaging) + break + if subscriberIp and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(ip=subscriberIp).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on IP: {subscriberIp}", redisClient=self.redisMessaging) + break + if gxSessionId and not result: + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Gx Session ID: {gxSessionId}", redisClient=self.redisMessaging) + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pgw=gxSessionId).first() + break + if rxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pcscf=rxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Rx Session ID: {rxSessionId}", redisClient=self.redisMessaging) + break + break + + if not result: + return True + + try: + emergencySubscriberId = result.emergency_subscriber_id + session.delete(result) + session.commit() + if propagate: + self.handleGeored({ "emergency_subscriber_id": int(emergencySubscriberId), + "emergency_subscriber_delete": True, + }) + self.safe_close(session) + return True + except Exception as E: + self.safe_close(session) + self.logTool.log(service='Database', level='error', message=f"[database.py] [Delete_Emergency_Subscriber] Error deleting emergency subscriber: {traceback.format_exc()}", redisClient=self.redisMessaging) + return False + + def Store_IMSI_IMEI_Binding(self, imsi, imei, match_response_code, propagate=True): #IMSI 14-15 Digits #IMEI 15 Digits diff --git a/lib/diameter.py b/lib/diameter.py index b57c670c..59bfb481 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1855,98 +1855,155 @@ def Answer_16777238_272(self, packet_vars, avps): """ try: if apn.lower() == 'sos': - # Use our defined SOS APN AMBR, if defined. - # Otherwise, use a default value of 128/128kbps. - try: - sosApn = (self.database.Get_APN_by_Name(apn="sos")) - AMBR = '' #Initiate empty var AVP for AMBR - apn_ambr_ul = int(sosApn['apn_ambr_ul']) - apn_ambr_dl = int(sosApn['apn_ambr_dl']) - AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL - AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL - APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) - - AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(sosApn['arp_priority']), 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not sosApn['arp_preemption_capability']), 4)) - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not sosApn['arp_preemption_vulnerability']), 4)) - AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(sosApn['qci']), 4)) - avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) + if int(CC_Request_Type) == 1: + """ + If we've recieved a CCR-Initial, create an emergency subscriber. + """ + # Use our defined SOS APN AMBR, if defined. + # Otherwise, use a default value of 128/128kbps. + try: + sosApn = (self.database.Get_APN_by_Name(apn="sos")) + AMBR = '' #Initiate empty var AVP for AMBR + apn_ambr_ul = int(sosApn['apn_ambr_ul']) + apn_ambr_dl = int(sosApn['apn_ambr_dl']) + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(sosApn['arp_priority']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not sosApn['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not sosApn['arp_preemption_vulnerability']), 4)) + AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(sosApn['qci']), 4)) + avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) - except Exception as e: - AMBR = '' #Initiate empty var AVP for AMBR - apn_ambr_ul = 128000 - apn_ambr_dl = 128000 - AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL - AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL - APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + except Exception as e: + AMBR = '' #Initiate empty var AVP for AMBR + apn_ambr_ul = 128000 + apn_ambr_dl = 128000 + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(1, 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(0, 4)) # Pre-Emption Capability Enabled + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(1, 4)) # Pre-Emption Vulnerability Disabled + AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(5, 4)) # QCI 5 + avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) + + QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) + QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) + avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) # QOS-Information + + #Supported-Features(628) (Gx feature list) + avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") + + """ + Store the Emergency Subscriber + """ + ueIp = self.get_avp_data(avps, 8)[0] + ueIp = str(self.hex_to_ip(ueIp)) + try: + #Get the IMSI + for SubscriptionIdentifier in self.get_avp_data(avps, 443): + for UniqueSubscriptionIdentifier in SubscriptionIdentifier: + if UniqueSubscriptionIdentifier['avp_code'] == 444: + imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') + except Exception as e: + imsi="Unknown" - AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(1, 4)) - AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(0, 4)) # Pre-Emption Capability Enabled - AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(1, 4)) # Pre-Emption Vulnerability Disabled - AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) - AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(5, 4)) # QCI 5 - avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) - - QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) - QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) - avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) # QOS-Information + try: + ratType = self.get_avp_data(avps, 1032)[0] + ratType = int(ratType, 16) + except Exception as e: + ratType = None - #Supported-Features(628) (Gx feature list) - avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") + try: + accessNetworkGatewayAddress = self.get_avp_data(avps, 1050)[0] + accessNetworkGatewayAddress = str(self.hex_to_ip(accessNetworkGatewayAddress[4:])) + except Exception as e: + accessNetworkGatewayAddress = None - """ - Store the Emergency Subscriber in redis - """ - ueIp = self.get_avp_data(avps, 8)[0] - ueIp = str(self.hex_to_ip(ueIp)) - try: - #Get the IMSI - for SubscriptionIdentifier in self.get_avp_data(avps, 443): - for UniqueSubscriptionIdentifier in SubscriptionIdentifier: - if UniqueSubscriptionIdentifier['avp_code'] == 444: - imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') - except Exception as e: - imsi="Unknown" + try: + accessNetworkChargingAddress = self.get_avp_data(avps, 501)[0] + accessNetworkChargingAddress = str(self.hex_to_ip(accessNetworkChargingAddress[4:])) + except Exception as e: + accessNetworkChargingAddress = None + + emergencySubscriberData = { + "servingPgw": binascii.unhexlify(session_id).decode(), + "requestTime": int(time.time()), + "servingPcscf": None, + "aarRequestTime": None, + "gxOriginRealm": OriginRealm, + "gxOriginHost": OriginHost, + "imsi": imsi, + "ip": ueIp, + "ratType": ratType, + "accessNetworkGatewayAddress": accessNetworkGatewayAddress, + "accessNetworkChargingAddress": accessNetworkChargingAddress, + } + + self.database.Update_Emergency_Subscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, imsi=imsi, gxSessionId=emergencySubscriberData.get('servingPgw')) + + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response - try: - ratType = self.get_avp_data(avps, 1032)[0] - ratType = int(ratType, 16) - except Exception as e: - ratType = None - pass + elif int(CC_Request_Type) == 3: + """ + If we've recieved a CCR-Terminate, delete the emergency subscriber. + """ + try: + ueIp = self.get_avp_data(avps, 8)[0] + ueIp = str(self.hex_to_ip(ueIp)) + except Exception as e: + ueIp = None + try: + #Get the IMSI + for SubscriptionIdentifier in self.get_avp_data(avps, 443): + for UniqueSubscriptionIdentifier in SubscriptionIdentifier: + if UniqueSubscriptionIdentifier['avp_code'] == 444: + imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') + except Exception as e: + imsi="Unknown" - try: - accessNetworkGatewayAddress = self.get_avp_data(avps, 1050)[0] - accessNetworkGatewayAddress = str(self.hex_to_ip(accessNetworkGatewayAddress)) - except Exception as e: - accessNetworkGatewayAddress = None - pass + try: + ratType = self.get_avp_data(avps, 1032)[0] + ratType = int(ratType, 16) + except Exception as e: + ratType = None - try: - accessNetworkChargingAddress = self.get_avp_data(avps, 501)[0] - accessNetworkChargingAddress = str(self.hex_to_ip(accessNetworkChargingAddress)) - except Exception as e: - accessNetworkChargingAddress = None - pass + try: + accessNetworkGatewayAddress = self.get_avp_data(avps, 1050)[0] + accessNetworkGatewayAddress = str(self.hex_to_ip(accessNetworkGatewayAddress)) + except Exception as e: + accessNetworkGatewayAddress = None - emergencySubscriberData = { - "servingPgw": binascii.unhexlify(session_id).decode(), - "requestTime": int(time.time()), - "gxOriginRealm": OriginRealm, - "gxOriginHost": OriginHost, - "imsi": imsi, - "ip": ueIp, - "ratType": ratType, - "accessNetworkGatewayAddress": accessNetworkGatewayAddress, - "accessNetworkChargingAddress": accessNetworkChargingAddress, - } - - self.database.Update_Emergency_Subscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, imsi=imsi, gxSessionId=emergencySubscriberData.get('servingPgw')) - - avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) - response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet - return response + try: + accessNetworkChargingAddress = self.get_avp_data(avps, 501)[0] + accessNetworkChargingAddress = str(self.hex_to_ip(accessNetworkChargingAddress)) + except Exception as e: + accessNetworkChargingAddress = None + + emergencySubscriberData = { + "servingPgw": binascii.unhexlify(session_id).decode(), + "requestTime": int(time.time()), + "gxOriginRealm": OriginRealm, + "gxOriginHost": OriginHost, + "imsi": imsi, + "ip": ueIp, + "ratType": ratType, + "accessNetworkGatewayAddress": accessNetworkGatewayAddress, + "accessNetworkChargingAddress": accessNetworkChargingAddress, + } + + self.database.Delete_Emergency_Subscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, imsi=imsi, gxSessionId=binascii.unhexlify(session_id).decode()) + + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response except Exception as e: self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777238_272] [CCA] Error generating SOS CCA: {traceback.format_exc()}", redisClient=self.redisMessaging) @@ -2774,11 +2831,10 @@ def Answer_16777236_265(self, packet_vars, avps): try: if emergencySubscriber and not imsEnabled: - for key, value in emergencySubscriberData.items(): - servingPgwPeer = emergencySubscriberData[key].get('servingPgw', None).split(';')[0] - pcrfSessionId = emergencySubscriberData[key].get('servingPgw', None) - servingPgwRealm = emergencySubscriberData[key].get('gxOriginRealm', None) - servingPgw = emergencySubscriberData[key].get('servingPgw', None).split(';')[0] + servingPgwPeer = emergencySubscriberData.get('serving_pgw', None).split(';')[0] + pcrfSessionId = emergencySubscriberData.get('serving_pgw', None) + servingPgwRealm = emergencySubscriberData.get('gx_origin_realm', None) + servingPgw = emergencySubscriberData.get('serving_pgw', None).split(';')[0] else: subscriberId = subscriberDetails.get('subscriber_id', None) apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) @@ -2919,7 +2975,22 @@ def Answer_16777236_265(self, packet_vars, avps): if not emergencySubscriber: self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=sessionId) - + else: + updatedEmergencySubscriberData = { + "servingPgw": emergencySubscriberData.get('serving_pgw'), + "requestTime": emergencySubscriberData.get('serving_pgw_timestamp'), + "servingPcscf": sessionId, + "aarRequestTime": int(time.time()), + "gxOriginRealm": emergencySubscriberData.get('gx_origin_realm'), + "gxOriginHost": emergencySubscriberData.get('gx_origin_host'), + "imsi": emergencySubscriberData.get('imsi'), + "ip": emergencySubscriberData.get('ip'), + "ratType": emergencySubscriberData.get('rat_type'), + "accessNetworkGatewayAddress": emergencySubscriberData.get('access_network_gateway_address'), + "accessNetworkChargingAddress": emergencySubscriberData.get('access_network_charging_address'), + } + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Updating Emergency Subscriber: {updatedEmergencySubscriberData}", redisClient=self.redisMessaging) + self.database.Update_Emergency_Subscriber(subscriberIp=ueIp, subscriberData=updatedEmergencySubscriberData, imsi=imsi) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAR Generated to be sent to serving PGW: {servingPgw} via peer {servingPgwPeer}", redisClient=self.redisMessaging) reAuthAnswer = self.awaitDiameterRequestAndResponse( @@ -3071,19 +3142,19 @@ def Answer_16777236_275(self, packet_vars, avps): Determine if the Session-ID for the STR belongs to an inbound roaming emergency subscriber. """ try: - emergencySubscriberData = self.getEmergencySubscriber(gxSessionId=sessionId) + emergencySubscriberData = self.database.Get_Emergency_Subscriber(rxSessionId=sessionId) if emergencySubscriberData: emergencySubscriber = True + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] Found emergency subscriber with Rx Session: {sessionId}", redisClient=self.redisMessaging) except Exception as e: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] Error getting Emergency Subscriber Data: {traceback.format_exc()}", redisClient=self.redisMessaging) emergencySubscriberData = None if emergencySubscriberData: - for key, value in emergencySubscriberData.items(): - servingPgwPeer = emergencySubscriberData[key].get('servingPgw', None).split(';')[0] - pcrfSessionId = emergencySubscriberData[key].get('servingPgw', None) - servingPgwRealm = emergencySubscriberData[key].get('gxOriginRealm', None) - servingPgw = emergencySubscriberData[key].get('servingPgw', None).split(';')[0] + servingPgwPeer = emergencySubscriberData.get('serving_pgw', None).split(';')[0] + pcrfSessionId = emergencySubscriberData.get('serving_pgw', None) + servingPgwRealm = emergencySubscriberData.get('gx_origin_realm', None) + servingPgw = emergencySubscriberData.get('serving_pgw', None).split(';')[0] if servingApn is not None or emergencySubscriberData: reAuthAnswer = self.awaitDiameterRequestAndResponse( diff --git a/services/apiService.py b/services/apiService.py index 0f9dd265..2317822c 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -207,26 +207,20 @@ 'scscf_timestamp' : fields.String(description=IMS_SUBSCRIBER.scscf_timestamp.doc), 'imei' : fields.String(description=EIR.imei.doc), 'match_response_code' : fields.String(description=EIR.match_response_code.doc), + 'emergency_subscriber_id': fields.String(description=EMERGENCY_SUBSCRIBER.emergency_subscriber_id.doc), + 'emergency_subscriber_imsi': fields.String(description=EMERGENCY_SUBSCRIBER.imsi.doc), + 'emergency_subscriber_serving_pgw': fields.String(description=EMERGENCY_SUBSCRIBER.serving_pgw.doc), + 'emergency_subscriber_serving_pgw_timestamp': fields.String(description=EMERGENCY_SUBSCRIBER.serving_pgw_timestamp.doc), + 'emergency_subscriber_serving_pcscf': fields.String(description=EMERGENCY_SUBSCRIBER.serving_pcscf.doc), + 'emergency_subscriber_serving_pcscf_timestamp': fields.String(description=EMERGENCY_SUBSCRIBER.serving_pcscf_timestamp.doc), + 'emergency_subscriber_gx_origin_realm': fields.String(description=EMERGENCY_SUBSCRIBER.gx_origin_realm.doc), + 'emergency_subscriber_gx_origin_host': fields.String(description=EMERGENCY_SUBSCRIBER.gx_origin_host.doc), + 'emergency_subscriber_rat_type': fields.String(description=EMERGENCY_SUBSCRIBER.rat_type.doc), + 'emergency_subscriber_ip': fields.String(description=EMERGENCY_SUBSCRIBER.ip.doc), + 'emergency_subscriber_access_network_gateway_address': fields.String(description=EMERGENCY_SUBSCRIBER.access_network_gateway_address.doc), + 'emergency_subscriber_access_network_charging_address': fields.String(description=EMERGENCY_SUBSCRIBER.access_network_charging_address.doc), }) -Geored_schema = { - 'serving_mme': "string", - 'serving_mme_realm': "string", - 'serving_mme_peer': "string", - 'serving_mme_timestamp': "string", - 'serving_apn' : "string", - 'pcrf_session_id' : "string", - 'subscriber_routing' : "string", - 'serving_pgw' : "string", - 'serving_pgw_timestamp' : "string", - 'scscf' : "string", - 'imei' : "string", - 'match_response_code' : "string", - 'auc_id': "int", - 'sqn': "int", -} - - def no_auth_required(f): f.no_auth_required = True return f @@ -1833,10 +1827,10 @@ def get(self, subscriber_routing): print(E) return handle_exception(E) -@ns_pcrf.route('/') +@ns_pcrf.route('/emergency_subscriber/') class PyHSS_EMERGENCY_SUBSCRIBER_Get(Resource): def get(self, emergency_subscriber_id): - '''Get all EMERGENCY_SUBSCRIBER data for specified EMERGENCY_SUBSCRIBER ID''' + '''Get all emergency_subscriber data for specified emergency_subscriber ID''' try: apn_data = databaseClient.GetObj(EMERGENCY_SUBSCRIBER, emergency_subscriber_id) return apn_data, 200 @@ -1845,7 +1839,7 @@ def get(self, emergency_subscriber_id): return handle_exception(E) def delete(self, emergency_subscriber_id): - '''Delete all EMERGENCY_SUBSCRIBER data for specified EMERGENCY_SUBSCRIBER ID''' + '''Delete all emergency_subscriber data for specified emergency_subscriber ID''' try: args = parser.parse_args() operation_id = args.get('operation_id', None) @@ -1858,7 +1852,7 @@ def delete(self, emergency_subscriber_id): @ns_pcrf.doc('Update EMERGENCY_SUBSCRIBER Object') @ns_pcrf.expect(EMERGENCY_SUBSCRIBER_model) def patch(self, emergency_subscriber_id): - '''Update EMERGENCY_SUBSCRIBER data for specified EMERGENCY_SUBSCRIBER ID''' + '''Update emergency_subscriber data for specified emergency_subscriber ID''' try: json_data = request.get_json(force=True) print("JSON Data sent: " + str(json_data)) @@ -1873,7 +1867,7 @@ def patch(self, emergency_subscriber_id): print(E) return handle_exception(E) -@ns_pcrf.route('/') +@ns_pcrf.route('/emergency_subscriber/') class PyHSS_EMERGENCY_SUBSCRIBER(Resource): @ns_pcrf.doc('Create EMERGENCY_SUBSCRIBER Object') @ns_pcrf.expect(EMERGENCY_SUBSCRIBER_model) @@ -1891,6 +1885,19 @@ def put(self): print(E) return handle_exception(E) +@ns_pcrf.route('/emergency_subscriber/list') +class PyHSS_ALL_EMERGENCY_SUBSCRIBER(Resource): + @ns_apn.expect(paginatorParser) + def get(self): + '''Get all Emergency Subscribers''' + try: + args = paginatorParser.parse_args() + data = databaseClient.getAllPaginated(EMERGENCY_SUBSCRIBER, args['page'], args['page_size']) + return (data), 200 + except Exception as E: + print(E) + return handle_exception(E) + @ns_geored.route('/') class PyHSS_Geored(Resource): @ns_geored.doc('Receive GeoRed data') @@ -1997,6 +2004,50 @@ def patch(self): "geored_host": request.remote_addr, }, metricExpiry=60) + if 'emergency_subscriber_id' in json_data: + """ + If we receive a geored payload containing emergency_subscriber_id, create or update the matching emergency_subscriber_id. + If emergency_subscriber_id exists as None, then remove the emergency subscriber. + """ + print("Updating Emergency Subscriber") + subscriberData = { + "imsi": json_data.get('emergency_subscriber_imsi'), + "servingPgw": json_data.get('emergency_subscriber_serving_pgw'), + "requestTime": json_data.get('emergency_subscriber_serving_pgw_timestamp'), + "servingPcscf": json_data.get('emergency_subscriber_serving_pcscf'), + "aarRequestTime": json_data.get('emergency_subscriber_serving_pcscf_timestamp'), + "gxOriginRealm": json_data.get('emergency_subscriber_gx_origin_realm'), + "gxOriginHost": json_data.get('emergency_subscriber_gx_origin_host'), + "ratType": json_data.get('emergency_subscriber_rat_type'), + "ip": json_data.get('emergency_subscriber_ip'), + "accessNetworkGatewayAddress": json_data.get('emergency_subscriber_access_network_gateway_address'), + "accessNetworkChargingAddress": json_data.get('emergency_subscriber_access_network_charging_address'), + } + + if not json_data.get('emergency_subscriber_id', None): + logTool.log(service='API', level='error', message=f"[API] emergency_subscriber_id missing from geored request. No changes to emergency_subscriber made.", redisClient=redisMessaging) + return {'result': 'Failed', 'Reason' : "emergency_subscriber_id missing from geored request"} + + if 'emergency_subscriber_delete' in json_data: + if json_data.get('emergency_subscriber_delete', False): + databaseClient.deleteObj(EMERGENCY_SUBSCRIBER, json_data.get('emergency_subscriber_id')) + return {}, 200 + + response_data.append(databaseClient.Update_Emergency_Subscriber(emergencySubscriberId=json_data['emergency_subscriber_id'], + subscriberData=subscriberData, + imsi=subscriberData.get('imsi'), + subscriberIp=subscriberData.get('ip'), + gxSessionId=subscriberData.get('servingPgw'), + propagate=False)) + + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "EMERGENCY_SUBSCRIBER", + "geored_host": request.remote_addr, + }, + metricExpiry=60) return response_data, 200 except Exception as E: print("Exception when updating: " + str(E)) From 00e7d23b0542486c8731e686e14329ccedc9adf8 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 16 Jan 2024 22:04:13 +1000 Subject: [PATCH 5/9] Remove stray curly brace --- lib/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/database.py b/lib/database.py index f0950172..8c3179c6 100755 --- a/lib/database.py +++ b/lib/database.py @@ -2150,7 +2150,7 @@ def Get_IMS_Subscriber_By_Session_Id(self, sessionId): return result def Get_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId: str=None, rxSessionId: str=None, imsi: str=None, **kwargs) -> dict: - self.logTool.log(service='Database', level='debug', message=f"Getting Emergency_Subscriber}", redisClient=self.redisMessaging) + self.logTool.log(service='Database', level='debug', message=f"Getting Emergency_Subscriber", redisClient=self.redisMessaging) Session = sessionmaker(bind = self.engine) session = Session() From 1fdcf42d28470e22b3fd9474d6cb46c25fc5fc3f Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 16 Jan 2024 22:32:07 +1000 Subject: [PATCH 6/9] Rely on IP instead of ID for emergency_subscriber geored --- lib/database.py | 4 +++- services/apiService.py | 11 ++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/lib/database.py b/lib/database.py index 8c3179c6..8799cdf9 100755 --- a/lib/database.py +++ b/lib/database.py @@ -2315,7 +2315,9 @@ def Delete_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscribe session.delete(result) session.commit() if propagate: - self.handleGeored({ "emergency_subscriber_id": int(emergencySubscriberId), + self.handleGeored({ + "emergency_subscriber_imsi": subscriberData.get('imsi'), + "emergency_subscriber_ip": subscriberData.get('ip'), "emergency_subscriber_delete": True, }) self.safe_close(session) diff --git a/services/apiService.py b/services/apiService.py index 2317822c..9bc6c48d 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -219,6 +219,7 @@ 'emergency_subscriber_ip': fields.String(description=EMERGENCY_SUBSCRIBER.ip.doc), 'emergency_subscriber_access_network_gateway_address': fields.String(description=EMERGENCY_SUBSCRIBER.access_network_gateway_address.doc), 'emergency_subscriber_access_network_charging_address': fields.String(description=EMERGENCY_SUBSCRIBER.access_network_charging_address.doc), + 'emergency_subscriber_delete': fields.Boolean(description="Whether to delete the emergency subscriber on receipt"), }) def no_auth_required(f): @@ -2004,7 +2005,7 @@ def patch(self): "geored_host": request.remote_addr, }, metricExpiry=60) - if 'emergency_subscriber_id' in json_data: + if 'emergency_subscriber_ip' in json_data: """ If we receive a geored payload containing emergency_subscriber_id, create or update the matching emergency_subscriber_id. If emergency_subscriber_id exists as None, then remove the emergency subscriber. @@ -2024,13 +2025,13 @@ def patch(self): "accessNetworkChargingAddress": json_data.get('emergency_subscriber_access_network_charging_address'), } - if not json_data.get('emergency_subscriber_id', None): - logTool.log(service='API', level='error', message=f"[API] emergency_subscriber_id missing from geored request. No changes to emergency_subscriber made.", redisClient=redisMessaging) - return {'result': 'Failed', 'Reason' : "emergency_subscriber_id missing from geored request"} + if not json_data.get('emergency_subscriber_ip', None): + logTool.log(service='API', level='error', message=f"[API] emergency_subscriber_ip missing from geored request. No changes to emergency_subscriber made.", redisClient=redisMessaging) + return {'result': 'Failed', 'Reason' : "emergency_subscriber_ip missing from geored request"} if 'emergency_subscriber_delete' in json_data: if json_data.get('emergency_subscriber_delete', False): - databaseClient.deleteObj(EMERGENCY_SUBSCRIBER, json_data.get('emergency_subscriber_id')) + databaseClient.Delete_Emergency_Subscriber(subscriberIp=subscriberData.get('ip'), imsi=subscriberData.get('imsi'), propagate=False) return {}, 200 response_data.append(databaseClient.Update_Emergency_Subscriber(emergencySubscriberId=json_data['emergency_subscriber_id'], From 11573c62f18b32caa9d684091a4bd989c33a1cda Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 16 Jan 2024 22:49:37 +1000 Subject: [PATCH 7/9] Working geored for emergency_subscriber --- lib/database.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/database.py b/lib/database.py index 8799cdf9..77701188 100755 --- a/lib/database.py +++ b/lib/database.py @@ -2314,10 +2314,11 @@ def Delete_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscribe emergencySubscriberId = result.emergency_subscriber_id session.delete(result) session.commit() + result = result.__dict__ if propagate: self.handleGeored({ - "emergency_subscriber_imsi": subscriberData.get('imsi'), - "emergency_subscriber_ip": subscriberData.get('ip'), + "emergency_subscriber_imsi": result.get('imsi'), + "emergency_subscriber_ip": result.get('ip'), "emergency_subscriber_delete": True, }) self.safe_close(session) From 221ee363ae1bc661e43d49ba72bfaaaea1fe8bc5 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Tue, 16 Jan 2024 23:29:31 +1000 Subject: [PATCH 8/9] Update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c12ea9a..4799da4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Gx RAR now dynamically creates TFT up to 512k based on UE request. - SQN Resync now propogates via Geored when enabled - Renamed sh_profile to xcap_profile in ims_subscriber +- Rebuilt keys using unique namespace for redis-sentinel / stateless compatibility. ### Fixed @@ -31,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Control of outbound roaming S6a AIR and ULA responses through roaming_rule and roaming_network objects. - Roaming management on a per-subscriber basis, through subscriber.roaming_enabled and subscriber.roaming_rule_list. - Support for Gx and Rx auth of unknown subscribers attaching via SOS. +- Preliminary support for SCTP. ## [1.0.0] - 2023-09-27 From 5b3a767e9e4452e03029718e3a43a7f68c307c69 Mon Sep 17 00:00:00 2001 From: davidkneipp Date: Thu, 18 Jan 2024 17:43:26 +1000 Subject: [PATCH 9/9] Fix for Dedicated Bearers on MO and MT call legs --- lib/diameter.py | 138 ++++++++++++++++++++--------------------- services/apiService.py | 3 +- services/hssService.py | 1 - 3 files changed, 68 insertions(+), 74 deletions(-) diff --git a/lib/diameter.py b/lib/diameter.py index 59bfb481..8281eb5a 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1059,6 +1059,23 @@ def AVP_278_Origin_State_Incriment(self, avps): origin_state_incriment_hex = format(origin_state_incriment_int,"x").zfill(8) return origin_state_incriment_hex + def Match_SDP(self, regexPattern, sdpBody): + """ + Matches a given regex in a given SDP body. + Returns the result, or and empty string if not found. + """ + + try: + sdpMatch = re.search(regexPattern, sdpBody, re.MULTILINE) + if sdpMatch: + sdpResult = sdpMatch.group(1) + if sdpResult: + return str(sdpResult) + return '' + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Match_SDP] Error matching SDP: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' + def Charging_Rule_Generator(self, ChargingRules=None, ue_ip=None, chargingRuleName=None, action="install"): self.logTool.log(service='HSS', level='debug', message=f"Called Charging_Rule_Generator with action: {action}", redisClient=self.redisMessaging) if action not in ['install', 'remove']: @@ -2861,81 +2878,56 @@ def Answer_16777236_265(self, packet_vars, avps): except Exception as e: pass - #Extract the SDP for each direction to find the source and destination IP Addresses and Ports used for the RTP streams + # Extract the SDP for both Uplink and Downlink, to create TFTs. try: - sdp1 = self.get_avp_data(avps, 524)[0] - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] got first SDP body raw: " + str(sdp1), redisClient=self.redisMessaging) - sdp1 = binascii.unhexlify(sdp1).decode('utf-8') - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] got first SDP body decoded: " + str(sdp1), redisClient=self.redisMessaging) - sdp2 = self.get_avp_data(avps, 524)[1] - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] got second SDP body raw: " + str(sdp2), redisClient=self.redisMessaging) - sdp2 = binascii.unhexlify(sdp2).decode('utf-8') - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] got second SDP body decoded: " + str(sdp2), redisClient=self.redisMessaging) - - regex_ipv4 = r"IN IP4 (\d*\.\d*\.\d*\.\d*)" - regex_ipv6 = r"IN IP6 ([0-9a-fA-F:]{3,39})" - regex_port_audio = r"m=audio (\d*)" - regex_port_rtcp = r"a=rtcp:(\d*)" - - #Check for IPv4 Matches in first SDP Body - matches_ipv4 = re.search(regex_ipv4, sdp1, re.MULTILINE) - if matches_ipv4: - sdp1_ipv4 = str(matches_ipv4.group(1)) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP IPv4" + str(sdp1_ipv4), redisClient=self.redisMessaging) - else: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] No matches for IPv4 in SDP", redisClient=self.redisMessaging) - if not matches_ipv4: - #Check for IPv6 Matches - matches_ipv6 = re.search(regex_ipv6, sdp1, re.MULTILINE) - if matches_ipv6: - sdp1_ipv6 = str(matches_ipv6.group(1)) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP IPv6" + str(sdp1_ipv6), redisClient=self.redisMessaging) - else: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] No matches for IPv6 in SDP", redisClient=self.redisMessaging) - - - #Check for IPv4 Matches in second SDP Body - matches_ipv4 = re.search(regex_ipv4, sdp2, re.MULTILINE) - if matches_ipv4: - sdp2_ipv4 = str(matches_ipv4.group(1)) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP2 IPv4 " + str(sdp2_ipv4), redisClient=self.redisMessaging) - else: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] No matches for IPv4 in SDP2", redisClient=self.redisMessaging) - if not matches_ipv4: - #Check for IPv6 Matches - matches_ipv6 = re.search(regex_ipv6, sdp2, re.MULTILINE) - if matches_ipv6: - sdp2_ipv6 = str(matches_ipv6.group(1)) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP2 IPv6 " + str(sdp2_ipv6), redisClient=self.redisMessaging) - else: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] No matches for IPv6 in SDP", redisClient=self.redisMessaging) - - #Extract RTP Port - matches_rtp_port = re.search(regex_port_audio, sdp2, re.MULTILINE) - if matches_rtp_port: - sdp2_rtp_port = str(matches_rtp_port.group(1)) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP2 RTP Port " + str(sdp2_rtp_port), redisClient=self.redisMessaging) - - #Extract RTP Port - matches_rtp_port = re.search(regex_port_audio, sdp1, re.MULTILINE) - if matches_rtp_port: - sdp1_rtp_port = str(matches_rtp_port.group(1)) - sdp1_rtcp_port = int(sdp1_rtp_port) - 1 - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP1 RTP Port " + str(sdp1_rtp_port), redisClient=self.redisMessaging) - - - #Extract RTP Port - matches_rtp_port = re.search(regex_port_audio, sdp2, re.MULTILINE) - if matches_rtp_port: - sdp2_rtp_port = str(matches_rtp_port.group(1)) - sdp2_rtcp_port = int(sdp2_rtp_port) - 1 - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Matched SDP2 RTP Port " + str(sdp2_rtp_port), redisClient=self.redisMessaging) + sdpOffer = self.get_avp_data(avps, 524)[0] + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Got SDP Offer raw: {sdpOffer}", redisClient=self.redisMessaging) + sdpOffer = binascii.unhexlify(sdpOffer).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Got SDP Offer decoded: {sdpOffer}", redisClient=self.redisMessaging) + sdpAnswer = self.get_avp_data(avps, 524)[1] + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Got SDP Answer raw: {sdpAnswer}", redisClient=self.redisMessaging) + sdpAnswer = binascii.unhexlify(sdpAnswer).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Got SDP Answer decoded: {sdpAnswer}", redisClient=self.redisMessaging) + + regexIpv4 = r"IN IP4 (\d*\.\d*\.\d*\.\d*)" + regexIpv6 = r"IN IP6 ([0-9a-fA-F:]{3,39})" + regexRtp = r"m=audio (\d*)" + regexRtcp = r"a=rtcp:(\d+)" + + sdpDownlink = None + sdpUplink = None + sdpDownlinkIpv4 = '' + sdpDownlinkRtpPort = '' + sdpUplinkRtpPort = '' + + # First, work out which side the SDP Downlink is, then do the same for the SDP Uplink. + if 'downlink' in sdpOffer.lower(): + sdpDownlink = sdpOffer + elif 'downlink' in sdpAnswer.lower(): + sdpDownlink = sdpAnswer + + if 'uplink' in sdpOffer.lower(): + sdpUplink = sdpOffer + elif 'uplink' in sdpAnswer.lower(): + sdpUplink = sdpAnswer + # Grab the SDP Downlink IP + sdpDownlinkIpv4 = self.Match_SDP(regexPattern=regexIpv4, sdpBody=sdpDownlink) + sdpDownlinkIpv6 = self.Match_SDP(regexPattern=regexIpv6, sdpBody=sdpDownlink) - except Exception as e: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Failed to extract SDP due to error" + str(e), redisClient=self.redisMessaging) + # Get the RTP ports + sdpDownlinkRtpPort = self.Match_SDP(regexPattern=regexRtp, sdpBody=sdpDownlink) + sdpUplinkRtpPort = self.Match_SDP(regexPattern=regexRtp, sdpBody=sdpUplink) + # The RTCP Port is always the RTP port + 1. Comma separated ports arent used due to lack of support in open source PGWs. + # We take a blind approach by setting a range of +1 on both sides. + sdpDownlinkRtpPorts = f"{sdpDownlinkRtpPort}-{int(sdpDownlinkRtpPort)+1}" + sdpUplinkRtpPorts = f"{sdpUplinkRtpPort}-{int(sdpUplinkRtpPort)+1}" + + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_265] [AAA] Failed to extract SDP due to error: {traceback.format_exc()}", redisClient=self.redisMessaging) + """ The below logic is applied: 1. Grab the Flow Rules and bitrates from the PCSCF in the AAR, @@ -2964,12 +2956,14 @@ def Answer_16777236_265(self, packet_vars, avps): "tft_group_id": 1, "direction": 1, "tft_id": 1, - "tft_string": "permit out 17 from " + str(sdp2_ipv4) + "/32 " + str(sdp2_rtcp_port) + "-" + str(sdp2_rtp_port) + " to " + str(ueIp) + "/32 " + str(sdp1_rtcp_port) + "-" + str(sdp1_rtp_port) }, + "tft_string": f"permit out 17 from {sdpDownlinkIpv4}/32 {sdpDownlinkRtpPorts} to {ueIp}/32 {sdpUplinkRtpPorts}" + }, { "tft_group_id": 1, "direction": 2, "tft_id": 2, - "tft_string": "permit out 17 from " + str(sdp2_ipv4) + "/32 " + str(sdp2_rtcp_port) + "-" + str(sdp2_rtp_port) + " to " + str(ueIp) + "/32 " + str(sdp1_rtcp_port) + "-" + str(sdp1_rtp_port) } + "tft_string": f"permit out 17 from {sdpDownlinkIpv4}/32 {sdpDownlinkRtpPorts} to {ueIp}/32 {sdpUplinkRtpPorts}" + } ] } diff --git a/services/apiService.py b/services/apiService.py index 9bc6c48d..e371e848 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -1770,7 +1770,7 @@ def put(self): return result, 400 activeSubscribers = databaseClient.Get_Subscribers_By_Pcscf(pcscf=pcscf) - logTool.log(service='API', level='debug', message=f"[API] Active Subscribers for {pcscf}: {activeSubscribers}", redisClient=redisMessaging) + logTool.log(service='API', level='debug', message=f"[API] [pcscf_restoration] Active Subscribers for {pcscf}: {activeSubscribers}", redisClient=redisMessaging) if len(activeSubscribers) > 0: for imsSubscriber in activeSubscribers: @@ -1796,6 +1796,7 @@ def put(self): ) except Exception as e: + logTool.log(service='API', level='error', message=f"[API] [pcscf_restoration] Error sending CLR for subscriber: {traceback.format_exc()}", redisClient=redisMessaging) continue result = {"Result": f"Successfully sent PCSCF Restoration request for PCSCF: {pcscf}"} diff --git a/services/hssService.py b/services/hssService.py index 08393446..cb74bdb9 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -32,7 +32,6 @@ def __init__(self): self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) self.hostname = socket.gethostname() - def handleQueue(self): """ Gets and parses inbound diameter requests, processes them and queues the response.