diff --git a/lib/diameter.py b/lib/diameter.py index adb7c61..e62e748 100644 --- a/lib/diameter.py +++ b/lib/diameter.py @@ -1021,21 +1021,21 @@ def validateSubscriberRoaming(self, subscriber: dict, mcc: str, mnc: str) -> boo return True - def storeEmergencySubscriber(self, subscriberIp: str, subscriberData: dict, authExpiry: int=3600, subscriberImsi: str="Unknown") -> bool: + def storeEmergencySubscriber(self, subscriberIp: str, subscriberData: dict, gxSessionId: str, authExpiry: int=3600, subscriberImsi: str="Unknown") -> bool: """ Store a given Emergency Subscriber in redis. If there's an existing entry for the same IMSI, then update the record with the new IP and details. The subscriber entry will expire per authExpiry in seconds. """ try: - emergencySubscriberKey = f"emergencySubscriber-{subscriberIp}-{subscriberImsi}" + emergencySubscriberKey = f"emergencySubscriber:{subscriberIp}:{subscriberImsi}:{gxSessionId}" # Check if our subscriber exists if subscriberImsi and subscriberImsi != "Unknown": existingEmergencySubscriber = self.getEmergencySubscriber(subscriberImsi=subscriberImsi) if existingEmergencySubscriber: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [getEmergencySubscriber] Found existing emergency subscriber to overwrite: {existingEmergencySubscriber}", redisClient=self.redisMessaging) for key, value in existingEmergencySubscriber.items(): - self.redisMessaging.multiDeleteQueue(queue=f"emergencySubscriber-{value.get('ip')}-{value.get('imsi')}", redisPeerConnections=self.redisPeerConnections) + self.redisMessaging.multiDeleteQueue(queue=f"emergencySubscriber:{value.get('ip')}:{value.get('imsi')}:{value.get('servingPgw')}", redisPeerConnections=self.redisPeerConnections) result = self.redisMessaging.multiSetValue(key=emergencySubscriberKey, value=json.dumps(subscriberData), keyExpiry=authExpiry, redisPeerConnections=self.redisPeerConnections) return True except Exception as e: @@ -1043,7 +1043,7 @@ def storeEmergencySubscriber(self, subscriberIp: str, subscriberData: dict, auth return False - def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=None) -> dict: + def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=None, gxSessionId: str=None) -> dict: """ Retrieves a provided Emergency Subscriber from redis, if it exists. The first match from any defined redis instance is used. @@ -1055,7 +1055,7 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non return None if subscriberIp and subscriberImsi: - emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber-{subscriberIp}-{subscriberImsi}") + emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:{subscriberIp}:{subscriberImsi}:*") if emergencySubscriberKeyList: for matchedKey in emergencySubscriberKeyList: for peerName, keyName in matchedKey.items(): @@ -1069,7 +1069,7 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non return emergencySubscriber if subscriberIp and not subscriberImsi: - emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber-{subscriberIp}-*") + emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:{subscriberIp}:*") if emergencySubscriberKeyList: for matchedKey in emergencySubscriberKeyList: for peerName, keyName in matchedKey.items(): @@ -1083,7 +1083,7 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non return emergencySubscriber if subscriberImsi and not subscriberIp: - emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber-*-{subscriberImsi}") + emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:*:{subscriberImsi}:*") if emergencySubscriberKeyList: for matchedKey in emergencySubscriberKeyList: for peerName, keyName in matchedKey.items(): @@ -1095,7 +1095,21 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non emergencySubscriberData = json.loads(emergencySubscriberData) emergencySubscriber = {peerName: emergencySubscriberData} return emergencySubscriber - + + if gxSessionId: + emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:*:*:{gxSessionId}") + if emergencySubscriberKeyList: + for matchedKey in emergencySubscriberKeyList: + for peerName, keyName in matchedKey.items(): + if isinstance(keyName, list): + keyName = keyName[0] if len(keyName) > 0 else '' + emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.getRedisPeerConnection(peerName=peerName)) + if not emergencySubscriberData: + return None + emergencySubscriberData = json.loads(emergencySubscriberData) + emergencySubscriber = {peerName: emergencySubscriberData} + return emergencySubscriber + return None except Exception as e: @@ -2010,7 +2024,7 @@ def Answer_16777238_272(self, packet_vars, avps): "accessNetworkChargingAddress": accessNetworkChargingAddress, } - self.storeEmergencySubscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, subscriberImsi=imsi) + self.storeEmergencySubscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, subscriberImsi=imsi, gxSessionId=emergencySubscriberData.get('servingPgw')) avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001)) response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet @@ -3018,22 +3032,45 @@ def Answer_16777236_275(self, packet_vars, avps): avp += self.generate_avp(263, 40, self.string_to_hex(sessionId)) #Set session ID to received session ID avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm - imsSubscriber = self.database.Get_IMS_Subscriber_By_Session_Id(sessionId=sessionId) - imsi = imsSubscriber.get('imsi', None) - pcscf = imsSubscriber.get('pcscf', None) - pcscf_realm = imsSubscriber.get('pcscf_realm', None) - pcscf_peer = imsSubscriber.get('pcscf_peer', None) - subscriber = self.database.Get_Subscriber(imsi=imsi) - subscriberId = subscriber.get('subscriber_id', None) - apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) - servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) - self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=pcscf, pcscf_realm=pcscf_realm, pcscf_peer=pcscf_peer, pcscf_active_session=None) + servingApn = None + try: + imsSubscriber = self.database.Get_IMS_Subscriber_By_Session_Id(sessionId=sessionId) + imsi = imsSubscriber.get('imsi', None) + pcscf = imsSubscriber.get('pcscf', None) + pcscf_realm = imsSubscriber.get('pcscf_realm', None) + pcscf_peer = imsSubscriber.get('pcscf_peer', None) + subscriber = self.database.Get_Subscriber(imsi=imsi) + subscriberId = subscriber.get('subscriber_id', None) + apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None) + servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId) + if servingApn is not None: + servingPgw = servingApn.get('serving_pgw', '') + servingPgwRealm = servingApn.get('serving_pgw_realm', '') + servingPgwPeer = servingApn.get('serving_pgw_peer', '').split(';')[0] + pcrfSessionId = servingApn.get('pcrf_session_id', None) + self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=pcscf, pcscf_realm=pcscf_realm, pcscf_peer=pcscf_peer, pcscf_active_session=None) + except Exception as e: + pass + + """ + Determine if the Session-ID for the STR belongs to an inbound roaming emergency subscriber. + """ + try: + emergencySubscriberData = self.getEmergencySubscriber(gxSessionId=sessionId) + if emergencySubscriberData: + emergencySubscriber = True + except Exception as e: + self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] Error getting Emergency Subscriber Data: {traceback.format_exc()}", redisClient=self.redisMessaging) + emergencySubscriberData = None - if servingApn is not None: - servingPgw = servingApn.get('serving_pgw', '') - servingPgwRealm = servingApn.get('serving_pgw_realm', '') - servingPgwPeer = servingApn.get('serving_pgw_peer', '').split(';')[0] - pcrfSessionId = servingApn.get('pcrf_session_id', None) + if emergencySubscriberData: + for key, value in emergencySubscriberData.items(): + servingPgwPeer = emergencySubscriberData[key].get('servingPgw', None).split(';')[0] + pcrfSessionId = emergencySubscriberData[key].get('servingPgw', None) + servingPgwRealm = emergencySubscriberData[key].get('gxOriginRealm', None) + servingPgw = emergencySubscriberData[key].get('servingPgw', None).split(';')[0] + + if servingApn is not None or emergencySubscriberData: reAuthAnswer = self.awaitDiameterRequestAndResponse( requestType='RAR', hostname=servingPgwPeer, @@ -3043,7 +3080,7 @@ def Answer_16777236_275(self, packet_vars, avps): chargingRuleName='GBR-Voice', chargingRuleAction='remove' ) - + if not len(reAuthAnswer) > 0: self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA Timeout: {reAuthAnswer}", redisClient=self.redisMessaging) assert()