Skip to content

Commit

Permalink
[v17] Support AMR notification routing with email plugins (#49234)
Browse files Browse the repository at this point in the history
* Support AMR routing with email plugins

* Update unit tests
  • Loading branch information
bernardjkim authored Nov 20, 2024
1 parent e2d93d5 commit d635ef7
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 29 deletions.
4 changes: 2 additions & 2 deletions integrations/access/accessrequest/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error {
case types.KindAccessMonitoringRule:
return trace.Wrap(a.accessMonitoringRules.HandleAccessMonitoringRule(ctx, event))
case types.KindAccessRequest:
return trace.Wrap(a.handleAcessRequest(ctx, event))
return trace.Wrap(a.handleAccessRequest(ctx, event))
}
return trace.BadParameter("unexpected kind %s", event.Resource.GetKind())
}

func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error {
func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error {
op := event.Type
reqID := event.Resource.GetName()
ctx, _ = logger.WithField(ctx, "request_id", reqID)
Expand Down
110 changes: 87 additions & 23 deletions integrations/access/email/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ package email

import (
"context"
"slices"
"time"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/accessmonitoring"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/access/common/teleport"
"github.com/gravitational/teleport/integrations/lib"
"github.com/gravitational/teleport/integrations/lib/logger"
Expand All @@ -48,9 +51,10 @@ const (
type App struct {
conf Config

apiClient teleport.Client
client Client
mainJob lib.ServiceJob
apiClient teleport.Client
client Client
mainJob lib.ServiceJob
accessMonitoringRules *accessmonitoring.RuleHandler

*lib.Process
}
Expand Down Expand Up @@ -91,13 +95,24 @@ func (a *App) run(ctx context.Context) error {
if err = a.init(ctx); err != nil {
return trace.Wrap(err)
}
watcherJob, err := watcherjob.NewJob(

watchKinds := []types.WatchKind{
{Kind: types.KindAccessRequest},
{Kind: types.KindAccessMonitoringRule},
}
acceptedWatchKinds := make([]string, 0, len(watchKinds))
watcherJob, err := watcherjob.NewJobWithConfirmedWatchKinds(
a.apiClient,
watcherjob.Config{
Watch: types.Watch{Kinds: []types.WatchKind{{Kind: types.KindAccessRequest}}},
Watch: types.Watch{Kinds: watchKinds, AllowPartialSuccess: true},
EventFuncTimeout: handlerTimeout,
},
a.onWatcherEvent,
func(ws types.WatchStatus) {
for _, watchKind := range ws.GetKinds() {
acceptedWatchKinds = append(acceptedWatchKinds, watchKind.Kind)
}
},
)
if err != nil {
return trace.Wrap(err)
Expand All @@ -107,6 +122,18 @@ func (a *App) run(ctx context.Context) error {
if err != nil {
return trace.Wrap(err)
}
if len(acceptedWatchKinds) == 0 {
return trace.BadParameter("failed to initialize watcher for all the required resources: %+v",
watchKinds)
}

// Check if KindAccessMonitoringRule resources are being watched,
// the role the plugin is running as may not have access.
if slices.Contains(acceptedWatchKinds, types.KindAccessMonitoringRule) {
if err := a.accessMonitoringRules.InitAccessMonitoringRulesCache(ctx); err != nil {
return trace.Wrap(err, "initializing Access Monitoring Rule cache")
}
}

a.mainJob.SetReady(ok)
if ok {
Expand Down Expand Up @@ -146,6 +173,19 @@ func (a *App) init(ctx context.Context) error {
return trace.Wrap(err)
}

a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{
Client: a.apiClient,
PluginType: types.PluginTypeEmail,
PluginName: pluginName,
FetchRecipientCallback: func(_ context.Context, recipient string) (*common.Recipient, error) {
return &common.Recipient{
Name: recipient,
ID: recipient,
Kind: common.RecipientKindEmail,
}, nil
},
})

log.Debug("Starting client connection health check...")
if err = a.client.CheckHealth(ctx); err != nil {
return trace.Wrap(err, "client connection health check failed")
Expand All @@ -170,8 +210,20 @@ func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, err
return pong, trace.Wrap(err)
}

// onWatcherEvent processes new incoming access request
// onWatcherEvent is called for every cluster Event. It will filter out non-access-request events and
// call onPendingRequest, onResolvedRequest and on DeletedRequest depending on the event.
func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error {
switch event.Resource.GetKind() {
case types.KindAccessMonitoringRule:
return trace.Wrap(a.accessMonitoringRules.HandleAccessMonitoringRule(ctx, event))
case types.KindAccessRequest:
return trace.Wrap(a.handleAccessRequest(ctx, event))
}
return trace.BadParameter("unexpected kind %s", event.Resource.GetKind())
}

// handleAccessRequest processes new incoming access request
func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error {
if kind := event.Resource.GetKind(); kind != types.KindAccessRequest {
return trace.Errorf("unexpected kind %s", kind)
}
Expand Down Expand Up @@ -238,14 +290,15 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err
}

if isNew {
if recipients := a.getEmailRecipients(ctx, req.GetRoles(), req.GetSuggestedReviewers()); len(recipients) > 0 {
if err := a.sendNewThreads(ctx, recipients, reqID, reqData); err != nil {
return trace.Wrap(err)
}
} else {
recipients := a.getRecipients(ctx, req)
if len(recipients) == 0 {
log.Warning("No recipients to send")
return nil
}

if err := a.sendNewThreads(ctx, recipients, reqID, reqData); err != nil {
return trace.Wrap(err)
}
}

if reqReviews := req.GetReviews(); len(reqReviews) > 0 {
Expand Down Expand Up @@ -288,27 +341,38 @@ func (a *App) onDeletedRequest(ctx context.Context, reqID string) error {
return a.sendResolution(ctx, reqID, Resolution{Tag: ResolvedExpired})
}

// getEmailRecipients converts suggested reviewers to email recipients
func (a *App) getEmailRecipients(ctx context.Context, roles, suggestedReviewers []string) []string {
func (a *App) getRecipients(ctx context.Context, req types.AccessRequest) []common.Recipient {
log := logger.Get(ctx)
validEmailRecipients := []string{}

recipients := a.conf.RoleToRecipients.GetRawRecipientsFor(roles, suggestedReviewers)
recipientSet := common.NewRecipientSet()
recipients := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req)
recipients.ForEach(func(r common.Recipient) {
recipientSet.Add(r)
})

// Return the set of recipients if it is not empty.
// Otherwise, use the legacy role to recipients map to search for recipients.
if recipientSet.Len() != 0 {
return recipientSet.ToSlice()
}

for _, recipient := range recipients {
if !lib.IsEmail(recipient) {
log.Warningf("Failed to notify a reviewer: %q does not look like a valid email", recipient)
rawRecipients := a.conf.RoleToRecipients.GetRawRecipientsFor(req.GetRoles(), req.GetSuggestedReviewers())
for _, rawRecipient := range rawRecipients {
if !lib.IsEmail(rawRecipient) {
log.Warningf("Failed to notify a reviewer: %q does not look like a valid email", rawRecipient)
continue
}

validEmailRecipients = append(validEmailRecipients, recipient)
recipientSet.Add(common.Recipient{
ID: rawRecipient,
Name: rawRecipient,
Kind: common.RecipientKindEmail,
})
}

return validEmailRecipients
return recipientSet.ToSlice()
}

// broadcastNewThreads sends notifications on a new request
func (a *App) sendNewThreads(ctx context.Context, recipients []string, reqID string, reqData RequestData) error {
func (a *App) sendNewThreads(ctx context.Context, recipients []common.Recipient, reqID string, reqData RequestData) error {
threadsSent, err := a.client.SendNewThreads(ctx, recipients, reqID, reqData)

if len(threadsSent) == 0 && err != nil {
Expand Down
9 changes: 5 additions & 4 deletions integrations/access/email/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/lib"
"github.com/gravitational/teleport/integrations/lib/logger"
)
Expand Down Expand Up @@ -85,20 +86,20 @@ func (c *Client) CheckHealth(ctx context.Context) error {
}

// SendNewThreads sends emails on new requests. Returns EmailData.
func (c *Client) SendNewThreads(ctx context.Context, recipients []string, reqID string, reqData RequestData) ([]EmailThread, error) {
func (c *Client) SendNewThreads(ctx context.Context, recipients []common.Recipient, reqID string, reqData RequestData) ([]EmailThread, error) {
var threads []EmailThread
var errors []error

body := c.buildBody(reqID, reqData, "You have a new Role Request")

for _, email := range recipients {
id, err := c.mailer.Send(ctx, reqID, email, body, "")
for _, recipient := range recipients {
id, err := c.mailer.Send(ctx, reqID, recipient.ID, body, "")
if err != nil {
errors = append(errors, err)
continue
}

threads = append(threads, EmailThread{Email: email, Timestamp: time.Now().String(), MessageID: id})
threads = append(threads, EmailThread{Email: recipient.ID, Timestamp: time.Now().String(), MessageID: id})
}

return threads, trace.NewAggregate(errors...)
Expand Down
129 changes: 129 additions & 0 deletions integrations/access/email/testlib/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1"
headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/access/email"
"github.com/gravitational/teleport/integrations/lib"
"github.com/gravitational/teleport/integrations/lib/logger"
Expand Down Expand Up @@ -231,6 +234,132 @@ func (s *EmailSuiteOSS) TestDenial() {
require.Contains(t, messages[0].Body, "Status: ❌ DENIED (not okay)")
}

// TestRecipientsFromAccessMonitoringRule tests access monitoring rules are
// applied to the recipient selection process.
func (s *EmailSuiteOSS) TestRecipientsFromAccessMonitoringRule() {
t := s.T()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
t.Cleanup(cancel)

_, err := s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().
CreateAccessMonitoringRule(ctx, &accessmonitoringrulesv1.AccessMonitoringRule{
Kind: types.KindAccessMonitoringRule,
Version: types.V1,
Metadata: &headerv1.Metadata{
Name: "test-email-amr",
},
Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "!is_empty(access_request.spec.roles)",
Notification: &accessmonitoringrulesv1.Notification{
Name: "email",
Recipients: []string{
integration.Reviewer1UserName,
},
},
},
})
require.NoError(t, err)

// Test execution: create an access request
req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil)
pluginData := s.checkPluginData(ctx, req.GetName(), func(data email.PluginData) bool {
return len(data.EmailThreads) > 0
})
require.Len(t, pluginData.EmailThreads, 1)

messages := s.getMessages(ctx, t, 1)
require.Len(t, messages, 1)
require.Equal(t, integration.Reviewer1UserName, messages[0].Recipient)

require.NoError(t, s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-email-amr"))
}

// TestRecipientsFromAccessMonitoringRuleAfterUpdate tests access monitoring
// rules are respected after the rule is updated.
func (s *EmailSuiteOSS) TestRecipientsFromAccessMonitoringRuleAfterUpdate() {
t := s.T()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
t.Cleanup(cancel)

// Setup base config to ensure access monitoring rule recipient take precidence
s.appConfig.RoleToRecipients = common.RawRecipientsMap{
types.Wildcard: []string{
integration.Reviewer2UserName,
},
}

_, err := s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().
CreateAccessMonitoringRule(ctx, &accessmonitoringrulesv1.AccessMonitoringRule{
Kind: types.KindAccessMonitoringRule,
Version: types.V1,
Metadata: &headerv1.Metadata{
Name: "test-email-amr-2",
},
Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Condition: "!is_empty(access_request.spec.roles)",
Notification: &accessmonitoringrulesv1.Notification{
Name: "email",
Recipients: []string{
integration.Reviewer1UserName,
},
},
},
})
require.NoError(t, err)

// Test execution: create an access request
req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil)
pluginData := s.checkPluginData(ctx, req.GetName(), func(data email.PluginData) bool {
return len(data.EmailThreads) > 0
})
require.Len(t, pluginData.EmailThreads, 1)

messages := s.getMessages(ctx, t, 1)
require.Len(t, messages, 1)
require.Equal(t, integration.Reviewer1UserName, messages[0].Recipient)

// Update the Access Monitoring Rule so it is no longer applied
_, err = s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().
UpdateAccessMonitoringRule(ctx, &accessmonitoringrulesv1.AccessMonitoringRule{
Kind: types.KindAccessMonitoringRule,
Version: types.V1,
Metadata: &headerv1.Metadata{
Name: "test-email-amr-2",
},
Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{
Subjects: []string{"someOtherKind"},
Condition: "!is_empty(access_request.spec.roles)",
Notification: &accessmonitoringrulesv1.Notification{
Name: "email",
Recipients: []string{
integration.Reviewer1UserName,
},
},
},
})
require.NoError(t, err)

// Test execution: create an access request
req = s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil)
pluginData = s.checkPluginData(ctx, req.GetName(), func(data email.PluginData) bool {
return len(data.EmailThreads) > 0
})
require.Len(t, pluginData.EmailThreads, 1)

messages = s.getMessages(ctx, t, 1)
require.Len(t, messages, 1)
require.Equal(t, allRecipient, messages[0].Recipient)

require.NoError(t, s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-email-amr-2"))
}

// TestReviewReplies tests that a followup email is sent after the access request
// is reviewed.
func (s *EmailSuiteEnterprise) TestReviewReplies() {
Expand Down

0 comments on commit d635ef7

Please sign in to comment.