Skip to content

Commit

Permalink
Fix for STA / Emergency SOS
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkneipp committed Jan 12, 2024
1 parent 3072fe9 commit 6c41b0b
Showing 1 changed file with 62 additions and 25 deletions.
87 changes: 62 additions & 25 deletions lib/diameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,29 +1021,29 @@ def validateSubscriberRoaming(self, subscriber: dict, mcc: str, mnc: str) -> boo
return True


def storeEmergencySubscriber(self, subscriberIp: str, subscriberData: dict, authExpiry: int=3600, subscriberImsi: str="Unknown") -> bool:
def storeEmergencySubscriber(self, subscriberIp: str, subscriberData: dict, gxSessionId: str, authExpiry: int=3600, subscriberImsi: str="Unknown") -> bool:
"""
Store a given Emergency Subscriber in redis.
If there's an existing entry for the same IMSI, then update the record with the new IP and details.
The subscriber entry will expire per authExpiry in seconds.
"""
try:
emergencySubscriberKey = f"emergencySubscriber-{subscriberIp}-{subscriberImsi}"
emergencySubscriberKey = f"emergencySubscriber:{subscriberIp}:{subscriberImsi}:{gxSessionId}"
# Check if our subscriber exists
if subscriberImsi and subscriberImsi != "Unknown":
existingEmergencySubscriber = self.getEmergencySubscriber(subscriberImsi=subscriberImsi)
if existingEmergencySubscriber:
self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [getEmergencySubscriber] Found existing emergency subscriber to overwrite: {existingEmergencySubscriber}", redisClient=self.redisMessaging)
for key, value in existingEmergencySubscriber.items():
self.redisMessaging.multiDeleteQueue(queue=f"emergencySubscriber-{value.get('ip')}-{value.get('imsi')}", redisPeerConnections=self.redisPeerConnections)
self.redisMessaging.multiDeleteQueue(queue=f"emergencySubscriber:{value.get('ip')}:{value.get('imsi')}:{value.get('servingPgw')}", redisPeerConnections=self.redisPeerConnections)
result = self.redisMessaging.multiSetValue(key=emergencySubscriberKey, value=json.dumps(subscriberData), keyExpiry=authExpiry, redisPeerConnections=self.redisPeerConnections)
return True
except Exception as e:
self.logTool.log(service='HSS', level='error', message=f"[diameter.py] [getEmergencySubscriber] Error storing emergency subscriber in redis: {traceback.format_exc()}", redisClient=self.redisMessaging)
return False


