diff --git a/lib/automaticupgrades/maintenance/mock.go b/lib/automaticupgrades/maintenance/mock.go index f46b990ee7930..f705bcee71f8b 100644 --- a/lib/automaticupgrades/maintenance/mock.go +++ b/lib/automaticupgrades/maintenance/mock.go @@ -29,6 +29,7 @@ import ( type StaticTrigger struct { name string canStart bool + err error } // Name returns the StaticTrigger name. @@ -38,7 +39,7 @@ func (m StaticTrigger) Name() string { // CanStart returns the statically defined maintenance approval result. func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) { - return m.canStart, nil + return m.canStart, m.err } // Default returns the default behavior if the trigger fails. This cannot diff --git a/lib/automaticupgrades/maintenance/proxy.go b/lib/automaticupgrades/maintenance/proxy.go new file mode 100644 index 0000000000000..ceb2495e5c17a --- /dev/null +++ b/lib/automaticupgrades/maintenance/proxy.go @@ -0,0 +1,85 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package maintenance + +import ( + "context" + + "github.com/gravitational/trace" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/lib/automaticupgrades/cache" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" +) + +type proxyMaintenanceClient struct { + client *webclient.ReusableClient +} + +// Get does the HTTPS call to the Teleport Proxy sevrice to check if the update should happen now. +// If the proxy response does not contain the auto_update.agent_version field, +// this means the proxy does not support autoupdates. In this case we return trace.NotImplementedErr. +func (b *proxyMaintenanceClient) Get(ctx context.Context) (bool, error) { + resp, err := b.client.Find() + if err != nil { + return false, trace.Wrap(err) + } + // We check if a version is advertised to know if the proxy implements RFD-184 or not. + if resp.AutoUpdate.AgentVersion == "" { + return false, trace.NotImplemented("proxy does not seem to implement RFD-184") + } + return resp.AutoUpdate.AgentAutoUpdate, nil +} + +// ProxyMaintenanceTrigger checks if the maintenance should be triggered from the Teleport Proxy service /find endpoint, +// as specified in the RFD-184: https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md +// The Trigger returns trace.NotImplementedErr when running against a proxy that does not seem to +// expose automatic update instructions over the /find endpoint (proxy too old). +type ProxyMaintenanceTrigger struct { + name string + cachedGetter func(context.Context) (bool, error) +} + +// Name implements maintenance.Trigger returns the trigger name for logging +// and debugging purposes. +func (g ProxyMaintenanceTrigger) Name() string { + return g.name +} + +// Default implements maintenance.Trigger and returns what to do if the trigger can't be evaluated. +// ProxyMaintenanceTrigger should fail open, so the function returns true. +func (g ProxyMaintenanceTrigger) Default() bool { + return false +} + +// CanStart implements maintenance.Trigger. +func (g ProxyMaintenanceTrigger) CanStart(ctx context.Context, _ client.Object) (bool, error) { + result, err := g.cachedGetter(ctx) + return result, trace.Wrap(err) +} + +// NewProxyMaintenanceTrigger builds and return a Trigger checking a public HTTP endpoint. +func NewProxyMaintenanceTrigger(name string, clt *webclient.ReusableClient) Trigger { + maintenanceClient := &proxyMaintenanceClient{ + client: clt, + } + + return ProxyMaintenanceTrigger{name, cache.NewTimedMemoize[bool](maintenanceClient.Get, constants.CacheDuration).Get} +} diff --git a/lib/automaticupgrades/maintenance/trigger.go b/lib/automaticupgrades/maintenance/trigger.go index 53e12b26cdd4a..5d9449d7ad864 100644 --- a/lib/automaticupgrades/maintenance/trigger.go +++ b/lib/automaticupgrades/maintenance/trigger.go @@ -20,7 +20,9 @@ package maintenance import ( "context" + "strings" + "github.com/gravitational/trace" "sigs.k8s.io/controller-runtime/pkg/client" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -51,7 +53,10 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool { start, err := trigger.CanStart(ctx, object) if err != nil { start = trigger.Default() - log.Error(err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", start) + log.Error( + err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", + start, + ) } else { log.Info("trigger evaluated", "trigger", trigger.Name(), "result", start) } @@ -62,3 +67,48 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool { } return false } + +// FailoverTrigger wraps multiple Triggers and tries them sequentially. +// Any error is considered fatal, except for the trace.NotImplementedErr +// which indicates the trigger is not supported yet and we should +// failover to the next trigger. +type FailoverTrigger []Trigger + +// Name implements Trigger +func (f FailoverTrigger) Name() string { + names := make([]string, len(f)) + for i, t := range f { + names[i] = t.Name() + } + + return strings.Join(names, ", failover ") +} + +// CanStart implements Trigger +// Triggers are evaluated sequentially, the result of the first trigger not returning +// trace.NotImplementedErr is used. +func (f FailoverTrigger) CanStart(ctx context.Context, object client.Object) (bool, error) { + for _, trigger := range f { + canStart, err := trigger.CanStart(ctx, object) + switch { + case err == nil: + return canStart, nil + case trace.IsNotImplemented(err): + continue + default: + return false, trace.Wrap(err) + } + } + return false, trace.NotFound("every trigger returned NotImplemented") +} + +// Default implements Trigger. +// The default is the logical OR of every Trigger.Default. +func (f FailoverTrigger) Default() bool { + for _, trigger := range f { + if trigger.Default() { + return true + } + } + return false +} diff --git a/lib/automaticupgrades/maintenance/trigger_test.go b/lib/automaticupgrades/maintenance/trigger_test.go new file mode 100644 index 0000000000000..435b73f0f9bc4 --- /dev/null +++ b/lib/automaticupgrades/maintenance/trigger_test.go @@ -0,0 +1,169 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package maintenance + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc +func checkTraceError(check func(error) bool) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, i ...interface{}) { + require.True(t, check(err), i...) + } +} + +func TestFailoverTrigger_CanStart(t *testing.T) { + t.Parallel() + + // Test setup + ctx := context.Background() + tests := []struct { + name string + triggers []Trigger + expectResult bool + expectErr require.ErrorAssertionFunc + }{ + { + name: "nil", + triggers: nil, + expectResult: false, + expectErr: checkTraceError(trace.IsNotFound), + }, + { + name: "empty", + triggers: []Trigger{}, + expectResult: false, + expectErr: checkTraceError(trace.IsNotFound), + }, + { + name: "first trigger success firing", + triggers: []Trigger{ + StaticTrigger{canStart: true}, + StaticTrigger{canStart: false}, + }, + expectResult: true, + expectErr: require.NoError, + }, + { + name: "first trigger success not firing", + triggers: []Trigger{ + StaticTrigger{canStart: false}, + StaticTrigger{canStart: true}, + }, + expectResult: false, + expectErr: require.NoError, + }, + { + name: "first trigger failure", + triggers: []Trigger{ + StaticTrigger{err: trace.LimitExceeded("got rate-limited")}, + StaticTrigger{canStart: true}, + }, + expectResult: false, + expectErr: checkTraceError(trace.IsLimitExceeded), + }, + { + name: "first trigger skipped, second getter success", + triggers: []Trigger{ + StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticTrigger{canStart: true}, + }, + expectResult: true, + expectErr: require.NoError, + }, + { + name: "first trigger skipped, second getter failure", + triggers: []Trigger{ + StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticTrigger{err: trace.LimitExceeded("got rate-limited")}, + }, + expectResult: false, + expectErr: checkTraceError(trace.IsLimitExceeded), + }, + { + name: "first trigger skipped, second getter skipped", + triggers: []Trigger{ + StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + }, + expectResult: false, + expectErr: checkTraceError(trace.IsNotFound), + }, + } + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + // Test execution + trigger := FailoverTrigger(tt.triggers) + result, err := trigger.CanStart(ctx, nil) + require.Equal(t, tt.expectResult, result) + tt.expectErr(t, err) + }, + ) + } +} + +func TestFailoverTrigger_Name(t *testing.T) { + tests := []struct { + name string + triggers []Trigger + expectResult string + }{ + { + name: "nil", + triggers: nil, + expectResult: "", + }, + { + name: "empty", + triggers: []Trigger{}, + expectResult: "", + }, + { + name: "one trigger", + triggers: []Trigger{ + StaticTrigger{name: "proxy"}, + }, + expectResult: "proxy", + }, + { + name: "two triggers", + triggers: []Trigger{ + StaticTrigger{name: "proxy"}, + StaticTrigger{name: "version-server"}, + }, + expectResult: "proxy, failover version-server", + }, + } + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + // Test execution + trigger := FailoverTrigger(tt.triggers) + result := trigger.Name() + require.Equal(t, tt.expectResult, result) + }, + ) + } +} diff --git a/lib/automaticupgrades/version/proxy.go b/lib/automaticupgrades/version/proxy.go new file mode 100644 index 0000000000000..db55123dd529e --- /dev/null +++ b/lib/automaticupgrades/version/proxy.go @@ -0,0 +1,75 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package version + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/lib/automaticupgrades/cache" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" +) + +type proxyVersionClient struct { + client *webclient.ReusableClient +} + +func (b *proxyVersionClient) Get(ctx context.Context) (string, error) { + resp, err := b.client.Find() + if err != nil { + return "", trace.Wrap(err) + } + // We check if a version is advertised to know if the proxy implements RFD-184 or not. + if resp.AutoUpdate.AgentVersion == "" { + return "", trace.NotImplemented("proxy does not seem to implement RFD-184") + } + return resp.AutoUpdate.AgentVersion, nil +} + +// ProxyVersionGetter gets the target version from the Teleport Proxy Service /find endpoint, as +// specified in the RFD-184: https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md +// The Getter returns trace.NotImplementedErr when running against a proxy that does not seem to +// expose automatic update instructions over the /find endpoint (proxy too old). +type ProxyVersionGetter struct { + name string + cachedGetter func(context.Context) (string, error) +} + +// Name implements Getter +func (g ProxyVersionGetter) Name() string { + return g.name +} + +// GetVersion implements Getter +func (g ProxyVersionGetter) GetVersion(ctx context.Context) (string, error) { + result, err := g.cachedGetter(ctx) + return result, trace.Wrap(err) +} + +// NewProxyVersionGetter creates a ProxyVersionGetter from a webclient. +// The answer is cached for a minute. +func NewProxyVersionGetter(name string, clt *webclient.ReusableClient) Getter { + versionClient := &proxyVersionClient{ + client: clt, + } + + return ProxyVersionGetter{name, cache.NewTimedMemoize[string](versionClient.Get, constants.CacheDuration).Get} +} diff --git a/lib/automaticupgrades/version/versionget.go b/lib/automaticupgrades/version/versionget.go index f1e7723a9a320..e2a1a893e5270 100644 --- a/lib/automaticupgrades/version/versionget.go +++ b/lib/automaticupgrades/version/versionget.go @@ -1,6 +1,6 @@ /* * Teleport - * Copyright (C) 2023 Gravitational, Inc. + * Copyright (C) 2024 Gravitational, Inc. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -36,13 +36,42 @@ type Getter interface { GetVersion(context.Context) (string, error) } +// FailoverGetter wraps multiple Getters and tries them sequentially. +// Any error is considered fatal, except for the trace.NotImplementedErr +// which indicates the version getter is not supported yet and we should +// failover to the next version getter. +type FailoverGetter []Getter + +// GetVersion implements Getter +// Getters are evaluated sequentially, the result of the first getter not returning +// trace.NotImplementedErr is used. +func (f FailoverGetter) GetVersion(ctx context.Context) (string, error) { + for _, getter := range f { + version, err := getter.GetVersion(ctx) + switch { + case err == nil: + return version, nil + case trace.IsNotImplemented(err): + continue + default: + return "", trace.Wrap(err) + } + } + return "", trace.NotFound("every versionGetter returned NotImplemented") +} + // ValidVersionChange receives the current version and the candidate next version // and evaluates if the version transition is valid. func ValidVersionChange(ctx context.Context, current, next string) bool { log := ctrllog.FromContext(ctx).V(1) // Cannot upgrade to a non-valid version if !semver.IsValid(next) { - log.Error(trace.BadParameter("next version is not following semver"), "version change is invalid", "nextVersion", next) + log.Error( + trace.BadParameter("next version is not following semver"), + "version change is invalid", + "current_version", current, + "next_version", next, + ) return false } switch semver.Compare(next, current) { diff --git a/lib/automaticupgrades/version/versionget_test.go b/lib/automaticupgrades/version/versionget_test.go index 80c2ec767b8fb..78f4940db229a 100644 --- a/lib/automaticupgrades/version/versionget_test.go +++ b/lib/automaticupgrades/version/versionget_test.go @@ -22,6 +22,7 @@ import ( "context" "testing" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" ) @@ -66,8 +67,99 @@ func TestValidVersionChange(t *testing.T) { }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next)) - }) + t.Run( + tt.name, func(t *testing.T) { + require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next)) + }, + ) + } +} + +// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc +func checkTraceError(check func(error) bool) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, i ...interface{}) { + require.True(t, check(err), i...) + } +} + +func TestFailoverGetter_GetVersion(t *testing.T) { + t.Parallel() + + // Test setup + ctx := context.Background() + tests := []struct { + name string + getters []Getter + expectResult string + expectErr require.ErrorAssertionFunc + }{ + { + name: "nil", + getters: nil, + expectResult: "", + expectErr: checkTraceError(trace.IsNotFound), + }, + { + name: "empty", + getters: []Getter{}, + expectResult: "", + expectErr: checkTraceError(trace.IsNotFound), + }, + { + name: "first getter success", + getters: []Getter{ + StaticGetter{version: semverMid}, + StaticGetter{version: semverHigh}, + }, + expectResult: semverMid, + expectErr: require.NoError, + }, + { + name: "first getter failure", + getters: []Getter{ + StaticGetter{err: trace.LimitExceeded("got rate-limited")}, + StaticGetter{version: semverHigh}, + }, + expectResult: "", + expectErr: checkTraceError(trace.IsLimitExceeded), + }, + { + name: "first getter skipped, second getter success", + getters: []Getter{ + StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticGetter{version: semverHigh}, + }, + expectResult: semverHigh, + expectErr: require.NoError, + }, + { + name: "first getter skipped, second getter failure", + getters: []Getter{ + StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticGetter{err: trace.LimitExceeded("got rate-limited")}, + }, + expectResult: "", + expectErr: checkTraceError(trace.IsLimitExceeded), + }, + { + name: "first getter skipped, second getter skipped", + getters: []Getter{ + StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + }, + expectResult: "", + expectErr: checkTraceError(trace.IsNotFound), + }, + } + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + // Test execution + getter := FailoverGetter(tt.getters) + result, err := getter.GetVersion(ctx) + require.Equal(t, tt.expectResult, result) + tt.expectErr(t, err) + }, + ) } }