diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e6fdb8..4799da4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Gx RAR now dynamically creates TFT up to 512k based on UE request. - SQN Resync now propogates via Geored when enabled - Renamed sh_profile to xcap_profile in ims_subscriber +- Rebuilt keys using unique namespace for redis-sentinel / stateless compatibility. ### Fixed @@ -30,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - generateUpgade.sh for generating alembic upgrade scripts - Control of outbound roaming S6a AIR and ULA responses through roaming_rule and roaming_network objects. - Roaming management on a per-subscriber basis, through subscriber.roaming_enabled and subscriber.roaming_rule_list. +- Support for Gx and Rx auth of unknown subscribers attaching via SOS. +- Preliminary support for SCTP. ## [1.0.0] - 2023-09-27 diff --git a/config.yaml b/config.yaml index 700b0c2..13f35c3 100644 --- a/config.yaml +++ b/config.yaml @@ -1,5 +1,7 @@ ## HSS Parameters -hss: +hss: + # Transport Type. "TCP" and "SCTP" are valid options. + # Note: SCTP works but is still experimental. TCP has been load-tested and performs in a production environment. transport: "TCP" #IP Addresses to bind on (List) - For TCP only the first IP is used, for SCTP all used for Transport (Multihomed). bind_ip: ["0.0.0.0"] @@ -70,6 +72,12 @@ hss: # Whether or not to a subscriber to connect to an undefined network when outbound roaming. allow_undefined_networks: True + # SCTP Socket Parameters + sctp: + rtoMax: 5000 + rtoMin: 500 + rtoInitial: 1000 + api: page_size: 200 @@ -95,8 +103,7 @@ logging: diameter_logging_file: /var/log/pyhss_diameter.log geored_logging_file: /var/log/pyhss_geored.log metric_logging_file: /var/log/pyhss_metrics.log - log_to_terminal: True - sqlalchemy_sql_echo: True + sqlalchemy_sql_echo: False sqlalchemy_pool_recycle: 15 sqlalchemy_pool_size: 30 sqlalchemy_max_overflow: 0 @@ -113,7 +120,7 @@ database: webhooks: enabled: False endpoints: - - http://127.0.0.1:8181 + - 'http://127.0.0.1:8181' ## Geographic Redundancy Parameters geored: @@ -123,13 +130,23 @@ geored: - 'http://hss01.mnc001.mcc001.3gppnetwork.org:8080' - 'http://hss02.mnc001.mcc001.3gppnetwork.org:8080' -#Redis is required to run PyHSS. A locally running instance is recommended for production. +#Redis is required to run PyHSS. An instance running on a local network is recommended for production. redis: - # Whether to use a UNIX socket instead of a tcp connection to redis. Host and port is ignored if useUnixSocket is True. - useUnixSocket: False + # Which connection type to attempt. Valid options are: tcp, unix, sentinel + # tcp - Connection via a standard TCP socket to a given host and port. + # unix - Connect to redis via a unix socket, provided by unixSocketPath. + # sentinel - Connect to one or more redis sentinel hosts. + connectionType: "tcp" unixSocketPath: '/var/run/redis/redis-server.sock' host: localhost port: 6379 + sentinel: + masterName: exampleMaster + hosts: + - exampleSentinel.mnc001.mcc001.3gppnetwork.org: + port: 6379 + password: '' + prometheus: enabled: False diff --git a/lib/database.py b/lib/database.py index e338c15..7770118 100755 --- a/lib/database.py +++ b/lib/database.py @@ -18,6 +18,7 @@ from messaging import RedisMessaging import yaml import json +import socket import traceback with open("../config.yaml", 'r') as stream: @@ -99,7 +100,6 @@ class SUBSCRIBER(Base): last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') operation_logs = relationship("SUBSCRIBER_OPERATION_LOG", back_populates="subscriber") - class SUBSCRIBER_ROUTING(Base): __tablename__ = 'subscriber_routing' __table_args__ = ( @@ -114,7 +114,6 @@ class SUBSCRIBER_ROUTING(Base): last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') operation_logs = relationship("SUBSCRIBER_ROUTING_OPERATION_LOG", back_populates="subscriber_routing") - class SERVING_APN(Base): __tablename__ = 'serving_apn' serving_apn_id = Column(Integer, primary_key=True, doc='Unique ID of SERVING_APN') @@ -158,15 +157,6 @@ class IMS_SUBSCRIBER(Base): last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') operation_logs = relationship("IMS_SUBSCRIBER_OPERATION_LOG", back_populates="ims_subscriber") -class ROAMING_RULE(Base): - __tablename__ = 'roaming_rule' - roaming_rule_id = Column(Integer, primary_key = True, doc='Unique ID of ROAMING_RULE entry') - roaming_network_id = Column(Integer, ForeignKey('roaming_network.roaming_network_id', ondelete='CASCADE'), doc='ID of the roaming network to apply the rule for') - allow = Column(Boolean, default=1, doc='Whether to allow outbound roaming on the network') - enabled = Column(Boolean, default=1, doc='Whether the rule is enabled') - last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') - operation_logs = relationship("ROAMING_RULE_OPERATION_LOG", back_populates="roaming_rule") - class ROAMING_NETWORK(Base): __tablename__ = 'roaming_network' roaming_network_id = Column(Integer, primary_key = True, doc='Unique ID of ROAMING_NETWORK entry') @@ -177,6 +167,32 @@ class ROAMING_NETWORK(Base): last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') operation_logs = relationship("ROAMING_NETWORK_OPERATION_LOG", back_populates="roaming_network") +class EMERGENCY_SUBSCRIBER(Base): + __tablename__ = 'emergency_subscriber' + emergency_subscriber_id = Column(Integer, primary_key = True, doc='Unique ID of EMERGENCY_SUBSCRIBER entry') + imsi = Column(String(18), doc='International Mobile Subscriber Identity') + serving_pgw = Column(String(512), doc='PGW serving this subscriber') + serving_pgw_timestamp = Column(String(512), doc='Timestamp of Gx CCR') + serving_pcscf = Column(String(512), doc='PCSCF serving this subscriber') + serving_pcscf_timestamp = Column(String(512), doc='Timestamp of Rx Media AAR') + gx_origin_realm = Column(String(512), doc='Origin Realm of the Gx CCR') + gx_origin_host = Column(String(512), doc='Origin host of the Gx CCR') + rat_type = Column(String(512), doc='Radio access technology type that the emergency subscriber has used') + ip = Column(String(512), doc='IP of the emergency subscriber') + access_network_gateway_address = Column(String(512), doc='ANGW emergency that the subscriber has used') + access_network_charging_address = Column(String(512), doc='AN Charging Address that the emergency subscriber has used') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("EMERGENCY_SUBSCRIBER_OPERATION_LOG", back_populates="emergency_subscriber") + +class ROAMING_RULE(Base): + __tablename__ = 'roaming_rule' + roaming_rule_id = Column(Integer, primary_key = True, doc='Unique ID of ROAMING_RULE entry') + roaming_network_id = Column(Integer, ForeignKey('roaming_network.roaming_network_id', ondelete='CASCADE'), doc='ID of the roaming network to apply the rule for') + allow = Column(Boolean, default=1, doc='Whether to allow outbound roaming on the network') + enabled = Column(Boolean, default=1, doc='Whether the rule is enabled') + last_modified = Column(String(100), default=datetime.datetime.now(tz=timezone.utc), doc='Timestamp of last modification') + operation_logs = relationship("ROAMING_RULE_OPERATION_LOG", back_populates="roaming_rule") + class CHARGING_RULE(Base): __tablename__ = 'charging_rule' charging_rule_id = Column(Integer, primary_key = True, doc='Unique ID of CHARGING_RULE entry') @@ -286,6 +302,11 @@ class ROAMING_NETWORK_OPERATION_LOG(OPERATION_LOG_BASE): roaming_network = relationship("ROAMING_NETWORK", back_populates="operation_logs") roaming_network_id = Column(Integer, ForeignKey('roaming_network.roaming_network_id')) +class EMERGENCY_SUBSCRIBER_OPERATION_LOG(OPERATION_LOG_BASE): + __mapper_args__ = {'polymorphic_identity': 'emergency_subscriber'} + emergency_subscriber = relationship("EMERGENCY_SUBSCRIBER", back_populates="operation_logs") + emergency_subscriber_id = Column(Integer, ForeignKey('emergency_subscriber.emergency_subscriber_id')) + class CHARGING_RULE_OPERATION_LOG(OPERATION_LOG_BASE): __mapper_args__ = {'polymorphic_identity': 'charging_rule'} charging_rule = relationship("CHARGING_RULE", back_populates="operation_logs") @@ -333,7 +354,8 @@ def __init__(self, logTool, redisMessaging=None): db_string = 'postgresql+psycopg2://' + str(self.config['database']['username']) + ':' + str(self.config['database']['password']) + '@' + str(self.config['database']['server']) + '/' + str(self.config['database']['database']) else: db_string = 'mysql://' + str(self.config['database']['username']) + ':' + str(self.config['database']['password']) + '@' + str(self.config['database']['server']) + '/' + str(self.config['database']['database'] + "?autocommit=true") - + + self.hostname = socket.gethostname() self.engine = create_engine( db_string, @@ -353,7 +375,7 @@ def __init__(self, logTool, redisMessaging=None): #Load IMEI TAC database into Redis if enabled if ('tac_database_csv' in self.config['eir']): self.load_IMEI_database_into_Redis() - self.tacData = json.loads(self.redisMessaging.getValue(key="tacDatabase")) + self.tacData = json.loads(self.redisMessaging.getValue(key="tacDatabase", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='database')) else: self.logTool.log(service='Database', level='info', message="Not loading EIR IMEI TAC Database as Redis not enabled or TAC CSV Database not set in config", redisClient=self.redisMessaging) self.tacData = {} @@ -390,7 +412,7 @@ def load_IMEI_database_into_Redis(self): if count == 0: self.logTool.log(service='Database', level='info', message="Checking to see if entries are already present...", redisClient=self.redisMessaging) - redis_imei_result = self.redisMessaging.getValue(key="tacDatabase") + redis_imei_result = self.redisMessaging.getValue(key="tacDatabase", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='database') if redis_imei_result is not None: if len(redis_imei_result) > 0: self.logTool.log(service='Database', level='info', message="IMEI TAC Database already loaded into Redis - Skipping reading from file...", redisClient=self.redisMessaging) @@ -398,7 +420,7 @@ def load_IMEI_database_into_Redis(self): self.logTool.log(service='Database', level='info', message="No data loaded into Redis, proceeding to load...", redisClient=self.redisMessaging) tacList['tacList'].append({str(tacPrefix): {'name': name, 'model': model}}) count += 1 - self.redisMessaging.setValue(key="tacDatabase", value=json.dumps(tacList)) + self.redisMessaging.setValue(key="tacDatabase", value=json.dumps(tacList), usePrefix=True, prefixHostname=self.hostname, prefixServiceName='database') self.tacData = tacList self.logTool.log(service='Database', level='info', message="Loaded " + str(count) + " IMEI TAC entries into Redis", redisClient=self.redisMessaging) except Exception as E: @@ -921,14 +943,14 @@ def handleGeored(self, jsonData, operation: str="PATCH", asymmetric: bool=False, georedDict['body'] = jsonData georedDict['operation'] = operation georedDict['timestamp'] = time.time_ns() - self.redisMessaging.sendMessage(queue=f'geored', message=json.dumps(georedDict), queueExpiry=120) + self.redisMessaging.sendMessage(queue=f'geored', message=json.dumps(georedDict), queueExpiry=120, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='geored') if asymmetric: if len(asymmetricUrls) > 0: georedDict['body'] = jsonData georedDict['operation'] = operation georedDict['timestamp'] = time.time_ns() georedDict['urls'] = asymmetricUrls - self.redisMessaging.sendMessage(queue=f'asymmetric-geored', message=json.dumps(georedDict), queueExpiry=120) + self.redisMessaging.sendMessage(queue=f'asymmetric-geored', message=json.dumps(georedDict), queueExpiry=120, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='geored') return True except Exception as E: @@ -956,7 +978,7 @@ def handleWebhook(self, objectData, operation: str="PATCH"): webhook['headers'] = webhookHeaders webhook['operation'] = operation webhook['timestamp'] = time.time_ns() - self.redisMessaging.sendMessage(queue=f'webhook', message=json.dumps(webhook), queueExpiry=120) + self.redisMessaging.sendMessage(queue=f'webhook', message=json.dumps(webhook), queueExpiry=120, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='webhook') return True def Sanitize_Datetime(self, result): @@ -1455,8 +1477,7 @@ def Get_Served_IMS_Subscribers(self, get_local_users_only=False): IMS_SUBSCRIBER.scscf.isnot(None)) for result in results: result = result.__dict__ - self.logTool.log(service='Database', level='debug', message="Result: " + str(result, redisClient=self.redisMessaging) + - " type: " + str(type(result))) + self.logTool.log(service='Database', level='debug', message="Result: " + str(result) + " type: " + str(type(result)), redisClient=self.redisMessaging) result = self.Sanitize_Datetime(result) result.pop('_sa_instance_state') if get_local_users_only == True: @@ -2128,6 +2149,186 @@ def Get_IMS_Subscriber_By_Session_Id(self, sessionId): result = self.Sanitize_Datetime(result) return result + def Get_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId: str=None, rxSessionId: str=None, imsi: str=None, **kwargs) -> dict: + self.logTool.log(service='Database', level='debug', message=f"Getting Emergency_Subscriber", redisClient=self.redisMessaging) + Session = sessionmaker(bind = self.engine) + session = Session() + + result = None + + try: + while not result: + if imsi and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(imsi=imsi).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Emergency_Subscriber] Matched emergency subscriber on IMSI: {imsi}", redisClient=self.redisMessaging) + break + if emergencySubscriberId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(emergency_subscriber_id=emergencySubscriberId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Emergency_Subscriber] Matched emergency subscriber on emergency_subscriber_id: {emergencySubscriberId}", redisClient=self.redisMessaging) + break + if subscriberIp and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(ip=subscriberIp).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Emergency_Subscriber] Matched emergency subscriber on IP: {subscriberIp}", redisClient=self.redisMessaging) + break + if gxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pgw=gxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Get_Emergency_Subscriber] Matched emergency subscriber on Gx Session ID: {gxSessionId}", redisClient=self.redisMessaging) + break + if rxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pcscf=rxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Rx Session ID: {rxSessionId}", redisClient=self.redisMessaging) + break + break + + if not result: + return None + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result + + except Exception as E: + self.logTool.log(service='Database', level='error', message=f"[database.py] [Get_Emergency_Subscriber] Error getting emergency subscriber: {traceback.format_exc()}", redisClient=self.redisMessaging) + self.safe_close(session) + return None + + def Update_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId: str=None, rxSessionId: str=None, imsi: str=None, subscriberData: dict={}, propagate: bool=True) -> dict: + """ + First, get at most one emergency subscriber. + Try and match on IMSI first (To detect an updated IP for an existing record), + If IMSI is None or no result was found, then try with a combination of all of the arguments. + Then update all data with the provided subscriberData, and push to geored. + """ + Session = sessionmaker(bind = self.engine) + session = Session() + + result = None + + while not result: + if imsi and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(imsi=imsi).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on IMSI: {imsi}", redisClient=self.redisMessaging) + break + if emergencySubscriberId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(emergency_subscriber_id=emergencySubscriberId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on emergency_subscriber_id: {emergencySubscriberId}", redisClient=self.redisMessaging) + break + if subscriberIp and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(ip=subscriberIp).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on IP: {subscriberIp}", redisClient=self.redisMessaging) + break + if gxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pgw=gxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Gx Session ID: {gxSessionId}", redisClient=self.redisMessaging) + break + if rxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pcscf=rxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Rx Session ID: {rxSessionId}", redisClient=self.redisMessaging) + break + break + + + """ + If we havent matched in on any entries at this point, create a new emergency subscriber. + """ + if not result: + result = EMERGENCY_SUBSCRIBER() + session.add(result) + + result.imsi = subscriberData.get('imsi') + result.serving_pgw = subscriberData.get('servingPgw') + result.serving_pgw_timestamp = subscriberData.get('requestTime') + result.serving_pcscf = subscriberData.get('servingPcscf') + result.serving_pcscf_timestamp = subscriberData.get('aarRequestTime') + result.gx_origin_realm = subscriberData.get('gxOriginRealm') + result.gx_origin_host = subscriberData.get('gxOriginHost') + result.rat_type = subscriberData.get('ratType') + result.ip = subscriberData.get('ip') + result.access_network_gateway_address = subscriberData.get('accessNetworkGatewayAddress') + result.access_network_charging_address = subscriberData.get('accessNetworkChargingAddress') + + try: + session.commit() + emergencySubscriberId = result.emergency_subscriber_id + if propagate: + self.handleGeored({ "emergency_subscriber_id": int(emergencySubscriberId), + "emergency_subscriber_imsi": subscriberData.get('imsi'), + "emergency_subscriber_serving_pgw": subscriberData.get('servingPgw'), + "emergency_subscriber_serving_pgw_timestamp": subscriberData.get('requestTime'), + "emergency_subscriber_serving_pcscf": subscriberData.get('servingPcscf'), + "emergency_subscriber_serving_pcscf_timestamp": subscriberData.get('aarRequestTime'), + "emergency_subscriber_gx_origin_realm": subscriberData.get('gxOriginRealm'), + "emergency_subscriber_gx_origin_host": subscriberData.get('gxOriginHost'), + "emergency_subscriber_rat_type": subscriberData.get('ratType'), + "emergency_subscriber_ip": subscriberData.get('ip'), + "emergency_subscriber_access_network_gateway_address": subscriberData.get('accessNetworkGatewayAddress'), + "emergency_subscriber_access_network_charging_address": subscriberData.get('accessNetworkChargingAddress'), + }) + + except Exception as E: + self.safe_close(session) + self.logTool.log(service='Database', level='error', message=f"[database.py] [Update_Emergency_Subscriber] Error updating emergency subscriber: {traceback.format_exc()}", redisClient=self.redisMessaging) + return None + result = result.__dict__ + result.pop('_sa_instance_state') + self.safe_close(session) + return result + + def Delete_Emergency_Subscriber(self, emergencySubscriberId: int=None, subscriberIp: str=None, gxSessionId: str=None, rxSessionId: str=None, imsi: str=None, subscriberData: dict={}, propagate: bool=True) -> bool: + """ + First, get at most one emergency subscriber matching the provided identifiers. + Then delete the emergency subscriber, and push to geored. + """ + Session = sessionmaker(bind = self.engine) + session = Session() + + result = None + + while not result: + if imsi and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(imsi=imsi).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on IMSI: {imsi}", redisClient=self.redisMessaging) + break + if emergencySubscriberId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(emergency_subscriber_id=emergencySubscriberId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on emergency_subscriber_id: {emergencySubscriberId}", redisClient=self.redisMessaging) + break + if subscriberIp and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(ip=subscriberIp).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on IP: {subscriberIp}", redisClient=self.redisMessaging) + break + if gxSessionId and not result: + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Gx Session ID: {gxSessionId}", redisClient=self.redisMessaging) + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pgw=gxSessionId).first() + break + if rxSessionId and not result: + result = session.query(EMERGENCY_SUBSCRIBER).filter_by(serving_pcscf=rxSessionId).first() + self.logTool.log(service='Database', level='debug', message=f"[database.py] [Update_Emergency_Subscriber] Matched emergency subscriber on Rx Session ID: {rxSessionId}", redisClient=self.redisMessaging) + break + break + + if not result: + return True + + try: + emergencySubscriberId = result.emergency_subscriber_id + session.delete(result) + session.commit() + result = result.__dict__ + if propagate: + self.handleGeored({ + "emergency_subscriber_imsi": result.get('imsi'), + "emergency_subscriber_ip": result.get('ip'), + "emergency_subscriber_delete": True, + }) + self.safe_close(session) + return True + except Exception as E: + self.safe_close(session) + self.logTool.log(service='Database', level='error', message=f"[database.py] [Delete_Emergency_Subscriber] Error deleting emergency subscriber: {traceback.format_exc()}", redisClient=self.redisMessaging) + return False + + def Store_IMSI_IMEI_Binding(self, imsi, imei, match_response_code, propagate=True): #IMSI 14-15 Digits #IMEI 15 Digits diff --git a/lib/diameter.py b/lib/diameter.py index 3df733c..8281eb5 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -9,10 +9,13 @@ import jinja2 from database import Database from messaging import RedisMessaging +from redis import Redis import yaml import json import time +import socket import traceback +import re class Diameter: @@ -31,10 +34,13 @@ def __init__(self, logTool, originHost: str="hss01", originRealm: str="epc.mnc99 self.redisUnixSocketPath = self.config.get('redis', {}).get('unixSocketPath', '/var/run/redis/redis-server.sock') self.redisHost = self.config.get('redis', {}).get('host', 'localhost') self.redisPort = self.config.get('redis', {}).get('port', 6379) + self.redisAdditionalPeers = self.config.get('redis', {}).get('additionalPeers', []) if redisMessaging: self.redisMessaging = redisMessaging else: self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + + self.hostname = socket.gethostname() self.database = Database(logTool=logTool) self.diameterRequestTimeout = int(self.config.get('hss', {}).get('diameter_request_timeout', 10)) @@ -150,24 +156,37 @@ def Reverse(self, str): return (slicedString) def DecodePLMN(self, plmn): - self.logTool.log(service='HSS', level='debug', message="Decoded PLMN: " + str(plmn), redisClient=self.redisMessaging) - mcc = self.Reverse(plmn[0:2]) + self.Reverse(plmn[2:4]).replace('f', '') + self.logTool.log(service='HSS', level='debug', message="Decoding PLMN: " + str(plmn), redisClient=self.redisMessaging) + if "f" in plmn: + mcc = self.Reverse(plmn[0:2]) + self.Reverse(plmn[2:4]).replace('f', '') + mnc = self.Reverse(plmn[4:6]) + else: + mcc = self.Reverse(plmn[0:2]) + self.Reverse(plmn[2:4][1]) + mnc = self.Reverse(plmn[4:6]) + str(self.Reverse(plmn[2:4][0])) self.logTool.log(service='HSS', level='debug', message="Decoded MCC: " + mcc, redisClient=self.redisMessaging) - - mnc = self.Reverse(plmn[4:6]) self.logTool.log(service='HSS', level='debug', message="Decoded MNC: " + mnc, redisClient=self.redisMessaging) return mcc, mnc - + def EncodePLMN(self, mcc, mnc): plmn = list('XXXXXX') - plmn[0] = self.Reverse(mcc)[1] - plmn[1] = self.Reverse(mcc)[2] - plmn[2] = "f" - plmn[3] = self.Reverse(mcc)[0] - plmn[4] = self.Reverse(mnc)[0] - plmn[5] = self.Reverse(mnc)[1] - plmn_list = plmn - plmn = '' + if len(mnc) == 2: + plmn[0] = self.Reverse(mcc)[1] + plmn[1] = self.Reverse(mcc)[2] + plmn[2] = "f" + plmn[3] = self.Reverse(mcc)[0] + plmn[4] = self.Reverse(mnc)[0] + plmn[5] = self.Reverse(mnc)[1] + plmn_list = plmn + plmn = '' + else: + plmn[0] = self.Reverse(mcc)[1] + plmn[1] = self.Reverse(mcc)[2] + plmn[2] = self.Reverse(mnc)[0] + plmn[3] = self.Reverse(mcc)[0] + plmn[4] = self.Reverse(mnc)[1] + plmn[5] = self.Reverse(mnc)[2] + plmn_list = plmn + plmn = '' for bits in plmn_list: plmn = plmn + bits self.logTool.log(service='HSS', level='debug', message="Encoded PLMN: " + str(plmn), redisClient=self.redisMessaging) @@ -543,7 +562,7 @@ def getConnectedPeersByType(self, peerType: str) -> list: if peerType not in peerTypes: return [] filteredConnectedPeers = [] - activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter').decode()) for key, value in activePeers.items(): if activePeers.get(key, {}).get('peerType', '') == peerType and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': @@ -555,15 +574,17 @@ def getConnectedPeersByType(self, peerType: str) -> list: return [] def getPeerByHostname(self, hostname: str) -> dict: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [getPeerByHostname] Looking for peer with hostname {hostname}", redisClient=self.redisMessaging) try: hostname = hostname.lower() - activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers").decode()) + activePeers = json.loads(self.redisMessaging.getValue(key="ActiveDiameterPeers", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter').decode()) for key, value in activePeers.items(): if activePeers.get(key, {}).get('diameterHostname', '').lower() == hostname and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': return(activePeers.get(key, {})) except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getPeerByHostname] Failed to find peer with hostname {hostname}", redisClient=self.redisMessaging) return {} def getDiameterMessageType(self, binaryData: str) -> dict: @@ -608,12 +629,16 @@ def sendDiameterRequest(self, requestType: str, hostname: str, **kwargs) -> str: peerPort = connectedPeer['port'] except Exception as e: return '' - request = diameterApplication["requestMethod"](**kwargs) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + try: + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Error generating request: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" sendTime = time.time_ns() outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [sendDiameterRequest] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return request except Exception as e: @@ -641,12 +666,17 @@ def broadcastDiameterRequest(self, requestType: str, peerType: str, **kwargs) -> peerPort = connectedPeer['port'] except Exception as e: return '' - request = diameterApplication["requestMethod"](**kwargs) + try: + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Error generating request: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" sendTime = time.time_ns() outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [broadcastDiameterRequest] [{requestType}] Queueing for peer type: {peerType} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) return connectedPeerList except Exception as e: @@ -678,26 +708,34 @@ def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeo except Exception as e: continue connectedPeer = self.getPeerByHostname(hostname=hostname) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Sending request via connected peer {connectedPeer} from hostname {hostname}", redisClient=self.redisMessaging) try: peerIp = connectedPeer['ipAddress'] peerPort = connectedPeer['port'] except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Could not get connection information for connectedPeer: {connectedPeer}", redisClient=self.redisMessaging) + return '' + + try: + request = diameterApplication["requestMethod"](**kwargs) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Error generating request: {traceback.format_exc()}", redisClient=self.redisMessaging) return '' - request = diameterApplication["requestMethod"](**kwargs) responseType = diameterApplication["responseAcronym"] sessionId = kwargs.get('sessionId', None) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Successfully generated request: {request}", redisClient=self.redisMessaging) sendTime = time.time_ns() outboundQueue = f"diameter-outbound-{peerIp}-{peerPort}" outboundMessage = json.dumps({"diameter-outbound": request, "inbound-received-timestamp": sendTime}) - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=self.diameterRequestTimeout, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] Queueing for host: {hostname} on {peerIp}-{peerPort}", redisClient=self.redisMessaging) startTimer = time.time() while True: try: if not time.time() >= startTimer + timeout: if sessionId is None: - queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound") + queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] queuedMessages(NoSessionId): {queuedMessages}", redisClient=self.redisMessaging) for queuedMessage in queuedMessages: queuedMessage = json.loads(queuedMessage) @@ -714,7 +752,7 @@ def awaitDiameterRequestAndResponse(self, requestType: str, hostname: str, timeo return messageHex time.sleep(0.02) else: - queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound") + queuedMessages = self.redisMessaging.getList(key=f"diameter-inbound", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [awaitDiameterRequestAndResponse] [{requestType}] queuedMessages({sessionId}): {queuedMessages} responseType: {responseType}", redisClient=self.redisMessaging) for queuedMessage in queuedMessages: queuedMessage = json.loads(queuedMessage) @@ -764,8 +802,12 @@ def generateDiameterResponse(self, binaryData: str) -> str: if 'flags' in diameterApplication: assert(str(packet_vars["flags"]) == str(diameterApplication["flags"])) self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Attempting to generate response", redisClient=self.redisMessaging) - response = diameterApplication["responseMethod"](packet_vars, avps) - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Successfully generated response: {response}", redisClient=self.redisMessaging) + try: + response = diameterApplication["responseMethod"](packet_vars, avps) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Successfully generated response: {response}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [generateDiameterResponse] [{diameterApplication.get('requestAcronym', '')}] Error generating response: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' break except Exception as e: continue @@ -990,7 +1032,7 @@ def validateOutboundRoamingNetwork(self, assignedRoamingRules: str, mcc: str, mn return True else: return False - + def validateSubscriberRoaming(self, subscriber: dict, mcc: str, mnc: str) -> bool: """ Ensures that a given subscriber is allowed to roam to the provided PLMN. @@ -1009,7 +1051,6 @@ def validateSubscriberRoaming(self, subscriber: dict, mcc: str, mnc: str) -> boo return True - def AVP_278_Origin_State_Incriment(self, avps): #Capabilities Exchange Answer incriment AVP body for avp_dicts in avps: if avp_dicts['avp_code'] == 278: @@ -1018,6 +1059,23 @@ def AVP_278_Origin_State_Incriment(self, avps): origin_state_incriment_hex = format(origin_state_incriment_int,"x").zfill(8) return origin_state_incriment_hex + def Match_SDP(self, regexPattern, sdpBody): + """ + Matches a given regex in a given SDP body. + Returns the result, or and empty string if not found. + """ + + try: + sdpMatch = re.search(regexPattern, sdpBody, re.MULTILINE) + if sdpMatch: + sdpResult = sdpMatch.group(1) + if sdpResult: + return str(sdpResult) + return '' + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Match_SDP] Error matching SDP: {traceback.format_exc()}", redisClient=self.redisMessaging) + return '' + def Charging_Rule_Generator(self, ChargingRules=None, ue_ip=None, chargingRuleName=None, action="install"): self.logTool.log(service='HSS', level='debug', message=f"Called Charging_Rule_Generator with action: {action}", redisClient=self.redisMessaging) if action not in ['install', 'remove']: @@ -1045,7 +1103,7 @@ def Charging_Rule_Generator(self, ChargingRules=None, ue_ip=None, chargingRuleNa #Populate all Flow Information AVPs Flow_Information = '' for tft in ChargingRules['tft']: - self.logTool.log(service='HSS', level='debug', message=tft, redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message="Adding TFT: " + str(tft), redisClient=self.redisMessaging) #If {{ UE_IP }} in TFT splice in the real UE IP Value try: tft['tft_string'] = tft['tft_string'].replace('{{ UE_IP }}', str(ue_ip)) @@ -1197,7 +1255,6 @@ def Answer_257(self, packet_vars, avps): #Device Watchdog Answer def Answer_280(self, packet_vars, avps): - avp = '' #Initiate empty var AVP avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host @@ -1207,8 +1264,6 @@ def Answer_280(self, packet_vars, avps): avp += self.generate_avp(278, 40, self.AVP_278_Origin_State_Incriment(avps)) #Origin State (Has to be incrimented (Handled by AVP_278_Origin_State_Incriment)) response = self.generate_diameter_packet("01", "00", 280, 0, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet self.logTool.log(service='HSS', level='debug', message="Successfully Generated DWA", redisClient=self.redisMessaging) - orignHost = self.get_avp_data(avps, 264)[0] #Get OriginHost from AVP - orignHost = binascii.unhexlify(orignHost).decode('utf-8') #Format it return response #Disconnect Peer Answer @@ -1282,14 +1337,16 @@ def Answer_16777251_316(self, packet_vars, avps): decodedPlmn = self.DecodePLMN(plmn=plmn) mcc = decodedPlmn[0] mnc = decodedPlmn[1] + subscriberIsRoaming = False + subscriberRoamingAllowed = False if str(mcc) != str(self.MCC) and str(mnc) != str(self.MNC): subscriberIsRoaming = True if subscriberIsRoaming: - self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777251_318] [AIA] Subscriber {imsi} is roaming", redisClient=self.redisMessaging) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777251_318] [ULA] Subscriber {imsi} is roaming", redisClient=self.redisMessaging) subscriberRoamingAllowed = self.validateSubscriberRoaming(subscriber=subscriber_details, mcc=mcc, mnc=mnc) - if not subscriberRoamingAllowed: + if not subscriberRoamingAllowed and subscriberIsRoaming: avp = '' session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set @@ -1608,6 +1665,8 @@ def Answer_16777251_318(self, packet_vars, avps): decodedPlmn = self.DecodePLMN(plmn=plmn) mcc = decodedPlmn[0] mnc = decodedPlmn[1] + subscriberIsRoaming = False + subscriberRoamingAllowed = False if str(mcc) != str(self.MCC) and str(mnc) != str(self.MNC): subscriberIsRoaming = True @@ -1615,7 +1674,7 @@ def Answer_16777251_318(self, packet_vars, avps): self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777251_318] [AIA] Subscriber {imsi} is roaming", redisClient=self.redisMessaging) subscriberRoamingAllowed = self.validateSubscriberRoaming(subscriber=subscriber_details, mcc=mcc, mnc=mnc) - if not subscriberRoamingAllowed: + if not subscriberRoamingAllowed and subscriberIsRoaming: avp = '' session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set @@ -1801,10 +1860,171 @@ def Answer_16777238_272(self, packet_vars, avps): session_id = self.get_avp_data(avps, 263)[0] #Get Session-ID self.logTool.log(service='HSS', level='debug', message="[diameter.py] [Answer_16777238_272] [CCA] Session Id is " + str(binascii.unhexlify(session_id).decode()), redisClient=self.redisMessaging) avp += self.generate_avp(263, 40, session_id) #Session-ID AVP set + avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host + avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(258, 40, "01000016") #Auth-Application-Id (3GPP Gx 16777238) avp += self.generate_avp(416, 40, format(int(CC_Request_Type),"x").zfill(8)) #CC-Request-Type avp += self.generate_avp(415, 40, format(int(CC_Request_Number),"x").zfill(8)) #CC-Request-Number + """ + If Called-Station-ID contains 'sos', we're dealing with an emergency bearer request. + Authentication is bypassed and we'll return a basic QOS profile. + """ + try: + if apn.lower() == 'sos': + if int(CC_Request_Type) == 1: + """ + If we've recieved a CCR-Initial, create an emergency subscriber. + """ + # Use our defined SOS APN AMBR, if defined. + # Otherwise, use a default value of 128/128kbps. + try: + sosApn = (self.database.Get_APN_by_Name(apn="sos")) + AMBR = '' #Initiate empty var AVP for AMBR + apn_ambr_ul = int(sosApn['apn_ambr_ul']) + apn_ambr_dl = int(sosApn['apn_ambr_dl']) + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(int(sosApn['arp_priority']), 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(int(not sosApn['arp_preemption_capability']), 4)) + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(int(not sosApn['arp_preemption_vulnerability']), 4)) + AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(int(sosApn['qci']), 4)) + avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) + + except Exception as e: + AMBR = '' #Initiate empty var AVP for AMBR + apn_ambr_ul = 128000 + apn_ambr_dl = 128000 + AMBR += self.generate_vendor_avp(516, "c0", 10415, self.int_to_hex(apn_ambr_ul, 4)) #Max-Requested-Bandwidth-UL + AMBR += self.generate_vendor_avp(515, "c0", 10415, self.int_to_hex(apn_ambr_dl, 4)) #Max-Requested-Bandwidth-DL + APN_AMBR = self.generate_vendor_avp(1435, "c0", 10415, AMBR) + + AVP_Priority_Level = self.generate_vendor_avp(1046, "80", 10415, self.int_to_hex(1, 4)) + AVP_Preemption_Capability = self.generate_vendor_avp(1047, "80", 10415, self.int_to_hex(0, 4)) # Pre-Emption Capability Enabled + AVP_Preemption_Vulnerability = self.generate_vendor_avp(1048, "80", 10415, self.int_to_hex(1, 4)) # Pre-Emption Vulnerability Disabled + AVP_ARP = self.generate_vendor_avp(1034, "80", 10415, AVP_Priority_Level + AVP_Preemption_Capability + AVP_Preemption_Vulnerability) + AVP_QoS = self.generate_vendor_avp(1028, "c0", 10415, self.int_to_hex(5, 4)) # QCI 5 + avp += self.generate_vendor_avp(1049, "80", 10415, AVP_QoS + AVP_ARP) + + QoS_Information = self.generate_vendor_avp(1041, "80", 10415, self.int_to_hex(apn_ambr_ul, 4)) + QoS_Information += self.generate_vendor_avp(1040, "80", 10415, self.int_to_hex(apn_ambr_dl, 4)) + avp += self.generate_vendor_avp(1016, "80", 10415, QoS_Information) # QOS-Information + + #Supported-Features(628) (Gx feature list) + avp += self.generate_vendor_avp(628, "80", 10415, "0000010a4000000c000028af0000027580000010000028af000000010000027680000010000028af0000000b") + + """ + Store the Emergency Subscriber + """ + ueIp = self.get_avp_data(avps, 8)[0] + ueIp = str(self.hex_to_ip(ueIp)) + try: + #Get the IMSI + for SubscriptionIdentifier in self.get_avp_data(avps, 443): + for UniqueSubscriptionIdentifier in SubscriptionIdentifier: + if UniqueSubscriptionIdentifier['avp_code'] == 444: + imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') + except Exception as e: + imsi="Unknown" + + try: + ratType = self.get_avp_data(avps, 1032)[0] + ratType = int(ratType, 16) + except Exception as e: + ratType = None + + try: + accessNetworkGatewayAddress = self.get_avp_data(avps, 1050)[0] + accessNetworkGatewayAddress = str(self.hex_to_ip(accessNetworkGatewayAddress[4:])) + except Exception as e: + accessNetworkGatewayAddress = None + + try: + accessNetworkChargingAddress = self.get_avp_data(avps, 501)[0] + accessNetworkChargingAddress = str(self.hex_to_ip(accessNetworkChargingAddress[4:])) + except Exception as e: + accessNetworkChargingAddress = None + + emergencySubscriberData = { + "servingPgw": binascii.unhexlify(session_id).decode(), + "requestTime": int(time.time()), + "servingPcscf": None, + "aarRequestTime": None, + "gxOriginRealm": OriginRealm, + "gxOriginHost": OriginHost, + "imsi": imsi, + "ip": ueIp, + "ratType": ratType, + "accessNetworkGatewayAddress": accessNetworkGatewayAddress, + "accessNetworkChargingAddress": accessNetworkChargingAddress, + } + + self.database.Update_Emergency_Subscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, imsi=imsi, gxSessionId=emergencySubscriberData.get('servingPgw')) + + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + elif int(CC_Request_Type) == 3: + """ + If we've recieved a CCR-Terminate, delete the emergency subscriber. + """ + try: + ueIp = self.get_avp_data(avps, 8)[0] + ueIp = str(self.hex_to_ip(ueIp)) + except Exception as e: + ueIp = None + try: + #Get the IMSI + for SubscriptionIdentifier in self.get_avp_data(avps, 443): + for UniqueSubscriptionIdentifier in SubscriptionIdentifier: + if UniqueSubscriptionIdentifier['avp_code'] == 444: + imsi = binascii.unhexlify(UniqueSubscriptionIdentifier['misc_data']).decode('utf-8') + except Exception as e: + imsi="Unknown" + + try: + ratType = self.get_avp_data(avps, 1032)[0] + ratType = int(ratType, 16) + except Exception as e: + ratType = None + + try: + accessNetworkGatewayAddress = self.get_avp_data(avps, 1050)[0] + accessNetworkGatewayAddress = str(self.hex_to_ip(accessNetworkGatewayAddress)) + except Exception as e: + accessNetworkGatewayAddress = None + + try: + accessNetworkChargingAddress = self.get_avp_data(avps, 501)[0] + accessNetworkChargingAddress = str(self.hex_to_ip(accessNetworkChargingAddress)) + except Exception as e: + accessNetworkChargingAddress = None + + emergencySubscriberData = { + "servingPgw": binascii.unhexlify(session_id).decode(), + "requestTime": int(time.time()), + "gxOriginRealm": OriginRealm, + "gxOriginHost": OriginHost, + "imsi": imsi, + "ip": ueIp, + "ratType": ratType, + "accessNetworkGatewayAddress": accessNetworkGatewayAddress, + "accessNetworkChargingAddress": accessNetworkChargingAddress, + } + + self.database.Delete_Emergency_Subscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, imsi=imsi, gxSessionId=binascii.unhexlify(session_id).decode()) + + avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) + response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet + return response + + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777238_272] [CCA] Error generating SOS CCA: {traceback.format_exc()}", redisClient=self.redisMessaging) + #Get Subscriber info from Subscription ID for SubscriptionIdentifier in self.get_avp_data(avps, 443): for UniqueSubscriptionIdentifier in SubscriptionIdentifier: @@ -1936,8 +2156,6 @@ def Answer_16777238_272(self, packet_vars, avps): except Exception as e: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777238_272] [CCA] Failed to clear apn state for {apn}: {traceback.format_exc()}", redisClient=self.redisMessaging) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet except Exception as e: #Get subscriber details @@ -1954,8 +2172,6 @@ def Answer_16777238_272(self, packet_vars, avps): "imsi_prefix": str(imsi[0:6])}, metricHelp='Diameter Authentication related Counters', metricExpiry=60) - avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host - avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm avp += self.generate_avp(268, 40, self.int_to_hex(5030, 4)) #Result Code (DIAMETER ERROR - User Unknown) response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet return response @@ -2528,6 +2744,24 @@ def Answer_16777236_265(self, packet_vars, avps): imsi = None msisdn = None identifier = None + emergencySubscriber = False + + try: + ueIp = self.get_avp_data(avps, 8)[0] + ueIp = str(self.hex_to_ip(ueIp)) + except Exception as e: + ueIp = None + + """ + Determine if the AAR for the IP belongs to an inbound roaming emergency subscriber. + """ + try: + emergencySubscriberData = self.database.Get_Emergency_Subscriber(subscriberIp=ueIp) + if emergencySubscriberData: + emergencySubscriber = True + except Exception as e: + emergencySubscriberData = None + if '@' in subscriptionId: subscriberIdentifier = subscriptionId.split('@')[0] # Subscriber Identifier can be either an IMSI or an MSISDN @@ -2545,59 +2779,89 @@ def Answer_16777236_265(self, packet_vars, avps): msisdn = imsSubscriberDetails.get('msisdn', None) except Exception as e: pass + if identifier == None: + try: + ueIP = subscriptionId.split('@')[1].split(':')[0] + ue = self.database.Get_UE_by_IP(ueIP) + subscriberId = ue.get('subscriber_id', None) + subscriberDetails = self.database.Get_Subscriber(subscriber_id=subscriberId) + imsi = subscriberDetails.get('imsi', None) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Found IMSI {imsi} by IP: {ueIP}", redisClient=self.redisMessaging) + except Exception as e: + pass else: imsi = None msisdn = None + try: + ueIP = subscriptionId.split(':')[0] + ue = self.database.Get_UE_by_IP(ueIP) + subscriberId = ue.get('subscriber_id', None) + subscriberDetails = self.database.Get_Subscriber(subscriber_id=subscriberId) + imsi = subscriberDetails.get('imsi', None) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Found IMSI {imsi} by IP: {ueIP}", redisClient=self.redisMessaging) + except Exception as e: + pass + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] IMSI: {imsi}\nMSISDN: {msisdn}", redisClient=self.redisMessaging) imsEnabled = self.validateImsSubscriber(imsi=imsi, msisdn=msisdn) - if imsEnabled: + if imsEnabled or emergencySubscriber: """ Add the PCSCF to the IMS_Subscriber object, and set the result code to 2001. """ + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Request authorized", redisClient=self.redisMessaging) - if imsi is None: - imsi = subscriberDetails.get('imsi', None) + if imsEnabled and not emergencySubscriber: + + if imsi is None: + imsi = subscriberDetails.get('imsi', None) + + aarOriginHost = self.get_avp_data(avps, 264)[0] + aarOriginHost = bytes.fromhex(aarOriginHost).decode('ascii') + aarOriginRealm = self.get_avp_data(avps, 296)[0] + aarOriginRealm = bytes.fromhex(aarOriginRealm).decode('ascii') + #Check if we have a record-route set as that's where we'll need to send the response + try: + #Get first record-route header, then parse it + remotePeer = self.get_avp_data(avps, 282)[-1] + remotePeer = binascii.unhexlify(remotePeer).decode('utf-8') + except Exception as e: + #If we don't have a record-route set, we'll send the response to the OriginHost + remotePeer = aarOriginHost - aarOriginHost = self.get_avp_data(avps, 264)[0] - aarOriginHost = bytes.fromhex(aarOriginHost).decode('ascii') - aarOriginRealm = self.get_avp_data(avps, 296)[0] - aarOriginRealm = bytes.fromhex(aarOriginRealm).decode('ascii') - #Check if we have a record-route set as that's where we'll need to send the response - try: - #Get first record-route header, then parse it - remotePeer = self.get_avp_data(avps, 282)[-1] - remotePeer = binascii.unhexlify(remotePeer).decode('utf-8') - except Exception as e: - #If we don't have a record-route set, we'll send the response to the OriginHost - remotePeer = aarOriginHost - - remotePeer = f"{remotePeer};{self.config['hss']['OriginHost']}" + remotePeer = f"{remotePeer};{self.config['hss']['OriginHost']}" - self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=None) - """ - Check for AVP's 504 (AF-Application-Identifier) and 520 (Media-Type), which indicates the UE is making a call. - Media-Type: 0 = Audio, 4 = Control - """ + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=None) + """ + Check for AVP's 504 (AF-Application-Identifier) and 520 (Media-Type), which indicates the UE is making a call. + Media-Type: 0 = Audio, 4 = Control + """ try: afApplicationIdentifier = self.get_avp_data(avps, 504)[0] mediaType = self.get_avp_data(avps, 520)[0] assert(bytes.fromhex(afApplicationIdentifier).decode('ascii') == "IMS Services") assert(int(mediaType, 16) == 0) - # At this point, we know the AAR is indicating a call setup, so we'll send get the serving pgw information, then send a + # At this point, we know the AAR is indicating a call setup, so we'll get the serving pgw information, then send a # RAR to the PGW over Gx, asking it to setup the dedicated bearer. try: - subscriberId = subscriberDetails.get('subscriber_id', None) - apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) - servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) - servingPgwPeer = servingApn.get('serving_pgw_peer', None).split(';')[0] - servingPgw = servingApn.get('serving_pgw', None) - servingPgwRealm = servingApn.get('serving_pgw_realm', None) - pcrfSessionId = servingApn.get('pcrf_session_id', None) - ueIp = servingApn.get('subscriber_routing', None) + if emergencySubscriber and not imsEnabled: + servingPgwPeer = emergencySubscriberData.get('serving_pgw', None).split(';')[0] + pcrfSessionId = emergencySubscriberData.get('serving_pgw', None) + servingPgwRealm = emergencySubscriberData.get('gx_origin_realm', None) + servingPgw = emergencySubscriberData.get('serving_pgw', None).split(';')[0] + else: + subscriberId = subscriberDetails.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) + servingPgwPeer = servingApn.get('serving_pgw_peer', None).split(';')[0] + servingPgw = servingApn.get('serving_pgw', None) + servingPgwRealm = servingApn.get('serving_pgw_realm', None) + pcrfSessionId = servingApn.get('pcrf_session_id', None) + if not ueIp: + ueIp = servingApn.get('subscriber_routing', None) ulBandwidth = 512000 dlBandwidth = 512000 @@ -2614,6 +2878,56 @@ def Answer_16777236_265(self, packet_vars, avps): except Exception as e: pass + # Extract the SDP for both Uplink and Downlink, to create TFTs. + try: + sdpOffer = self.get_avp_data(avps, 524)[0] + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Got SDP Offer raw: {sdpOffer}", redisClient=self.redisMessaging) + sdpOffer = binascii.unhexlify(sdpOffer).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Got SDP Offer decoded: {sdpOffer}", redisClient=self.redisMessaging) + sdpAnswer = self.get_avp_data(avps, 524)[1] + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Got SDP Answer raw: {sdpAnswer}", redisClient=self.redisMessaging) + sdpAnswer = binascii.unhexlify(sdpAnswer).decode('utf-8') + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Got SDP Answer decoded: {sdpAnswer}", redisClient=self.redisMessaging) + + regexIpv4 = r"IN IP4 (\d*\.\d*\.\d*\.\d*)" + regexIpv6 = r"IN IP6 ([0-9a-fA-F:]{3,39})" + regexRtp = r"m=audio (\d*)" + regexRtcp = r"a=rtcp:(\d+)" + + sdpDownlink = None + sdpUplink = None + sdpDownlinkIpv4 = '' + sdpDownlinkRtpPort = '' + sdpUplinkRtpPort = '' + + # First, work out which side the SDP Downlink is, then do the same for the SDP Uplink. + if 'downlink' in sdpOffer.lower(): + sdpDownlink = sdpOffer + elif 'downlink' in sdpAnswer.lower(): + sdpDownlink = sdpAnswer + + if 'uplink' in sdpOffer.lower(): + sdpUplink = sdpOffer + elif 'uplink' in sdpAnswer.lower(): + sdpUplink = sdpAnswer + + # Grab the SDP Downlink IP + sdpDownlinkIpv4 = self.Match_SDP(regexPattern=regexIpv4, sdpBody=sdpDownlink) + sdpDownlinkIpv6 = self.Match_SDP(regexPattern=regexIpv6, sdpBody=sdpDownlink) + + # Get the RTP ports + sdpDownlinkRtpPort = self.Match_SDP(regexPattern=regexRtp, sdpBody=sdpDownlink) + sdpUplinkRtpPort = self.Match_SDP(regexPattern=regexRtp, sdpBody=sdpUplink) + + # The RTCP Port is always the RTP port + 1. Comma separated ports arent used due to lack of support in open source PGWs. + # We take a blind approach by setting a range of +1 on both sides. + sdpDownlinkRtpPorts = f"{sdpDownlinkRtpPort}-{int(sdpDownlinkRtpPort)+1}" + sdpUplinkRtpPorts = f"{sdpUplinkRtpPort}-{int(sdpUplinkRtpPort)+1}" + + + except Exception as e: + self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [Answer_16777236_265] [AAA] Failed to extract SDP due to error: {traceback.format_exc()}", redisClient=self.redisMessaging) + """ The below logic is applied: 1. Grab the Flow Rules and bitrates from the PCSCF in the AAR, @@ -2626,14 +2940,14 @@ def Answer_16777236_265(self, packet_vars, avps): chargingRule = { "charging_rule_id": 1000, "qci": 1, - "arp_preemption_capability": True, + "arp_preemption_capability": False, "mbr_dl": dlBandwidth, "mbr_ul": ulBandwidth, "gbr_ul": ulBandwidth, - "precedence": 100, - "arp_priority": 2, + "precedence": 40, + "arp_priority": 15, "rule_name": "GBR-Voice", - "arp_preemption_vulnerability": False, + "arp_preemption_vulnerability": True, "gbr_dl": dlBandwidth, "tft_group_id": 1, "rating_group": None, @@ -2642,19 +2956,37 @@ def Answer_16777236_265(self, packet_vars, avps): "tft_group_id": 1, "direction": 1, "tft_id": 1, - "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535" + "tft_string": f"permit out 17 from {sdpDownlinkIpv4}/32 {sdpDownlinkRtpPorts} to {ueIp}/32 {sdpUplinkRtpPorts}" }, { "tft_group_id": 1, "direction": 2, "tft_id": 2, - "tft_string": "permit out 17 from {{ UE_IP }}/32 1-65535 to any 1-65535" + "tft_string": f"permit out 17 from {sdpDownlinkIpv4}/32 {sdpDownlinkRtpPorts} to {ueIp}/32 {sdpUplinkRtpPorts}" } ] } - self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=sessionId) + if not emergencySubscriber: + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=aarOriginHost, pcscf_realm=aarOriginRealm, pcscf_peer=remotePeer, pcscf_active_session=sessionId) + else: + updatedEmergencySubscriberData = { + "servingPgw": emergencySubscriberData.get('serving_pgw'), + "requestTime": emergencySubscriberData.get('serving_pgw_timestamp'), + "servingPcscf": sessionId, + "aarRequestTime": int(time.time()), + "gxOriginRealm": emergencySubscriberData.get('gx_origin_realm'), + "gxOriginHost": emergencySubscriberData.get('gx_origin_host'), + "imsi": emergencySubscriberData.get('imsi'), + "ip": emergencySubscriberData.get('ip'), + "ratType": emergencySubscriberData.get('rat_type'), + "accessNetworkGatewayAddress": emergencySubscriberData.get('access_network_gateway_address'), + "accessNetworkChargingAddress": emergencySubscriberData.get('access_network_charging_address'), + } + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] Updating Emergency Subscriber: {updatedEmergencySubscriberData}", redisClient=self.redisMessaging) + self.database.Update_Emergency_Subscriber(subscriberIp=ueIp, subscriberData=updatedEmergencySubscriberData, imsi=imsi) + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [AAA] RAR Generated to be sent to serving PGW: {servingPgw} via peer {servingPgwPeer}", redisClient=self.redisMessaging) reAuthAnswer = self.awaitDiameterRequestAndResponse( requestType='RAR', hostname=servingPgwPeer, @@ -2778,22 +3110,47 @@ def Answer_16777236_275(self, packet_vars, avps): avp += self.generate_avp(263, 40, self.string_to_hex(sessionId)) #Set session ID to received session ID avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - imsSubscriber = self.database.Get_IMS_Subscriber_By_Session_Id(sessionId=sessionId) - imsi = imsSubscriber.get('imsi', None) - pcscf = imsSubscriber.get('pcscf', None) - pcscf_realm = imsSubscriber.get('pcscf_realm', None) - pcscf_peer = imsSubscriber.get('pcscf_peer', None) - subscriber = self.database.Get_Subscriber(imsi=imsi) - subscriberId = subscriber.get('subscriber_id', None) - apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) - servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) - self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=pcscf, pcscf_realm=pcscf_realm, pcscf_peer=pcscf_peer, pcscf_active_session=None) + servingApn = None + try: + imsSubscriber = self.database.Get_IMS_Subscriber_By_Session_Id(sessionId=sessionId) + imsi = imsSubscriber.get('imsi', None) + pcscf = imsSubscriber.get('pcscf', None) + pcscf_realm = imsSubscriber.get('pcscf_realm', None) + pcscf_peer = imsSubscriber.get('pcscf_peer', None) + subscriber = self.database.Get_Subscriber(imsi=imsi) + subscriberId = subscriber.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) + if servingApn is not None: + servingPgw = servingApn.get('serving_pgw', '') + servingPgwRealm = servingApn.get('serving_pgw_realm', '') + servingPgwPeer = servingApn.get('serving_pgw_peer', '').split(';')[0] + pcrfSessionId = servingApn.get('pcrf_session_id', None) + else: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] No servingApn defined for IMS Subscriber", redisClient=self.redisMessaging) + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=pcscf, pcscf_realm=pcscf_realm, pcscf_peer=pcscf_peer, pcscf_active_session=None) + except Exception as e: + pass + + """ + Determine if the Session-ID for the STR belongs to an inbound roaming emergency subscriber. + """ + try: + emergencySubscriberData = self.database.Get_Emergency_Subscriber(rxSessionId=sessionId) + if emergencySubscriberData: + emergencySubscriber = True + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] Found emergency subscriber with Rx Session: {sessionId}", redisClient=self.redisMessaging) + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] Error getting Emergency Subscriber Data: {traceback.format_exc()}", redisClient=self.redisMessaging) + emergencySubscriberData = None - if servingApn is not None: - servingPgw = servingApn.get('serving_pgw', '') - servingPgwRealm = servingApn.get('serving_pgw_realm', '') - servingPgwPeer = servingApn.get('serving_pgw_peer', '').split(';')[0] - pcrfSessionId = servingApn.get('pcrf_session_id', None) + if emergencySubscriberData: + servingPgwPeer = emergencySubscriberData.get('serving_pgw', None).split(';')[0] + pcrfSessionId = emergencySubscriberData.get('serving_pgw', None) + servingPgwRealm = emergencySubscriberData.get('gx_origin_realm', None) + servingPgw = emergencySubscriberData.get('serving_pgw', None).split(';')[0] + + if servingApn is not None or emergencySubscriberData: reAuthAnswer = self.awaitDiameterRequestAndResponse( requestType='RAR', hostname=servingPgwPeer, @@ -2803,7 +3160,7 @@ def Answer_16777236_275(self, packet_vars, avps): chargingRuleName='GBR-Voice', chargingRuleAction='remove' ) - + if not len(reAuthAnswer) > 0: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA Timeout: {reAuthAnswer}", redisClient=self.redisMessaging) assert() diff --git a/lib/diameterAsync.py b/lib/diameterAsync.py index f8c5be9..11240f4 100644 --- a/lib/diameterAsync.py +++ b/lib/diameterAsync.py @@ -2,6 +2,7 @@ import math import asyncio import yaml +import socket from messagingAsync import RedisMessagingAsync @@ -41,7 +42,7 @@ def __init__(self, logTool): self.redisMessaging = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) self.logTool = logTool - + self.hostname = socket.gethostname() #Generates rounding for calculating padding async def myRound(self, n, base=4): @@ -246,7 +247,7 @@ async def getConnectedPeersByType(self, peerType: str) -> list: if peerType not in peerTypes: return [] filteredConnectedPeers = [] - activePeers = await(self.redisMessaging.getValue(key="ActiveDiameterPeers")) + activePeers = await(self.redisMessaging.getValue(key="ActiveDiameterPeers", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter')) for key, value in activePeers.items(): if activePeers.get(key, {}).get('peerType', '') == 'pgw' and activePeers.get(key, {}).get('connectionStatus', '') == 'connected': diff --git a/lib/logtool.py b/lib/logtool.py index 8506113..b5877aa 100644 --- a/lib/logtool.py +++ b/lib/logtool.py @@ -1,6 +1,7 @@ import logging import logging.handlers as handlers import os, sys, time +import socket from datetime import datetime sys.path.append(os.path.realpath('../')) import asyncio @@ -41,6 +42,7 @@ def __init__(self, config: dict): self.redisMessagingAsync = RedisMessagingAsync(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) self.redisMessaging = RedisMessaging(host=self.redisHost, port=self.redisPort, useUnixSocket=self.redisUseUnixSocket, unixSocketPath=self.redisUnixSocketPath) + self.hostname = socket.gethostname() async def logAsync(self, service: str, level: str, message: str, redisClient=None) -> bool: """ @@ -55,7 +57,7 @@ async def logAsync(self, service: str, level: str, message: str, redisClient=Non timestamp = time.time() dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() print(f"[{dateTimeString}] [{level.upper()}] {message}") - await(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60)) + await(redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='log')) return True def log(self, service: str, level: str, message: str, redisClient=None) -> bool: @@ -71,7 +73,7 @@ def log(self, service: str, level: str, message: str, redisClient=None) -> bool: timestamp = time.time() dateTimeString = datetime.fromtimestamp(timestamp).strftime("%m/%d/%Y %H:%M:%S %Z").strip() print(f"[{dateTimeString}] [{level.upper()}] {message}") - redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60) + redisClient.sendLogMessage(serviceName=service.lower(), logLevel=level, logTimestamp=timestamp, message=message, logExpiry=60, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') return True def setupFileLogger(self, loggerName: str, logFilePath: str): diff --git a/lib/messaging.py b/lib/messaging.py index 7b376a3..55ada5c 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -13,11 +13,22 @@ def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=Fa else: self.redisClient = Redis(host=host, port=port) - def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: + def handlePrefix(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): + """ + Adds a prefix to the Key or Queue name, if enabled. + Returns the same Key or Queue if not enabled. + """ + if usePrefix: + return f"{prefixHostname}:{prefixServiceName}:{key}" + else: + return key + + def sendMessage(self, queue: str, message: str, queueExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a message in a given Queue (Key). """ try: + queue = self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) self.redisClient.rpush(queue, message) if queueExpiry is not None: self.redisClient.expire(queue, queueExpiry) @@ -25,7 +36,7 @@ def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: except Exception as e: return '' - def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: + def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a prometheus metric in a format readable by the metric service. """ @@ -44,35 +55,36 @@ def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricA } ]) - metricQueueName = f"metric" + queue = self.handlePrefix(key='metric', usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) try: - self.redisClient.rpush(metricQueueName, prometheusMetricBody) + self.redisClient.rpush(queue, prometheusMetricBody) if metricExpiry is not None: - self.redisClient.expire(metricQueueName, metricExpiry) + self.redisClient.expire(queue, metricExpiry) return f'Succesfully stored metric called: {metricName}, with value of: {metricType}' except Exception as e: return '' - def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None) -> str: + def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a message in a given Queue (Key). """ try: - logQueueName = f"log" + queue = self.handlePrefix(key='log', usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) logMessage = json.dumps({"message": message, "service": serviceName, "level": logLevel, "timestamp": logTimestamp}) - self.redisClient.rpush(logQueueName, logMessage) + self.redisClient.rpush(queue, logMessage) if logExpiry is not None: - self.redisClient.expire(logQueueName, logExpiry) - return f'{message} stored in {logQueueName} successfully.' + self.redisClient.expire(queue, logExpiry) + return f'{message} stored in {queue} successfully.' except Exception as e: return '' - def getMessage(self, queue: str) -> str: + def getMessage(self, queue: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Gets the oldest message from a given Queue (Key), while removing it from the key as well. Deletes the key if the last message is being removed. """ try: + queue = self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) message = self.redisClient.lpop(queue) if message is None: message = '' @@ -85,62 +97,68 @@ def getMessage(self, queue: str) -> str: except Exception as e: return '' - def getQueues(self, pattern: str='*') -> list: + def getQueues(self, pattern: str='*', usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> list: """ Returns all Queues (Keys) in the database. """ try: + pattern = self.handlePrefix(key=pattern, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) allQueues = self.redisClient.scan_iter(match=pattern) return [x.decode() for x in allQueues] except Exception as e: return f"{traceback.format_exc()}" - def getNextQueue(self, pattern: str='*') -> dict: + def getNextQueue(self, pattern: str='*', usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> dict: """ Returns the next Queue (Key) in the list. """ try: + pattern = self.handlePrefix(key=pattern, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) for nextQueue in self.redisClient.scan_iter(match=pattern): return nextQueue.decode() except Exception as e: return {} - def awaitMessage(self, key: str): + def awaitMessage(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): """ Blocks until a message is received at the given key, then returns the message. """ try: + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) message = self.redisClient.blpop(key) return tuple(data.decode() for data in message) except Exception as e: return '' - def awaitBulkMessage(self, key: str, count: int=100): + def awaitBulkMessage(self, key: str, count: int=100, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): """ Blocks until one or more messages are received at the given key, then returns the amount of messages specified by count. """ try: + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) message = self.redisClient.blmpop(0, 1, key, direction='RIGHT', count=count) return message except Exception as e: print(traceback.format_exc()) return '' - def deleteQueue(self, queue: str) -> bool: + def deleteQueue(self, queue: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> bool: """ Deletes the given Queue (Key) """ try: + queue = self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) self.redisClient.delete(queue) return True except Exception as e: return False - def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: + def setValue(self, key: str, value: str, keyExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a value under a given key and sets an expiry (in seconds) if provided. """ try: + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) self.redisClient.set(key, value) if keyExpiry is not None: self.redisClient.expire(key, keyExpiry) @@ -148,38 +166,41 @@ def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: except Exception as e: return '' - def getValue(self, key: str) -> str: + def getValue(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Gets the value stored under a given key. """ try: - message = self.redisClient.get(key) - if message is None: - message = '' - else: - return message + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) + message = self.redisClient.get(key) + if message is None: + message = '' + else: + return message except Exception as e: return '' - def getList(self, key: str) -> list: + def getList(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> list: """ Gets the list stored under a given key. """ try: - allResults = self.redisClient.lrange(key, 0, -1) - if allResults is None: - result = [] - else: - return [result.decode() for result in allResults] + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) + allResults = self.redisClient.lrange(key, 0, -1) + if allResults is None: + result = [] + else: + return [result.decode() for result in allResults] except Exception as e: return [] - def RedisHGetAll(self, key: str): + def RedisHGetAll(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): """ Wrapper for Redis HGETALL *Deprecated: will be removed upon completed database cleanup. """ try: + key = self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName) data = self.redisClient.hgetall(key) return data except Exception as e: diff --git a/lib/messagingAsync.py b/lib/messagingAsync.py index 6c33e0a..98bfc1b 100644 --- a/lib/messagingAsync.py +++ b/lib/messagingAsync.py @@ -1,4 +1,6 @@ import asyncio +import traceback +import socket import redis.asyncio as redis import time, json, uuid @@ -15,11 +17,22 @@ def __init__(self, host: str='localhost', port: int=6379, useUnixSocket: bool=Fa self.redisClient = redis.Redis(host=host, port=port) pass - async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> str: + async def handlePrefix(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): + """ + Adds a prefix to the Key or Queue name, if enabled. + Returns the same Key or Queue if not enabled. + """ + if usePrefix: + return f"{prefixHostname}:{prefixServiceName}:{key}" + else: + return key + + async def sendMessage(self, queue: str, message: str, queueExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. """ try: + queue = await(self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) await(self.redisClient.rpush(queue, message)) if queueExpiry is not None: await(self.redisClient.expire(queue, queueExpiry)) @@ -27,11 +40,12 @@ async def sendMessage(self, queue: str, message: str, queueExpiry: int=None) -> except Exception as e: return '' - async def sendBulkMessage(self, queue: str, messageList: list, queueExpiry: int=None) -> str: + async def sendBulkMessage(self, queue: str, messageList: list, queueExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Empties a given asyncio queue into a redis pipeline, then sends to redis. """ try: + queue = await(self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) redisPipe = self.redisClient.pipeline() for message in messageList: @@ -46,7 +60,7 @@ async def sendBulkMessage(self, queue: str, messageList: list, queueExpiry: int= except Exception as e: return '' - async def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None) -> str: + async def sendMetric(self, serviceName: str, metricName: str, metricType: str, metricAction: str, metricValue: float, metricHelp: str='', metricLabels: list=[], metricTimestamp: int=time.time_ns(), metricExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a prometheus metric in a format readable by the metric service, asynchronously. """ @@ -66,6 +80,7 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m ]) metricQueueName = f"metric" + metricQueueName = await(self.handlePrefix(key=metricQueueName, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) try: async with self.redisClient.pipeline(transaction=True) as redisPipe: @@ -77,12 +92,13 @@ async def sendMetric(self, serviceName: str, metricName: str, metricType: str, m except Exception as e: return '' - async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None) -> str: + async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: int, message: str, logExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a log message in a given Queue (Key) asynchronously and sets an expiry (in seconds) if provided. """ try: logQueueName = f"log" + logQueueName = await(self.handlePrefix(key=logQueueName, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) logMessage = json.dumps({"message": message, "service": serviceName, "level": logLevel, "timestamp": logTimestamp}) async with self.redisClient.pipeline(transaction=True) as redisPipe: await redisPipe.rpush(logQueueName, logMessage) @@ -93,42 +109,49 @@ async def sendLogMessage(self, serviceName: str, logLevel: str, logTimestamp: in except Exception as e: return '' - async def getMessage(self, queue: str) -> str: + async def getMessage(self, queue: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Gets the oldest message from a given Queue (Key) asynchronously, while removing it from the key as well. Deletes the key if the last message is being removed. """ try: - message = await(self.redisClient.lpop(queue)) - if message is None: - message = '' - else: - try: - if message[0] is None: - return '' - else: - message = message[0].decode() - except (UnicodeDecodeError, AttributeError): - pass - return message + queue = await(self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) + message = await(self.redisClient.lpop(queue)) + if message is None: + message = '' + else: + try: + if message[0] is None: + return '' + else: + message = message[0].decode() + except (UnicodeDecodeError, AttributeError): + pass + return message except Exception as e: return '' - async def getQueues(self, pattern: str='*') -> list: + async def getQueues(self, pattern: str='*', usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> list: """ Returns all Queues (Keys) in the database, asynchronously. """ try: - allQueuesBinary = await(self.redisClient.scan_iter(match=pattern)) - allQueues = [x.decode() for x in allQueuesBinary] - return allQueues + pattern = await(self.handlePrefix(key=pattern, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) + allQueuesBinary = [] + async for nextQueue in self.redisClient.scan_iter(match=pattern): + if nextQueue: + allQueuesBinary.append(nextQueue) + allQueues = [x.decode() for x in allQueuesBinary] + return allQueues except Exception as e: + print(traceback.format_exc()) return [] - async def getNextQueue(self, pattern: str='*') -> str: + async def getNextQueue(self, pattern: str='*', usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Returns the next Queue (Key) in the list, asynchronously. """ try: + pattern = await(self.handlePrefix(key=pattern, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) async for nextQueue in self.redisClient.scan_iter(match=pattern): if nextQueue is not None: return nextQueue.decode('utf-8') @@ -136,50 +159,66 @@ async def getNextQueue(self, pattern: str='*') -> str: print(e) return '' - async def awaitMessage(self, key: str): + async def awaitMessage(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): """ Asynchronously blocks until a message is received at the given key, then returns the message. """ try: + key = await(self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) message = (await(self.redisClient.blpop(key))) return tuple(data.decode() for data in message) except Exception as e: return '' - async def deleteQueue(self, queue: str) -> bool: + async def awaitBulkMessage(self, key: str, count: int=100, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common'): + """ + Asynchronously blocks until one or more messages are received at the given key, then returns the amount of messages specified by count. + """ + try: + key = await(self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) + message = await(self.redisClient.blmpop(0, 1, key, direction='RIGHT', count=count)) + return message + except Exception as e: + print(traceback.format_exc()) + return '' + + async def deleteQueue(self, queue: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> bool: """ Deletes the given Queue (Key) asynchronously. """ try: + queue = await(self.handlePrefix(key=queue, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) deleteQueueResult = await(self.redisClient.delete(queue)) return True except Exception as e: return False - async def setValue(self, key: str, value: str, keyExpiry: int=None) -> str: + async def setValue(self, key: str, value: str, keyExpiry: int=None, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Stores a value under a given key asynchronously and sets an expiry (in seconds) if provided. """ try: + key = await(self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) async with self.redisClient.pipeline(transaction=True) as redisPipe: await redisPipe.set(key, value) if keyExpiry is not None: - await redisPipe.expire(key, value) - setValueResult, expireValueResult = await redisPipe.execute() + await redisPipe.expire(key, keyExpiry) + setValueResult = await redisPipe.execute() return f'{value} stored in {key} successfully.' except Exception as e: - return '' + return traceback.format_exc() - async def getValue(self, key: str) -> str: + async def getValue(self, key: str, usePrefix: bool=False, prefixHostname: str='unknown', prefixServiceName: str='common') -> str: """ Gets the value stored under a given key asynchronously. """ try: - message = await(self.redisClient.get(key)) - if message is None: - message = '' - else: - return message + key = await(self.handlePrefix(key=key, usePrefix=usePrefix, prefixHostname=prefixHostname, prefixServiceName=prefixServiceName)) + message = await(self.redisClient.get(key)) + if message is None: + message = '' + else: + return message except Exception as e: return '' diff --git a/services/apiService.py b/services/apiService.py index 2fba9d2..e371e84 100644 --- a/services/apiService.py +++ b/services/apiService.py @@ -69,6 +69,8 @@ SUBSCRIBER_ROUTING = database.SUBSCRIBER_ROUTING ROAMING_NETWORK = database.ROAMING_NETWORK ROAMING_RULE = database.ROAMING_RULE +EMERGENCY_SUBSCRIBER = database.EMERGENCY_SUBSCRIBER + apiService.wsgi_app = ProxyFix(apiService.wsgi_app) api = Api(apiService, version='1.0', title=f'{siteName + " - " if siteName else ""}{originHostname} - PyHSS OAM API', @@ -128,14 +130,15 @@ databaseClient.Generate_JSON_Model_for_Flask(ROAMING_RULE) ) +EMERGENCY_SUBSCRIBER_model = api.schema_model('EMERGENCY_SUBSCRIBER JSON', + databaseClient.Generate_JSON_Model_for_Flask(EMERGENCY_SUBSCRIBER) +) #Legacy support for sh_profile. sh_profile is deprecated as of v1.0.1. -imsSubscriberModel = databaseClient.Generate_JSON_Model_for_Flask(TFT) +imsSubscriberModel = databaseClient.Generate_JSON_Model_for_Flask(IMS_SUBSCRIBER) imsSubscriberModel['sh_profile'] = fields.String(required=False, description=IMS_SUBSCRIBER.sh_profile.doc), -IMS_SUBSCRIBER_model = api.schema_model('IMS_SUBSCRIBER JSON', - databaseClient.Generate_JSON_Model_for_Flask(TFT) -) +IMS_SUBSCRIBER_model = api.schema_model('IMS_SUBSCRIBER JSON', databaseClient.Generate_JSON_Model_for_Flask(IMS_SUBSCRIBER)) TFT_model = api.schema_model('TFT JSON', databaseClient.Generate_JSON_Model_for_Flask(TFT) @@ -204,26 +207,21 @@ 'scscf_timestamp' : fields.String(description=IMS_SUBSCRIBER.scscf_timestamp.doc), 'imei' : fields.String(description=EIR.imei.doc), 'match_response_code' : fields.String(description=EIR.match_response_code.doc), + 'emergency_subscriber_id': fields.String(description=EMERGENCY_SUBSCRIBER.emergency_subscriber_id.doc), + 'emergency_subscriber_imsi': fields.String(description=EMERGENCY_SUBSCRIBER.imsi.doc), + 'emergency_subscriber_serving_pgw': fields.String(description=EMERGENCY_SUBSCRIBER.serving_pgw.doc), + 'emergency_subscriber_serving_pgw_timestamp': fields.String(description=EMERGENCY_SUBSCRIBER.serving_pgw_timestamp.doc), + 'emergency_subscriber_serving_pcscf': fields.String(description=EMERGENCY_SUBSCRIBER.serving_pcscf.doc), + 'emergency_subscriber_serving_pcscf_timestamp': fields.String(description=EMERGENCY_SUBSCRIBER.serving_pcscf_timestamp.doc), + 'emergency_subscriber_gx_origin_realm': fields.String(description=EMERGENCY_SUBSCRIBER.gx_origin_realm.doc), + 'emergency_subscriber_gx_origin_host': fields.String(description=EMERGENCY_SUBSCRIBER.gx_origin_host.doc), + 'emergency_subscriber_rat_type': fields.String(description=EMERGENCY_SUBSCRIBER.rat_type.doc), + 'emergency_subscriber_ip': fields.String(description=EMERGENCY_SUBSCRIBER.ip.doc), + 'emergency_subscriber_access_network_gateway_address': fields.String(description=EMERGENCY_SUBSCRIBER.access_network_gateway_address.doc), + 'emergency_subscriber_access_network_charging_address': fields.String(description=EMERGENCY_SUBSCRIBER.access_network_charging_address.doc), + 'emergency_subscriber_delete': fields.Boolean(description="Whether to delete the emergency subscriber on receipt"), }) -Geored_schema = { - 'serving_mme': "string", - 'serving_mme_realm': "string", - 'serving_mme_peer': "string", - 'serving_mme_timestamp': "string", - 'serving_apn' : "string", - 'pcrf_session_id' : "string", - 'subscriber_routing' : "string", - 'serving_pgw' : "string", - 'serving_pgw_timestamp' : "string", - 'scscf' : "string", - 'imei' : "string", - 'match_response_code' : "string", - 'auc_id': "int", - 'sqn': "int", -} - - def no_auth_required(f): f.no_auth_required = True return f @@ -1304,7 +1302,7 @@ class PyHSS_OAM_Peers(Resource): def get(self): '''Get active Diameter Peers''' try: - diameterPeers = json.loads(redisMessaging.getValue("ActiveDiameterPeers")) + diameterPeers = json.loads(redisMessaging.getValue("ActiveDiameterPeers", usePrefix=True, prefixHostname=originHostname, prefixServiceName='diameter')) return diameterPeers, 200 except Exception as E: logTool.log(service='API', level='error', message=f"[API] An error occurred: {traceback.format_exc()}", redisClient=redisMessaging) @@ -1772,7 +1770,7 @@ def put(self): return result, 400 activeSubscribers = databaseClient.Get_Subscribers_By_Pcscf(pcscf=pcscf) - logTool.log(service='API', level='debug', message=f"[API] Active Subscribers for {pcscf}: {activeSubscribers}", redisClient=redisMessaging) + logTool.log(service='API', level='debug', message=f"[API] [pcscf_restoration] Active Subscribers for {pcscf}: {activeSubscribers}", redisClient=redisMessaging) if len(activeSubscribers) > 0: for imsSubscriber in activeSubscribers: @@ -1798,6 +1796,7 @@ def put(self): ) except Exception as e: + logTool.log(service='API', level='error', message=f"[API] [pcscf_restoration] Error sending CLR for subscriber: {traceback.format_exc()}", redisClient=redisMessaging) continue result = {"Result": f"Successfully sent PCSCF Restoration request for PCSCF: {pcscf}"} @@ -1818,6 +1817,7 @@ def get(self, charging_rule_id): print(E) return handle_exception(E) + @ns_pcrf.route('/subscriber_routing/') class PyHSS_PCRF_SUBSCRIBER_ROUTING(Resource): def get(self, subscriber_routing): @@ -1829,6 +1829,77 @@ def get(self, subscriber_routing): print(E) return handle_exception(E) +@ns_pcrf.route('/emergency_subscriber/') +class PyHSS_EMERGENCY_SUBSCRIBER_Get(Resource): + def get(self, emergency_subscriber_id): + '''Get all emergency_subscriber data for specified emergency_subscriber ID''' + try: + apn_data = databaseClient.GetObj(EMERGENCY_SUBSCRIBER, emergency_subscriber_id) + return apn_data, 200 + except Exception as E: + print(E) + return handle_exception(E) + + def delete(self, emergency_subscriber_id): + '''Delete all emergency_subscriber data for specified emergency_subscriber ID''' + try: + args = parser.parse_args() + operation_id = args.get('operation_id', None) + data = databaseClient.DeleteObj(EMERGENCY_SUBSCRIBER, emergency_subscriber_id, False, operation_id) + return data, 200 + except Exception as E: + print(E) + return handle_exception(E) + + @ns_pcrf.doc('Update EMERGENCY_SUBSCRIBER Object') + @ns_pcrf.expect(EMERGENCY_SUBSCRIBER_model) + def patch(self, emergency_subscriber_id): + '''Update emergency_subscriber data for specified emergency_subscriber ID''' + try: + json_data = request.get_json(force=True) + print("JSON Data sent: " + str(json_data)) + args = parser.parse_args() + operation_id = args.get('operation_id', None) + apn_data = databaseClient.UpdateObj(EMERGENCY_SUBSCRIBER, json_data, emergency_subscriber_id, False, operation_id) + + print("Updated object") + print(apn_data) + return apn_data, 200 + except Exception as E: + print(E) + return handle_exception(E) + +@ns_pcrf.route('/emergency_subscriber/') +class PyHSS_EMERGENCY_SUBSCRIBER(Resource): + @ns_pcrf.doc('Create EMERGENCY_SUBSCRIBER Object') + @ns_pcrf.expect(EMERGENCY_SUBSCRIBER_model) + def put(self): + '''Create new EMERGENCY_SUBSCRIBER''' + try: + json_data = request.get_json(force=True) + print("JSON Data sent: " + str(json_data)) + args = parser.parse_args() + operation_id = args.get('operation_id', None) + emergency_subscriber_id = databaseClient.CreateObj(EMERGENCY_SUBSCRIBER, json_data, False, operation_id) + + return emergency_subscriber_id, 200 + except Exception as E: + print(E) + return handle_exception(E) + +@ns_pcrf.route('/emergency_subscriber/list') +class PyHSS_ALL_EMERGENCY_SUBSCRIBER(Resource): + @ns_apn.expect(paginatorParser) + def get(self): + '''Get all Emergency Subscribers''' + try: + args = paginatorParser.parse_args() + data = databaseClient.getAllPaginated(EMERGENCY_SUBSCRIBER, args['page'], args['page_size']) + return (data), 200 + except Exception as E: + print(E) + return handle_exception(E) + @ns_geored.route('/') class PyHSS_Geored(Resource): @ns_geored.doc('Receive GeoRed data') @@ -1935,6 +2006,50 @@ def patch(self): "geored_host": request.remote_addr, }, metricExpiry=60) + if 'emergency_subscriber_ip' in json_data: + """ + If we receive a geored payload containing emergency_subscriber_id, create or update the matching emergency_subscriber_id. + If emergency_subscriber_id exists as None, then remove the emergency subscriber. + """ + print("Updating Emergency Subscriber") + subscriberData = { + "imsi": json_data.get('emergency_subscriber_imsi'), + "servingPgw": json_data.get('emergency_subscriber_serving_pgw'), + "requestTime": json_data.get('emergency_subscriber_serving_pgw_timestamp'), + "servingPcscf": json_data.get('emergency_subscriber_serving_pcscf'), + "aarRequestTime": json_data.get('emergency_subscriber_serving_pcscf_timestamp'), + "gxOriginRealm": json_data.get('emergency_subscriber_gx_origin_realm'), + "gxOriginHost": json_data.get('emergency_subscriber_gx_origin_host'), + "ratType": json_data.get('emergency_subscriber_rat_type'), + "ip": json_data.get('emergency_subscriber_ip'), + "accessNetworkGatewayAddress": json_data.get('emergency_subscriber_access_network_gateway_address'), + "accessNetworkChargingAddress": json_data.get('emergency_subscriber_access_network_charging_address'), + } + + if not json_data.get('emergency_subscriber_ip', None): + logTool.log(service='API', level='error', message=f"[API] emergency_subscriber_ip missing from geored request. No changes to emergency_subscriber made.", redisClient=redisMessaging) + return {'result': 'Failed', 'Reason' : "emergency_subscriber_ip missing from geored request"} + + if 'emergency_subscriber_delete' in json_data: + if json_data.get('emergency_subscriber_delete', False): + databaseClient.Delete_Emergency_Subscriber(subscriberIp=subscriberData.get('ip'), imsi=subscriberData.get('imsi'), propagate=False) + return {}, 200 + + response_data.append(databaseClient.Update_Emergency_Subscriber(emergencySubscriberId=json_data['emergency_subscriber_id'], + subscriberData=subscriberData, + imsi=subscriberData.get('imsi'), + subscriberIp=subscriberData.get('ip'), + gxSessionId=subscriberData.get('servingPgw'), + propagate=False)) + + redisMessaging.sendMetric(serviceName='api', metricName='prom_flask_http_geored_endpoints', + metricType='counter', metricAction='inc', + metricValue=1.0, metricHelp='Number of Geored Pushes Received', + metricLabels={ + "endpoint": "EMERGENCY_SUBSCRIBER", + "geored_host": request.remote_addr, + }, + metricExpiry=60) return response_data, 200 except Exception as E: print("Exception when updating: " + str(E)) diff --git a/services/diameterService.py b/services/diameterService.py index 04629b3..0555097 100644 --- a/services/diameterService.py +++ b/services/diameterService.py @@ -2,6 +2,7 @@ import sys, os, json import time, yaml, uuid from datetime import datetime +import sctp, socket sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from diameterAsync import DiameterAsync @@ -41,6 +42,7 @@ def __init__(self): self.diameterRequests = 0 self.diameterResponses = 0 self.workerPoolSize = int(self.config.get('hss', {}).get('diameter_service_workers', 10)) + self.hostname = socket.gethostname() async def validateDiameterInbound(self, clientAddress: str, clientPort: str, inboundData) -> bool: """ @@ -86,7 +88,7 @@ async def handleActiveDiameterPeers(self): del self.activePeers[key] await(self.logActivePeers()) - await(self.redisPeerMessaging.setValue(key='ActiveDiameterPeers', value=json.dumps(self.activePeers), keyExpiry=86400)) + await(self.redisPeerMessaging.setValue(key='ActiveDiameterPeers', value=json.dumps(self.activePeers), keyExpiry=86400, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter')) await(asyncio.sleep(1)) except Exception as e: @@ -174,7 +176,7 @@ async def inboundDataWorker(self, coroutineUuid: str) -> bool: break if messageList: - await self.redisReaderMessaging.sendBulkMessage(queue=inboundQueueName, messageList=messageList, queueExpiry=self.diameterRequestTimeout) + await self.redisReaderMessaging.sendBulkMessage(queue=inboundQueueName, messageList=messageList, queueExpiry=self.diameterRequestTimeout, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') messageList = [] except Exception as e: @@ -189,7 +191,7 @@ async def writeOutboundData(self, writer, clientAddress: str, clientPort: str, s while not writer.transport.is_closing(): try: await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Waiting for messages for host {clientAddress} on port {clientPort}")) - pendingOutboundMessage = json.loads((await(self.redisWriterMessaging.awaitMessage(key=f"diameter-outbound-{clientAddress}-{clientPort}")))[1]) + pendingOutboundMessage = json.loads((await(self.redisWriterMessaging.awaitMessage(key=f"diameter-outbound-{clientAddress}-{clientPort}", usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter')))[1]) diameterOutboundBinary = bytes.fromhex(pendingOutboundMessage.get('diameter-outbound', '')) await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [writeOutboundData] [{coroutineUuid}] Sending: {diameterOutboundBinary.hex()} to to {clientAddress} on {clientPort}.")) @@ -286,6 +288,28 @@ async def startServer(self, host: str=None, port: int=None, type: str=None): if type.upper() == 'TCP': server = await(asyncio.start_server(self.handleConnection, host, port)) + elif type.upper() == 'SCTP': + self.sctpSocket = sctp.sctpsocket_tcp(socket.AF_INET) + self.sctpSocket.setblocking(False) + self.sctpSocket.events.clear() + self.sctpSocket.bind((host, port)) + self.sctpRtoInfo = self.sctpSocket.get_rtoinfo() + self.sctpRtoMin = self.config.get('hss', {}).get('sctp', {}).get('rtoMin', 500) + self.sctpRtoMax = self.config.get('hss', {}).get('sctp', {}).get('rtoMax', 5000) + self.sctpRtoInitial = self.config.get('hss', {}).get('sctp', {}).get('rtoInitial', 1000) + self.sctpRtoInfo.initial = int(self.sctpRtoInitial) + self.sctpRtoInfo.max = int(self.sctpRtoMax) + self.sctpRtoInfo.min = int(self.sctpRtoMin) + self.sctpSocket.set_rtoinfo(self.sctpRtoInfo) + self.sctpAssociatedParameters = self.sctpSocket.get_assocparams() + sctpInitParameters = { "initialRto": self.sctpRtoInfo.initial, + "rtoMin": self.sctpRtoInfo.min, + "rtoMax": self.sctpRtoInfo.max + } + self.sctpSocket.listen() + await(self.logTool.logAsync(service='Diameter', level='debug', message=f"[Diameter] [startServer] SCTP Parameters: {sctpInitParameters}")) + + server = await(asyncio.start_server(self.handleConnection, sock=self.sctpSocket)) else: return False servingAddresses = ', '.join(str(sock.getsockname()) for sock in server.sockets) diff --git a/services/georedService.py b/services/georedService.py index 12f1680..14d3117 100644 --- a/services/georedService.py +++ b/services/georedService.py @@ -1,6 +1,7 @@ import os, sys, json, yaml import uuid, time import asyncio, aiohttp +import socket sys.path.append(os.path.realpath('../lib')) from messagingAsync import RedisMessagingAsync from banners import Banners @@ -32,6 +33,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.georedPeers = self.config.get('geored', {}).get('endpoints', []) self.webhookPeers = self.config.get('webhooks', {}).get('endpoints', []) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + self.hostname = socket.gethostname() if not self.config.get('geored', {}).get('enabled'): self.logger.error("[Geored] Fatal Error - geored not enabled under geored.enabled, exiting.") @@ -56,7 +58,7 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr if operation not in requestOperations: return False - headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId)} + headers = {"Content-Type": "application/json", "Transaction-Id": str(transactionId), "User-Agent": f"PyHSS/1.0.1 (Geored)"} for attempt in range(retryCount): try: @@ -145,7 +147,7 @@ async def sendGeored(self, asyncSession, url: str, operation: str, body: str, tr return True - async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, headers: str, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: + async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, headers: dict, transactionId: str=uuid.uuid4(), retryCount: int=3) -> bool: """ Sends a Webhook HTTP request to a given endpoint. """ @@ -159,6 +161,9 @@ async def sendWebhook(self, asyncSession, url: str, operation: str, body: str, h if operation not in requestOperations: return False + + if 'User-Agent' not in headers: + headers['User-Agent'] = f"PyHSS/1.0.1 (Webhook)" for attempt in range(retryCount): try: @@ -255,7 +260,7 @@ async def handleAsymmetricGeoredQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='asymmetric-geored')))[1]) + georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='asymmetric-geored', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='geored')))[1]) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleAsymmetricGeoredQueue] Message: {georedMessage}")) georedOperation = georedMessage['operation'] @@ -286,7 +291,7 @@ async def handleGeoredQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='geored')))[1]) + georedMessage = json.loads((await(self.redisGeoredMessaging.awaitMessage(key='geored', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='geored')))[1]) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleGeoredQueue] Message: {georedMessage}")) georedOperation = georedMessage['operation'] @@ -316,7 +321,7 @@ async def handleWebhookQueue(self): try: if self.benchmarking: startTime = time.perf_counter() - webhookMessage = json.loads((await(self.redisWebhookMessaging.awaitMessage(key='webhook')))[1]) + webhookMessage = json.loads((await(self.redisWebhookMessaging.awaitMessage(key='webhook', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='webhook')))[1]) await(self.logTool.logAsync(service='Geored', level='debug', message=f"[Geored] [handleWebhookQueue] Message: {webhookMessage}")) diff --git a/services/hssService.py b/services/hssService.py index 46abcbc..cb74bdb 100644 --- a/services/hssService.py +++ b/services/hssService.py @@ -1,4 +1,4 @@ -import os, sys, json, yaml, time, traceback +import os, sys, json, yaml, time, traceback, socket sys.path.append(os.path.realpath('../lib')) from messaging import RedisMessaging from diameter import Diameter @@ -30,6 +30,7 @@ def __init__(self): self.logTool.log(service='HSS', level='info', message=f"{self.banners.hssService()}", redisClient=self.redisMessaging) self.diameterLibrary = Diameter(logTool=self.logTool, originHost=self.originHost, originRealm=self.originRealm, productName=self.productName, mcc=self.mcc, mnc=self.mnc) self.benchmarking = self.config.get('hss').get('enable_benchmarking', False) + self.hostname = socket.gethostname() def handleQueue(self): """ @@ -40,7 +41,7 @@ def handleQueue(self): if self.benchmarking: startTime = time.perf_counter() - inboundMessageList = self.redisMessaging.awaitBulkMessage(key='diameter-inbound') + inboundMessageList = self.redisMessaging.awaitBulkMessage(key='diameter-inbound', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if inboundMessageList == None: continue @@ -86,7 +87,7 @@ def handleQueue(self): self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound Queue: {outboundQueue}", redisClient=self.redisMessaging) self.logTool.log(service='HSS', level='debug', message=f"[HSS] [handleQueue] [{diameterMessageTypeOutbound}] Outbound Diameter Outbound: {outboundMessage}", redisClient=self.redisMessaging) - self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60) + self.redisMessaging.sendMessage(queue=outboundQueue, message=outboundMessage, queueExpiry=60, usePrefix=True, prefixHostname=self.hostname, prefixServiceName='diameter') if self.benchmarking: self.logTool.log(service='HSS', level='info', message=f"[HSS] [handleQueue] [{diameterMessageTypeInbound}] Time taken to process request: {round(((time.perf_counter() - startTime)*1000), 3)} ms", redisClient=self.redisMessaging) diff --git a/services/logService.py b/services/logService.py index 34e7ae0..827662a 100644 --- a/services/logService.py +++ b/services/logService.py @@ -1,4 +1,4 @@ -import os, sys, json, yaml +import os, sys, json, yaml, socket from datetime import datetime import time import logging @@ -37,6 +37,8 @@ def __init__(self): 'DEBUG': {'verbosity': 5, 'logging': logging.DEBUG}, 'NOTSET': {'verbosity': 6, 'logging': logging.NOTSET}, } + self.hostname = socket.gethostname() + print(f"{self.banners.logService()}") def handleLogs(self): @@ -46,7 +48,7 @@ def handleLogs(self): activeLoggers = {} while True: try: - logMessage = json.loads(self.redisMessaging.awaitMessage(key='log')[1]) + logMessage = json.loads(self.redisMessaging.awaitMessage(key='log', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='log')[1]) print(f"[Log] Message: {logMessage}") diff --git a/services/metricService.py b/services/metricService.py index 12d51c1..968ddc5 100644 --- a/services/metricService.py +++ b/services/metricService.py @@ -1,6 +1,7 @@ import asyncio import sys, os, json import time, json, yaml +import socket from prometheus_client import make_wsgi_app, start_http_server, Counter, Gauge, Summary, Histogram, CollectorRegistry from werkzeug.middleware.dispatcher import DispatcherMiddleware from flask import Flask @@ -25,6 +26,7 @@ def __init__(self, redisHost: str='127.0.0.1', redisPort: int=6379): self.logTool = LogTool(config=self.config) self.registry = CollectorRegistry(auto_describe=True) self.logTool.log(service='Metric', level='info', message=f"{self.banners.metricService()}", redisClient=self.redisMessaging) + self.hostname = socket.gethostname() def handleMetrics(self): """ @@ -34,7 +36,7 @@ def handleMetrics(self): actions = {'inc': 'inc', 'dec': 'dec', 'set':'set'} prometheusTypes = {'counter': Counter, 'gauge': Gauge, 'histogram': Histogram, 'summary': Summary} - metric = self.redisMessaging.awaitMessage(key='metric')[1] + metric = self.redisMessaging.awaitMessage(key='metric', usePrefix=True, prefixHostname=self.hostname, prefixServiceName='metric')[1] self.logTool.log(service='Metric', level='debug', message=f"[Metric] [handleMetrics] Received Metric: {metric}", redisClient=self.redisMessaging) prometheusJsonList = json.loads(metric)