From dad2d5982e8d9324a1552d3d989f06b1d5497160 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Thu, 14 Nov 2024 16:48:22 -0800 Subject: [PATCH 1/2] Support AMR routing with email plugins --- integrations/access/accessrequest/app.go | 4 +- integrations/access/email/app.go | 110 ++++++++++++++++++----- integrations/access/email/client.go | 9 +- 3 files changed, 94 insertions(+), 29 deletions(-) diff --git a/integrations/access/accessrequest/app.go b/integrations/access/accessrequest/app.go index 893d3db9caba8..8a5effc73dabd 100644 --- a/integrations/access/accessrequest/app.go +++ b/integrations/access/accessrequest/app.go @@ -181,12 +181,12 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { case types.KindAccessMonitoringRule: return trace.Wrap(a.accessMonitoringRules.HandleAccessMonitoringRule(ctx, event)) case types.KindAccessRequest: - return trace.Wrap(a.handleAcessRequest(ctx, event)) + return trace.Wrap(a.handleAccessRequest(ctx, event)) } return trace.BadParameter("unexpected kind %s", event.Resource.GetKind()) } -func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error { +func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error { op := event.Type reqID := event.Resource.GetName() ctx, _ = logger.WithField(ctx, "request_id", reqID) diff --git a/integrations/access/email/app.go b/integrations/access/email/app.go index 8e38af504a56a..07bb3b558080e 100644 --- a/integrations/access/email/app.go +++ b/integrations/access/email/app.go @@ -18,12 +18,15 @@ package email import ( "context" + "slices" "time" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/accessmonitoring" + "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/common/teleport" "github.com/gravitational/teleport/integrations/lib" "github.com/gravitational/teleport/integrations/lib/logger" @@ -48,9 +51,10 @@ const ( type App struct { conf Config - apiClient teleport.Client - client Client - mainJob lib.ServiceJob + apiClient teleport.Client + client Client + mainJob lib.ServiceJob + accessMonitoringRules *accessmonitoring.RuleHandler *lib.Process } @@ -91,13 +95,24 @@ func (a *App) run(ctx context.Context) error { if err = a.init(ctx); err != nil { return trace.Wrap(err) } - watcherJob, err := watcherjob.NewJob( + + watchKinds := []types.WatchKind{ + {Kind: types.KindAccessRequest}, + {Kind: types.KindAccessMonitoringRule}, + } + acceptedWatchKinds := make([]string, 0, len(watchKinds)) + watcherJob, err := watcherjob.NewJobWithConfirmedWatchKinds( a.apiClient, watcherjob.Config{ - Watch: types.Watch{Kinds: []types.WatchKind{{Kind: types.KindAccessRequest}}}, + Watch: types.Watch{Kinds: watchKinds, AllowPartialSuccess: true}, EventFuncTimeout: handlerTimeout, }, a.onWatcherEvent, + func(ws types.WatchStatus) { + for _, watchKind := range ws.GetKinds() { + acceptedWatchKinds = append(acceptedWatchKinds, watchKind.Kind) + } + }, ) if err != nil { return trace.Wrap(err) @@ -107,6 +122,18 @@ func (a *App) run(ctx context.Context) error { if err != nil { return trace.Wrap(err) } + if len(acceptedWatchKinds) == 0 { + return trace.BadParameter("failed to initialize watcher for all the required resources: %+v", + watchKinds) + } + + // Check if KindAccessMonitoringRule resources are being watched, + // the role the plugin is running as may not have access. + if slices.Contains(acceptedWatchKinds, types.KindAccessMonitoringRule) { + if err := a.accessMonitoringRules.InitAccessMonitoringRulesCache(ctx); err != nil { + return trace.Wrap(err, "initializing Access Monitoring Rule cache") + } + } a.mainJob.SetReady(ok) if ok { @@ -146,6 +173,19 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } + a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{ + Client: a.apiClient, + PluginType: types.PluginTypeEmail, + PluginName: pluginName, + FetchRecipientCallback: func(_ context.Context, recipient string) (*common.Recipient, error) { + return &common.Recipient{ + Name: recipient, + ID: recipient, + Kind: common.RecipientKindEmail, + }, nil + }, + }) + log.Debug("Starting client connection health check...") if err = a.client.CheckHealth(ctx); err != nil { return trace.Wrap(err, "client connection health check failed") @@ -170,8 +210,20 @@ func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, err return pong, trace.Wrap(err) } -// onWatcherEvent processes new incoming access request +// onWatcherEvent is called for every cluster Event. It will filter out non-access-request events and +// call onPendingRequest, onResolvedRequest and on DeletedRequest depending on the event. func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { + switch event.Resource.GetKind() { + case types.KindAccessMonitoringRule: + return trace.Wrap(a.accessMonitoringRules.HandleAccessMonitoringRule(ctx, event)) + case types.KindAccessRequest: + return trace.Wrap(a.handleAccessRequest(ctx, event)) + } + return trace.BadParameter("unexpected kind %s", event.Resource.GetKind()) +} + +// handleAccessRequest processes new incoming access request +func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error { if kind := event.Resource.GetKind(); kind != types.KindAccessRequest { return trace.Errorf("unexpected kind %s", kind) } @@ -238,14 +290,15 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err } if isNew { - if recipients := a.getEmailRecipients(ctx, req.GetRoles(), req.GetSuggestedReviewers()); len(recipients) > 0 { - if err := a.sendNewThreads(ctx, recipients, reqID, reqData); err != nil { - return trace.Wrap(err) - } - } else { + recipients := a.getRecipients(ctx, req) + if len(recipients) == 0 { log.Warning("No recipients to send") return nil } + + if err := a.sendNewThreads(ctx, recipients, reqID, reqData); err != nil { + return trace.Wrap(err) + } } if reqReviews := req.GetReviews(); len(reqReviews) > 0 { @@ -288,27 +341,38 @@ func (a *App) onDeletedRequest(ctx context.Context, reqID string) error { return a.sendResolution(ctx, reqID, Resolution{Tag: ResolvedExpired}) } -// getEmailRecipients converts suggested reviewers to email recipients -func (a *App) getEmailRecipients(ctx context.Context, roles, suggestedReviewers []string) []string { +func (a *App) getRecipients(ctx context.Context, req types.AccessRequest) []common.Recipient { log := logger.Get(ctx) - validEmailRecipients := []string{} - recipients := a.conf.RoleToRecipients.GetRawRecipientsFor(roles, suggestedReviewers) + recipientSet := common.NewRecipientSet() + recipients := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req) + recipients.ForEach(func(r common.Recipient) { + recipientSet.Add(r) + }) + + // Return the set of recipients if it is not empty. + // Otherwise, use the legacy role to recipients map to search for recipients. + if recipientSet.Len() != 0 { + return recipientSet.ToSlice() + } - for _, recipient := range recipients { - if !lib.IsEmail(recipient) { - log.Warningf("Failed to notify a reviewer: %q does not look like a valid email", recipient) + rawRecipients := a.conf.RoleToRecipients.GetRawRecipientsFor(req.GetRoles(), req.GetSuggestedReviewers()) + for _, rawRecipient := range rawRecipients { + if !lib.IsEmail(rawRecipient) { + log.Warningf("Failed to notify a reviewer: %q does not look like a valid email", rawRecipient) continue } - - validEmailRecipients = append(validEmailRecipients, recipient) + recipientSet.Add(common.Recipient{ + ID: rawRecipient, + Name: rawRecipient, + Kind: common.RecipientKindEmail, + }) } - - return validEmailRecipients + return recipientSet.ToSlice() } // broadcastNewThreads sends notifications on a new request -func (a *App) sendNewThreads(ctx context.Context, recipients []string, reqID string, reqData RequestData) error { +func (a *App) sendNewThreads(ctx context.Context, recipients []common.Recipient, reqID string, reqData RequestData) error { threadsSent, err := a.client.SendNewThreads(ctx, recipients, reqID, reqData) if len(threadsSent) == 0 && err != nil { diff --git a/integrations/access/email/client.go b/integrations/access/email/client.go index b65516962d8c4..6ef1d2f04144e 100644 --- a/integrations/access/email/client.go +++ b/integrations/access/email/client.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/lib" "github.com/gravitational/teleport/integrations/lib/logger" ) @@ -85,20 +86,20 @@ func (c *Client) CheckHealth(ctx context.Context) error { } // SendNewThreads sends emails on new requests. Returns EmailData. -func (c *Client) SendNewThreads(ctx context.Context, recipients []string, reqID string, reqData RequestData) ([]EmailThread, error) { +func (c *Client) SendNewThreads(ctx context.Context, recipients []common.Recipient, reqID string, reqData RequestData) ([]EmailThread, error) { var threads []EmailThread var errors []error body := c.buildBody(reqID, reqData, "You have a new Role Request") - for _, email := range recipients { - id, err := c.mailer.Send(ctx, reqID, email, body, "") + for _, recipient := range recipients { + id, err := c.mailer.Send(ctx, reqID, recipient.ID, body, "") if err != nil { errors = append(errors, err) continue } - threads = append(threads, EmailThread{Email: email, Timestamp: time.Now().String(), MessageID: id}) + threads = append(threads, EmailThread{Email: recipient.ID, Timestamp: time.Now().String(), MessageID: id}) } return threads, trace.NewAggregate(errors...) From 86bc729e6e9767f6fd0101540fd24b43e8183fbe Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Thu, 14 Nov 2024 16:51:59 -0800 Subject: [PATCH 2/2] Update unit tests --- integrations/access/email/testlib/suite.go | 129 +++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/integrations/access/email/testlib/suite.go b/integrations/access/email/testlib/suite.go index 8d99411193776..1b1dd1494e1c2 100644 --- a/integrations/access/email/testlib/suite.go +++ b/integrations/access/email/testlib/suite.go @@ -31,7 +31,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/email" "github.com/gravitational/teleport/integrations/lib" "github.com/gravitational/teleport/integrations/lib/logger" @@ -231,6 +234,132 @@ func (s *EmailSuiteOSS) TestDenial() { require.Contains(t, messages[0].Body, "Status: ❌ DENIED (not okay)") } +// TestRecipientsFromAccessMonitoringRule tests access monitoring rules are +// applied to the recipient selection process. +func (s *EmailSuiteOSS) TestRecipientsFromAccessMonitoringRule() { + t := s.T() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + t.Cleanup(cancel) + + _, err := s.ClientByName(integration.RulerUserName). + AccessMonitoringRulesClient(). + CreateAccessMonitoringRule(ctx, &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test-email-amr", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "!is_empty(access_request.spec.roles)", + Notification: &accessmonitoringrulesv1.Notification{ + Name: "email", + Recipients: []string{ + integration.Reviewer1UserName, + }, + }, + }, + }) + require.NoError(t, err) + + // Test execution: create an access request + req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) + pluginData := s.checkPluginData(ctx, req.GetName(), func(data email.PluginData) bool { + return len(data.EmailThreads) > 0 + }) + require.Len(t, pluginData.EmailThreads, 1) + + messages := s.getMessages(ctx, t, 1) + require.Len(t, messages, 1) + require.Equal(t, integration.Reviewer1UserName, messages[0].Recipient) + + require.NoError(t, s.ClientByName(integration.RulerUserName). + AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-email-amr")) +} + +// TestRecipientsFromAccessMonitoringRuleAfterUpdate tests access monitoring +// rules are respected after the rule is updated. +func (s *EmailSuiteOSS) TestRecipientsFromAccessMonitoringRuleAfterUpdate() { + t := s.T() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + t.Cleanup(cancel) + + // Setup base config to ensure access monitoring rule recipient take precidence + s.appConfig.RoleToRecipients = common.RawRecipientsMap{ + types.Wildcard: []string{ + integration.Reviewer2UserName, + }, + } + + _, err := s.ClientByName(integration.RulerUserName). + AccessMonitoringRulesClient(). + CreateAccessMonitoringRule(ctx, &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test-email-amr-2", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "!is_empty(access_request.spec.roles)", + Notification: &accessmonitoringrulesv1.Notification{ + Name: "email", + Recipients: []string{ + integration.Reviewer1UserName, + }, + }, + }, + }) + require.NoError(t, err) + + // Test execution: create an access request + req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) + pluginData := s.checkPluginData(ctx, req.GetName(), func(data email.PluginData) bool { + return len(data.EmailThreads) > 0 + }) + require.Len(t, pluginData.EmailThreads, 1) + + messages := s.getMessages(ctx, t, 1) + require.Len(t, messages, 1) + require.Equal(t, integration.Reviewer1UserName, messages[0].Recipient) + + // Update the Access Monitoring Rule so it is no longer applied + _, err = s.ClientByName(integration.RulerUserName). + AccessMonitoringRulesClient(). + UpdateAccessMonitoringRule(ctx, &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test-email-amr-2", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{"someOtherKind"}, + Condition: "!is_empty(access_request.spec.roles)", + Notification: &accessmonitoringrulesv1.Notification{ + Name: "email", + Recipients: []string{ + integration.Reviewer1UserName, + }, + }, + }, + }) + require.NoError(t, err) + + // Test execution: create an access request + req = s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) + pluginData = s.checkPluginData(ctx, req.GetName(), func(data email.PluginData) bool { + return len(data.EmailThreads) > 0 + }) + require.Len(t, pluginData.EmailThreads, 1) + + messages = s.getMessages(ctx, t, 1) + require.Len(t, messages, 1) + require.Equal(t, allRecipient, messages[0].Recipient) + + require.NoError(t, s.ClientByName(integration.RulerUserName). + AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-email-amr-2")) +} + // TestReviewReplies tests that a followup email is sent after the access request // is reviewed. func (s *EmailSuiteEnterprise) TestReviewReplies() {