From d90aa70078f91a2b8dce5b9b615a49fddc102408 Mon Sep 17 00:00:00 2001 From: Yassine Bounekhla Date: Wed, 5 Jun 2024 19:08:57 -0400 Subject: [PATCH] make access request notifications expire alongside the request --- lib/auth/auth.go | 7 +- lib/auth/grpcserver.go | 1 + lib/auth/notification_test.go | 67 +++++++++++++++++-- .../notifications/notificationsv1/service.go | 23 +++++++ 4 files changed, 92 insertions(+), 6 deletions(-) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 727373a8b9014..4fca58cc5599b 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -58,6 +58,7 @@ import ( "golang.org/x/exp/maps" "golang.org/x/time/rate" "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client" @@ -4971,7 +4972,8 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ Spec: ¬ificationsv1.NotificationSpec{}, SubKind: types.NotificationAccessRequestPendingSubKind, Metadata: &headerv1.Metadata{ - Labels: map[string]string{types.NotificationTitleLabel: notificationText, "request-id": req.GetName()}, + Labels: map[string]string{types.NotificationTitleLabel: notificationText, "request-id": req.GetName()}, + Expires: timestamppb.New(req.Expiry()), }, }, }, @@ -5238,7 +5240,8 @@ func generateAccessRequestReviewedNotification(req types.AccessRequest, params t "request-id": params.RequestID, "roles": strings.Join(req.GetRoles(), ","), "assumable-time": assumableTime, - }}, + }, + Expires: timestamppb.New(req.Expiry())}, } } diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 0931c116b9b26..88016750733d9 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5380,6 +5380,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { notificationsServer, err := notifications.NewService(notifications.ServiceConfig{ Authorizer: cfg.Authorizer, Backend: cfg.AuthServer.Services, + Clock: cfg.AuthServer.GetClock(), UserNotificationCache: cfg.AuthServer.UserNotificationCache, GlobalNotificationCache: cfg.AuthServer.GlobalNotificationCache, }) diff --git a/lib/auth/notification_test.go b/lib/auth/notification_test.go index 02493ddafed1b..74d353f1bbde2 100644 --- a/lib/auth/notification_test.go +++ b/lib/auth/notification_test.go @@ -222,6 +222,49 @@ func TestNotifications(t *testing.T) { }, }), }, + { + userNotification: ¬ificationsv1.Notification{ + SubKind: "test-subkind", + Spec: ¬ificationsv1.NotificationSpec{ + Username: managerUsername, + }, + Metadata: &headerv1.Metadata{ + Labels: map[string]string{ + types.NotificationTitleLabel: "manager-7-expires", + }, + // Expires in 15 minutes. + Expires: timestamppb.New(fakeClock.Now().Add(15 * time.Minute)), + }, + }, + }, + { + globalNotification: ¬ificationsv1.GlobalNotification{ + Spec: ¬ificationsv1.GlobalNotificationSpec{ + Matcher: ¬ificationsv1.GlobalNotificationSpec_ByPermissions{ + ByPermissions: ¬ificationsv1.ByPermissions{ + RoleConditions: []*types.RoleConditions{ + { + ReviewRequests: &types.AccessReviewConditions{ + Roles: []string{"intern"}, + }, + }, + }, + }, + }, + Notification: ¬ificationsv1.Notification{ + SubKind: "test-subkind", + Spec: ¬ificationsv1.NotificationSpec{}, + Metadata: &headerv1.Metadata{ + Labels: map[string]string{ + types.NotificationTitleLabel: "manager-8-expires", + }, + // Expires in 10 minutes. + Expires: timestamppb.New(fakeClock.Now().Add(10 * time.Minute)), + }, + }, + }, + }, + }, } notificationIdMap := map[string]string{} @@ -348,7 +391,7 @@ func TestNotifications(t *testing.T) { require.NoError(t, err) defer managerClient.Close() - managerExpectedNotifications := []string{"auditor-8,manager-6", "manager-5", "manager-4", "manager-3", "auditor-5,manager-2", "manager-1"} + managerExpectedNotifications := []string{"manager-8-expires", "manager-7-expires", "auditor-8,manager-6", "manager-5", "manager-4", "manager-3", "auditor-5,manager-2", "manager-1"} resp, err = managerClient.ListNotifications(ctx, ¬ificationsv1.ListNotificationsRequest{ PageSize: 10, @@ -359,10 +402,10 @@ func TestNotifications(t *testing.T) { // Verify that we've reached the end of both lists. require.Equal(t, "", resp.NextPageToken) - // Mark "auditor-8,manager-6" as clicked. + // Mark "manager-8-expires" as clicked. _, err = managerClient.UpsertUserNotificationState(ctx, managerUsername, ¬ificationsv1.UserNotificationState{ Spec: ¬ificationsv1.UserNotificationStateSpec{ - NotificationId: notificationIdMap["auditor-8,manager-6"], + NotificationId: notificationIdMap["manager-8-expires"], }, Status: ¬ificationsv1.UserNotificationStateStatus{ NotificationState: notificationsv1.NotificationState_NOTIFICATION_STATE_CLICKED, @@ -376,10 +419,26 @@ func TestNotifications(t *testing.T) { }) require.NoError(t, err) - clickedNotification := resp.Notifications[0] // "auditor-8,manager-6" is the first item in the list + clickedNotification := resp.Notifications[0] // "manager-8-expires" is the first item in the list clickedLabelValue := clickedNotification.GetMetadata().GetLabels()[types.NotificationClickedLabel] require.Equal(t, "true", clickedLabelValue) + // Advance 11 minutes. + fakeClock.Advance(11 * time.Minute) + + // Verify that notification "manager-8-expires" is now no longer returned. + resp, err = managerClient.ListNotifications(ctx, ¬ificationsv1.ListNotificationsRequest{}) + require.NoError(t, err) + require.Equal(t, managerExpectedNotifications[1:], notificationsToTitlesList(t, resp.Notifications)) + + // Advance 16 minutes. + fakeClock.Advance(16 * time.Minute) + + // Verify that notification "manager-7-expires" is now no longer returned either. + resp, err = managerClient.ListNotifications(ctx, ¬ificationsv1.ListNotificationsRequest{}) + require.NoError(t, err) + require.Equal(t, managerExpectedNotifications[2:], notificationsToTitlesList(t, resp.Notifications)) + // Verify that manager can't upsert a notification state for auditor _, err = managerClient.UpsertUserNotificationState(ctx, auditorUsername, ¬ificationsv1.UserNotificationState{ Spec: ¬ificationsv1.UserNotificationStateSpec{ diff --git a/lib/auth/notifications/notificationsv1/service.go b/lib/auth/notifications/notificationsv1/service.go index 353932e5b3afa..b6b16b39be896 100644 --- a/lib/auth/notifications/notificationsv1/service.go +++ b/lib/auth/notifications/notificationsv1/service.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/gravitational/teleport/api/client" apidefaults "github.com/gravitational/teleport/api/defaults" @@ -47,6 +48,8 @@ type ServiceConfig struct { // GlobalNotificationCache is a custom cache for user-specific notifications, // this is to allow fetching notifications by date in descending order. GlobalNotificationCache *services.GlobalNotificationCache + + Clock clockwork.Clock } // Backend contains the getters required for notification states and user last seen notifications, @@ -70,6 +73,7 @@ type Service struct { backend Backend userNotificationCache *services.UserNotificationCache globalNotificationCache *services.GlobalNotificationCache + clock clockwork.Clock } // NewService returns a new notifications gRPC service. @@ -83,6 +87,8 @@ func NewService(cfg ServiceConfig) (*Service, error) { return nil, trace.BadParameter("user notification cache is required") case cfg.GlobalNotificationCache == nil: return nil, trace.BadParameter("global notification cache is required") + case cfg.Clock == nil: + cfg.Clock = clockwork.NewRealClock() } return &Service{ @@ -90,6 +96,7 @@ func NewService(cfg ServiceConfig) (*Service, error) { backend: cfg.Backend, userNotificationCache: cfg.UserNotificationCache, globalNotificationCache: cfg.GlobalNotificationCache, + clock: cfg.Clock, }, nil } @@ -120,6 +127,12 @@ func (s *Service) ListNotifications(ctx context.Context, req *notificationsv1.Li startKey = nextKey } + currentTime := s.clock.Now() + var hasNotificationExpired = func(n *notificationsv1.Notification) bool { + notificationExpiryTime := n.GetMetadata().GetExpires().AsTime() + return currentTime.After(notificationExpiryTime) + } + var userNotifMatchFn = func(n *notificationsv1.Notification) bool { // Return true if the user hasn't dismissed this notification return notificationStatesMap[n.GetMetadata().GetName()] != notificationsv1.NotificationState_NOTIFICATION_STATE_DISMISSED @@ -134,6 +147,11 @@ func (s *Service) ListNotifications(ctx context.Context, req *notificationsv1.Li userNotifsStream = stream.FilterMap( s.userNotificationCache.StreamUserNotifications(ctx, username, userKey), func(n *notificationsv1.Notification) (*notificationsv1.Notification, bool) { + // If the notification is expired, return false right away. + if hasNotificationExpired(n) { + return nil, false + } + if !userNotifMatchFn(n) { return nil, false } @@ -145,6 +163,11 @@ func (s *Service) ListNotifications(ctx context.Context, req *notificationsv1.Li globalNotifsStream = stream.FilterMap( s.globalNotificationCache.StreamGlobalNotifications(ctx, globalKey), func(gn *notificationsv1.GlobalNotification) (*notificationsv1.GlobalNotification, bool) { + // If the notification is expired, return false right away. + if hasNotificationExpired(gn.GetSpec().GetNotification()) { + return nil, false + } + if !s.matchGlobalNotification(ctx, authCtx, gn, notificationStatesMap) { return nil, false }