def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=None) -> dict:
def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=None, gxSessionId: str=None) -> dict:
"""
Retrieves a provided Emergency Subscriber from redis, if it exists.
The first match from any defined redis instance is used.
Expand All @@ -1055,7 +1055,7 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non
return None

if subscriberIp and subscriberImsi:
emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber-{subscriberIp}-{subscriberImsi}")
emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:{subscriberIp}:{subscriberImsi}:*")
if emergencySubscriberKeyList:
for matchedKey in emergencySubscriberKeyList:
for peerName, keyName in matchedKey.items():
Expand All @@ -1069,7 +1069,7 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non
return emergencySubscriber

if subscriberIp and not subscriberImsi:
emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber-{subscriberIp}-*")
emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:{subscriberIp}:*")
if emergencySubscriberKeyList:
for matchedKey in emergencySubscriberKeyList:
for peerName, keyName in matchedKey.items():
Expand All @@ -1083,7 +1083,7 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non
return emergencySubscriber

if subscriberImsi and not subscriberIp:
emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber-*-{subscriberImsi}")
emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:*:{subscriberImsi}:*")
if emergencySubscriberKeyList:
for matchedKey in emergencySubscriberKeyList:
for peerName, keyName in matchedKey.items():
Expand All @@ -1095,7 +1095,21 @@ def getEmergencySubscriber(self, subscriberIp: str=None, subscriberImsi: str=Non
emergencySubscriberData = json.loads(emergencySubscriberData)
emergencySubscriber = {peerName: emergencySubscriberData}
return emergencySubscriber


if gxSessionId:
emergencySubscriberKeyList = self.redisMessaging.multiGetQueues(pattern=f"emergencySubscriber:*:*:{gxSessionId}")
if emergencySubscriberKeyList:
for matchedKey in emergencySubscriberKeyList:
for peerName, keyName in matchedKey.items():
if isinstance(keyName, list):
keyName = keyName[0] if len(keyName) > 0 else ''
emergencySubscriberData = self.redisMessaging.getValue(key=keyName, redisClient=self.getRedisPeerConnection(peerName=peerName))
if not emergencySubscriberData:
return None
emergencySubscriberData = json.loads(emergencySubscriberData)
emergencySubscriber = {peerName: emergencySubscriberData}
return emergencySubscriber

return None

except Exception as e:
Expand Down Expand Up @@ -2010,7 +2024,7 @@ def Answer_16777238_272(self, packet_vars, avps):
"accessNetworkChargingAddress": accessNetworkChargingAddress,
}

self.storeEmergencySubscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, subscriberImsi=imsi)
self.storeEmergencySubscriber(subscriberIp=ueIp, subscriberData=emergencySubscriberData, subscriberImsi=imsi, gxSessionId=emergencySubscriberData.get('servingPgw'))

avp += self.generate_avp(268, 40, self.int_to_hex(2001, 4)) #Result Code (DIAMETER_SUCCESS (2001))
response = self.generate_diameter_packet("01", "40", 272, 16777238, packet_vars['hop-by-hop-identifier'], packet_vars['end-to-end-identifier'], avp) #Generate Diameter packet
Expand Down Expand Up @@ -3018,22 +3032,45 @@ def Answer_16777236_275(self, packet_vars, avps):
avp += self.generate_avp(263, 40, self.string_to_hex(sessionId)) #Set session ID to received session ID
avp += self.generate_avp(264, 40, self.OriginHost) #Origin Host
avp += self.generate_avp(296, 40, self.OriginRealm) #Origin Realm
imsSubscriber = self.database.Get_IMS_Subscriber_By_Session_Id(sessionId=sessionId)
imsi = imsSubscriber.get('imsi', None)
pcscf = imsSubscriber.get('pcscf', None)
pcscf_realm = imsSubscriber.get('pcscf_realm', None)
pcscf_peer = imsSubscriber.get('pcscf_peer', None)
subscriber = self.database.Get_Subscriber(imsi=imsi)
subscriberId = subscriber.get('subscriber_id', None)
apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None)
servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId)
self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=pcscf, pcscf_realm=pcscf_realm, pcscf_peer=pcscf_peer, pcscf_active_session=None)
servingApn = None
try:
imsSubscriber = self.database.Get_IMS_Subscriber_By_Session_Id(sessionId=sessionId)
imsi = imsSubscriber.get('imsi', None)
pcscf = imsSubscriber.get('pcscf', None)
pcscf_realm = imsSubscriber.get('pcscf_realm', None)
pcscf_peer = imsSubscriber.get('pcscf_peer', None)
subscriber = self.database.Get_Subscriber(imsi=imsi)
subscriberId = subscriber.get('subscriber_id', None)
apnId = (self.database.Get_APN_by_Name(apn="ims")).get('apn_id', None)
servingApn = self.database.Get_Serving_APN(subscriber_id=subscriberId, apn_id=apnId)
if servingApn is not None:
servingPgw = servingApn.get('serving_pgw', '')
servingPgwRealm = servingApn.get('serving_pgw_realm', '')
servingPgwPeer = servingApn.get('serving_pgw_peer', '').split(';')[0]
pcrfSessionId = servingApn.get('pcrf_session_id', None)
self.database.Update_Proxy_CSCF(imsi=imsi, proxy_cscf=pcscf, pcscf_realm=pcscf_realm, pcscf_peer=pcscf_peer, pcscf_active_session=None)
except Exception as e:
pass

"""
Determine if the Session-ID for the STR belongs to an inbound roaming emergency subscriber.
"""
try:
emergencySubscriberData = self.getEmergencySubscriber(gxSessionId=sessionId)
if emergencySubscriberData:
emergencySubscriber = True
except Exception as e:
self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] Error getting Emergency Subscriber Data: {traceback.format_exc()}", redisClient=self.redisMessaging)
emergencySubscriberData = None

if servingApn is not None:
servingPgw = servingApn.get('serving_pgw', '')
servingPgwRealm = servingApn.get('serving_pgw_realm', '')
servingPgwPeer = servingApn.get('serving_pgw_peer', '').split(';')[0]
pcrfSessionId = servingApn.get('pcrf_session_id', None)
if emergencySubscriberData:
for key, value in emergencySubscriberData.items():
servingPgwPeer = emergencySubscriberData[key].get('servingPgw', None).split(';')[0]
pcrfSessionId = emergencySubscriberData[key].get('servingPgw', None)
servingPgwRealm = emergencySubscriberData[key].get('gxOriginRealm', None)
servingPgw = emergencySubscriberData[key].get('servingPgw', None).split(';')[0]

if servingApn is not None or emergencySubscriberData:
reAuthAnswer = self.awaitDiameterRequestAndResponse(
requestType='RAR',
hostname=servingPgwPeer,
Expand All @@ -3043,7 +3080,7 @@ def Answer_16777236_275(self, packet_vars, avps):
chargingRuleName='GBR-Voice',
chargingRuleAction='remove'
)

if not len(reAuthAnswer) > 0:
self.logTool.log(service='HSS', level='debug', message=f"[diameter.py] [Answer_16777236_265] [STA] RAA Timeout: {reAuthAnswer}", redisClient=self.redisMessaging)
assert()
Expand Down

0 comments on commit 6c41b0b

Please sign in to comment.