Skip to content

Commit

Permalink
Merge branch 'master' into edwarddowling/msteams-amr
Browse files Browse the repository at this point in the history
  • Loading branch information
EdwardDowling authored Oct 22, 2024
2 parents 46ed3f0 + 3672bc6 commit 6784da3
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 44 deletions.
10 changes: 10 additions & 0 deletions integrations/access/accessmonitoring/access_monitoring_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type RuleHandler struct {
pluginName string

fetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error)
onCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

// RuleMap is a concurrent map for access monitoring rules.
Expand All @@ -65,6 +66,8 @@ type RuleHandlerConfig struct {

// FetchRecipientCallback is a callback that maps recipient strings to plugin Recipients.
FetchRecipientCallback func(ctx context.Context, recipient string) (*common.Recipient, error)
// OnCacheUpdateCallback is a callback that is called when a rule in the cache is created or updated.
OnCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

// NewRuleHandler returns a new RuleHandler.
Expand All @@ -77,6 +80,7 @@ func NewRuleHandler(conf RuleHandlerConfig) *RuleHandler {
pluginType: conf.PluginType,
pluginName: conf.PluginName,
fetchRecipientCallback: conf.FetchRecipientCallback,
onCacheUpdateCallback: conf.OnCacheUpdateCallback,
}
}

Expand All @@ -93,6 +97,9 @@ func (amrh *RuleHandler) InitAccessMonitoringRulesCache(ctx context.Context) err
continue
}
amrh.accessMonitoringRules.rules[amr.GetMetadata().Name] = amr
if amrh.onCacheUpdateCallback != nil {
amrh.onCacheUpdateCallback(types.OpPut, amr.GetMetadata().Name, amr)
}
}
return nil
}
Expand Down Expand Up @@ -123,6 +130,9 @@ func (amrh *RuleHandler) HandleAccessMonitoringRule(ctx context.Context, event t
return nil
}
amrh.accessMonitoringRules.rules[req.Metadata.Name] = req
if amrh.onCacheUpdateCallback != nil {
amrh.onCacheUpdateCallback(types.OpPut, req.GetMetadata().Name, req)
}
return nil
case types.OpDelete:
delete(amrh.accessMonitoringRules.rules, event.Resource.GetName())
Expand Down
9 changes: 6 additions & 3 deletions integrations/access/pagerduty/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ func NewApp(conf Config) (*App, error) {
teleport: conf.Client,
statusSink: conf.StatusSink,
}

app.mainJob = lib.NewServiceJob(app.run)

return app, nil
Expand Down Expand Up @@ -173,7 +172,7 @@ func (a *App) init(ctx context.Context) error {
}
}

a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{
amrhConf := accessmonitoring.RuleHandlerConfig{
Client: a.teleport,
PluginType: types.PluginTypePagerDuty,
PluginName: pluginName,
Expand All @@ -184,7 +183,11 @@ func (a *App) init(ctx context.Context) error {
Kind: common.RecipientKindSchedule,
}, nil
},
})
}
if a.conf.OnAccessMonitoringRuleCacheUpdateCallback != nil {
amrhConf.OnCacheUpdateCallback = a.conf.OnAccessMonitoringRuleCacheUpdateCallback
}
a.accessMonitoringRules = accessmonitoring.NewRuleHandler(amrhConf)

if pong, err = a.checkTeleportVersion(ctx); err != nil {
return trace.Wrap(err)
Expand Down
6 changes: 6 additions & 0 deletions integrations/access/pagerduty/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"github.com/gravitational/trace"
"github.com/pelletier/go-toml"

accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integrations/access/common"
"github.com/gravitational/teleport/integrations/access/common/teleport"
"github.com/gravitational/teleport/integrations/lib"
Expand All @@ -47,6 +49,10 @@ type Config struct {
// TeleportUser is the name of the Teleport user that will act
// as the access request approver
TeleportUser string

// OnAccessMonitoringRuleCacheUpdateCallback is used for checking when
// the Rule cache is updated in tests
OnAccessMonitoringRuleCacheUpdateCallback func(Operation types.OpType, name string, rule *accessmonitoringrulesv1.AccessMonitoringRule) error
}

type PagerdutyConfig struct {
Expand Down
27 changes: 22 additions & 5 deletions integrations/access/pagerduty/testlib/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,15 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
t.Cleanup(cancel)

const ruleName = "test-pagerduty-amr"
var collectedNames []string
var mu sync.Mutex
s.appConfig.OnAccessMonitoringRuleCacheUpdateCallback = func(_ types.OpType, name string, _ *accessmonitoringrulesv1.AccessMonitoringRule) error {
mu.Lock()
collectedNames = append(collectedNames, name)
mu.Unlock()
return nil
}
s.startApp()

_, err := s.ClientByName(integration.RulerUserName).
Expand All @@ -438,7 +447,7 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
Kind: types.KindAccessMonitoringRule,
Version: types.V1,
Metadata: &v1.Metadata{
Name: "test-pagerduty-amr",
Name: ruleName,
},
Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{
Subjects: []string{types.KindAccessRequest},
Expand All @@ -453,6 +462,14 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
})
assert.NoError(t, err)

// Incident creation may happen before plugins Access Monitoring Rule cache
// has been updated with new rule. Retry until the new cache picks up the rule.
require.EventuallyWithT(t, func(t *assert.CollectT) {
mu.Lock()
require.Contains(t, collectedNames, ruleName)
mu.Unlock()
}, 3*time.Second, time.Millisecond*100, "new access monitoring rule did not begin applying")

// Test execution: create an access request
req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil)

