diff --git a/backend/user/notifications.py b/backend/user/notifications.py index e3a004b4..0241db3e 100644 --- a/backend/user/notifications.py +++ b/backend/user/notifications.py @@ -27,8 +27,8 @@ class NotificationWrapper(ABC): - def send_notification(self, tokens, title, body): - self.send_payload(tokens, self.create_payload(title, body)) + def send_notification(self, tokens, title, body, urgent): + self.send_payload(tokens, self.create_payload(title, body, urgent)) def send_shadow_notification(self, tokens, body): self.send_payload(tokens, self.create_shadow_payload(body)) @@ -42,7 +42,7 @@ def send_payload(self, tokens, payload): self.send_one_notification(tokens[0], payload) @abstractmethod - def create_payload(self, title, body): + def create_payload(self, title, body, urgent): raise NotImplementedError # pragma: no cover @abstractmethod @@ -67,7 +67,8 @@ def __init__(self): except Exception as e: print(f"Notifications Error: Failed to initialize Firebase client: {e}") - def create_payload(self, title, body): + def create_payload(self, title, body, urgent): + # TODO: do something with urgent return {"notification": messaging.Notification(title=title, body=body)} def create_shadow_payload(self, body): @@ -84,6 +85,17 @@ def send_one_notification(self, token, payload): class IOSNotificationWrapper(NotificationWrapper): + class CustomPayload(Payload): + # Custom payload to support interruption_level + def __init__(self, urgent, **kwargs): + super().__init__(**kwargs) + self.urgent = urgent + + def dict(self): + result = super().dict() + if self.urgent: + result["aps"]["interruption-level"] = "time-sensitive" + @staticmethod def get_client(is_dev): auth_key_path = ( @@ -98,10 +110,14 @@ def __init__(self, is_dev=False): except Exception as e: print(f"Notifications Error: Failed to initialize APNs client: {e}") - def create_payload(self, title, body): + def create_payload(self, title, body, urgent): # TODO: we might want to add category here, but there is no use on iOS side for now - return Payload( - alert={"title": title, "body": body}, sound="default", badge=0, mutable_content=True + return IOSNotificationWrapper.CustomPayload( + alert={"title": title, "body": body}, + sound="default", + badge=0, + mutable_content=True, + urgent=urgent, ) def create_shadow_payload(self, body): @@ -121,7 +137,7 @@ def send_one_notification(self, token, payload): @shared_task(name="notifications.ios_send_notification") -def ios_send_notification(tokens, title, body): +def ios_send_notification(tokens, title, body, urgent): IOSNotificationSender.send_notification(tokens, title, body) @@ -131,7 +147,7 @@ def ios_send_shadow_notification(tokens, body): @shared_task(name="notifications.android_send_notification") -def android_send_notification(tokens, title, body): +def android_send_notification(tokens, title, body, urgent): AndroidNotificationSender.send_notification(tokens, title, body) @@ -141,7 +157,7 @@ def android_send_shadow_notification(tokens, body): @shared_task(name="notifications.ios_send_dev_notification") -def ios_send_dev_notification(tokens, title, body): +def ios_send_dev_notification(tokens, title, body, urgent): IOSNotificationDevSender.send_notification(tokens, title, body) diff --git a/backend/user/views.py b/backend/user/views.py index d94d4f6b..cacf5be8 100644 --- a/backend/user/views.py +++ b/backend/user/views.py @@ -139,6 +139,7 @@ def post(self, request): title = request.data.get("title") body = request.data.get("body") delay = max(request.data.get("delay", 0), 0) + urgent = request.data.get("urgent", False) if None in [service, title, body]: return Response({"detail": "Missing required parameters."}, status=400) @@ -160,7 +161,7 @@ def post(self, request): (android_tokens, android_send_notification), ]: if tokens_list := list(tokens.values_list("token", flat=True)): - send.apply_async(args=(tokens_list, title, body), countdown=delay) + send.apply_async(args=(tokens_list, title, body, urgent), countdown=delay) users_with_service_usernames = users_with_service.values_list("username", flat=True) users_not_reached_usernames = list(set(usernames) - set(users_with_service_usernames))