Expand All @@ -463,16 +480,16 @@ func (s *PagerdutySuiteOSS) TestRecipientsFromAccessMonitoringRule() {
})

incident, err := s.fakePagerduty.CheckNewIncident(ctx)
require.NoError(t, err, "no new incidents stored")

assert.NoError(t, err, "no new incidents stored")
assert.Equal(t, incident.ID, pluginData.IncidentID)
assert.Equal(t, s.pdNotifyService2.ID, pluginData.ServiceID)

assert.Equal(t, pagerduty.PdIncidentKeyPrefix+"/"+req.GetName(), incident.IncidentKey)
assert.Equal(t, "triggered", incident.Status)

assert.Equal(t, s.pdNotifyService2.ID, pluginData.ServiceID)

assert.NoError(t, s.ClientByName(integration.RulerUserName).
AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-pagerduty-amr"))
AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, ruleName))
}

func (s *PagerdutyBaseSuite) assertNewEvent(ctx context.Context, watcher types.Watcher, opType types.OpType, resourceKind, resourceName string) types.Event {
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -990,10 +990,10 @@ func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) err
installerScript: req.InstallerScriptName(),
},
&usertasksv1.DiscoverEC2Instance{
// TODO(marco): add instance name
DiscoveryConfig: instances.DiscoveryConfig,
DiscoveryGroup: s.DiscoveryGroup,
InstanceId: instance.InstanceID,
Name: instance.InstanceName,
SyncTime: timestamppb.New(s.clock.Now()),
},
)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@ func (s *Server) ReportEC2SSMInstallationResult(ctx context.Context, result *ser
installerScript: result.InstallerScript,
},
&usertasksv1.DiscoverEC2Instance{
// TODO(marco): add instance name
InvocationUrl: result.SSMRunEvent.InvocationURL,
DiscoveryConfig: result.DiscoveryConfig,
DiscoveryGroup: s.DiscoveryGroup,
SyncTime: timestamppb.New(result.SSMRunEvent.Time),
InstanceId: result.SSMRunEvent.InstanceID,
Name: result.InstanceName,
},
)

Expand Down
4 changes: 4 additions & 0 deletions lib/srv/server/ec2_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type EC2Instances struct {
// discovered.
type EC2Instance struct {
InstanceID string
InstanceName string
Tags map[string]string
OriginalInstance ec2.Instance
}
Expand All @@ -92,6 +93,9 @@ func toEC2Instance(originalInst *ec2.Instance) EC2Instance {
for _, tag := range originalInst.Tags {
if key := aws.StringValue(tag.Key); key != "" {
inst.Tags[key] = aws.StringValue(tag.Value)
if key == "Name" {
inst.InstanceName = aws.StringValue(tag.Value)
}
}
}
return inst
Expand Down
82 changes: 78 additions & 4 deletions lib/srv/server/ec2_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,16 @@ func TestEC2Watcher(t *testing.T) {

present := ec2.Instance{
InstanceId: aws.String("instance-present"),
Tags: []*ec2.Tag{{
Key: aws.String("teleport"),
Value: aws.String("yes"),
}},
Tags: []*ec2.Tag{
{
Key: aws.String("teleport"),
Value: aws.String("yes"),
},
{
Key: aws.String("Name"),
Value: aws.String("Present"),
},
},
State: &ec2.InstanceState{
Name: aws.String(ec2.InstanceStateNameRunning),
},
Expand Down Expand Up @@ -360,3 +366,71 @@ func TestMakeEvents(t *testing.T) {
})
}
}

func TestToEC2Instances(t *testing.T) {
sampleInstance := &ec2.Instance{
InstanceId: aws.String("instance-001"),
Tags: []*ec2.Tag{
{
Key: aws.String("teleport"),
Value: aws.String("yes"),
},
{
Key: aws.String("Name"),
Value: aws.String("MyInstanceName"),
},
},
State: &ec2.InstanceState{
Name: aws.String(ec2.InstanceStateNameRunning),
},
}

sampleInstanceWithoutName := &ec2.Instance{
InstanceId: aws.String("instance-001"),
Tags: []*ec2.Tag{
{
Key: aws.String("teleport"),
Value: aws.String("yes"),
},
},
State: &ec2.InstanceState{
Name: aws.String(ec2.InstanceStateNameRunning),
},
}

for _, tt := range []struct {
name string
input []*ec2.Instance
expected []EC2Instance
}{
{
name: "with name",
input: []*ec2.Instance{sampleInstance},
expected: []EC2Instance{{
InstanceID: "instance-001",
Tags: map[string]string{
"Name": "MyInstanceName",
"teleport": "yes",
},
InstanceName: "MyInstanceName",
OriginalInstance: *sampleInstance,
}},
},
{
name: "without name",
input: []*ec2.Instance{sampleInstanceWithoutName},
expected: []EC2Instance{{
InstanceID: "instance-001",
Tags: map[string]string{
"teleport": "yes",
},
OriginalInstance: *sampleInstanceWithoutName,
}},
},
} {
t.Run(tt.name, func(t *testing.T) {
got := ToEC2Instances(tt.input)
require.Equal(t, tt.expected, got)
})
}
}
Loading

0 comments on commit 6784da3

Please sign in to comment.