diff --git a/lib/automaticupgrades/maintenance/mock.go b/lib/automaticupgrades/maintenance/mock.go
index f46b990ee7930..f705bcee71f8b 100644
--- a/lib/automaticupgrades/maintenance/mock.go
+++ b/lib/automaticupgrades/maintenance/mock.go
@@ -29,6 +29,7 @@ import (
type StaticTrigger struct {
name string
canStart bool
+ err error
}
// Name returns the StaticTrigger name.
@@ -38,7 +39,7 @@ func (m StaticTrigger) Name() string {
// CanStart returns the statically defined maintenance approval result.
func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) {
- return m.canStart, nil
+ return m.canStart, m.err
}
// Default returns the default behavior if the trigger fails. This cannot
diff --git a/lib/automaticupgrades/maintenance/proxy.go b/lib/automaticupgrades/maintenance/proxy.go
new file mode 100644
index 0000000000000..ceb2495e5c17a
--- /dev/null
+++ b/lib/automaticupgrades/maintenance/proxy.go
@@ -0,0 +1,85 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package maintenance
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+
+ "github.com/gravitational/teleport/api/client/webclient"
+ "github.com/gravitational/teleport/lib/automaticupgrades/cache"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+)
+
+type proxyMaintenanceClient struct {
+ client *webclient.ReusableClient
+}
+
+// Get does the HTTPS call to the Teleport Proxy sevrice to check if the update should happen now.
+// If the proxy response does not contain the auto_update.agent_version field,
+// this means the proxy does not support autoupdates. In this case we return trace.NotImplementedErr.
+func (b *proxyMaintenanceClient) Get(ctx context.Context) (bool, error) {
+ resp, err := b.client.Find()
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ // We check if a version is advertised to know if the proxy implements RFD-184 or not.
+ if resp.AutoUpdate.AgentVersion == "" {
+ return false, trace.NotImplemented("proxy does not seem to implement RFD-184")
+ }
+ return resp.AutoUpdate.AgentAutoUpdate, nil
+}
+
+// ProxyMaintenanceTrigger checks if the maintenance should be triggered from the Teleport Proxy service /find endpoint,
+// as specified in the RFD-184: https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md
+// The Trigger returns trace.NotImplementedErr when running against a proxy that does not seem to
+// expose automatic update instructions over the /find endpoint (proxy too old).
+type ProxyMaintenanceTrigger struct {
+ name string
+ cachedGetter func(context.Context) (bool, error)
+}
+
+// Name implements maintenance.Trigger returns the trigger name for logging
+// and debugging purposes.
+func (g ProxyMaintenanceTrigger) Name() string {
+ return g.name
+}
+
+// Default implements maintenance.Trigger and returns what to do if the trigger can't be evaluated.
+// ProxyMaintenanceTrigger should fail open, so the function returns true.
+func (g ProxyMaintenanceTrigger) Default() bool {
+ return false
+}
+
+// CanStart implements maintenance.Trigger.
+func (g ProxyMaintenanceTrigger) CanStart(ctx context.Context, _ client.Object) (bool, error) {
+ result, err := g.cachedGetter(ctx)
+ return result, trace.Wrap(err)
+}
+
+// NewProxyMaintenanceTrigger builds and return a Trigger checking a public HTTP endpoint.
+func NewProxyMaintenanceTrigger(name string, clt *webclient.ReusableClient) Trigger {
+ maintenanceClient := &proxyMaintenanceClient{
+ client: clt,
+ }
+
+ return ProxyMaintenanceTrigger{name, cache.NewTimedMemoize[bool](maintenanceClient.Get, constants.CacheDuration).Get}
+}
diff --git a/lib/automaticupgrades/maintenance/trigger.go b/lib/automaticupgrades/maintenance/trigger.go
index 53e12b26cdd4a..5d9449d7ad864 100644
--- a/lib/automaticupgrades/maintenance/trigger.go
+++ b/lib/automaticupgrades/maintenance/trigger.go
@@ -20,7 +20,9 @@ package maintenance
import (
"context"
+ "strings"
+ "github.com/gravitational/trace"
"sigs.k8s.io/controller-runtime/pkg/client"
ctrllog "sigs.k8s.io/controller-runtime/pkg/log"
)
@@ -51,7 +53,10 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool {
start, err := trigger.CanStart(ctx, object)
if err != nil {
start = trigger.Default()
- log.Error(err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", start)
+ log.Error(
+ err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue",
+ start,
+ )
} else {
log.Info("trigger evaluated", "trigger", trigger.Name(), "result", start)
}
@@ -62,3 +67,48 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool {
}
return false
}
+
+// FailoverTrigger wraps multiple Triggers and tries them sequentially.
+// Any error is considered fatal, except for the trace.NotImplementedErr
+// which indicates the trigger is not supported yet and we should
+// failover to the next trigger.
+type FailoverTrigger []Trigger
+
+// Name implements Trigger
+func (f FailoverTrigger) Name() string {
+ names := make([]string, len(f))
+ for i, t := range f {
+ names[i] = t.Name()
+ }
+
+ return strings.Join(names, ", failover ")
+}
+
+// CanStart implements Trigger
+// Triggers are evaluated sequentially, the result of the first trigger not returning
+// trace.NotImplementedErr is used.
+func (f FailoverTrigger) CanStart(ctx context.Context, object client.Object) (bool, error) {
+ for _, trigger := range f {
+ canStart, err := trigger.CanStart(ctx, object)
+ switch {
+ case err == nil:
+ return canStart, nil
+ case trace.IsNotImplemented(err):
+ continue
+ default:
+ return false, trace.Wrap(err)
+ }
+ }
+ return false, trace.NotFound("every trigger returned NotImplemented")
+}
+
+// Default implements Trigger.
+// The default is the logical OR of every Trigger.Default.
+func (f FailoverTrigger) Default() bool {
+ for _, trigger := range f {
+ if trigger.Default() {
+ return true
+ }
+ }
+ return false
+}
diff --git a/lib/automaticupgrades/maintenance/trigger_test.go b/lib/automaticupgrades/maintenance/trigger_test.go
new file mode 100644
index 0000000000000..435b73f0f9bc4
--- /dev/null
+++ b/lib/automaticupgrades/maintenance/trigger_test.go
@@ -0,0 +1,169 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package maintenance
+
+import (
+ "context"
+ "testing"
+
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+)
+
+// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc
+func checkTraceError(check func(error) bool) require.ErrorAssertionFunc {
+ return func(t require.TestingT, err error, i ...interface{}) {
+ require.True(t, check(err), i...)
+ }
+}
+
+func TestFailoverTrigger_CanStart(t *testing.T) {
+ t.Parallel()
+
+ // Test setup
+ ctx := context.Background()
+ tests := []struct {
+ name string
+ triggers []Trigger
+ expectResult bool
+ expectErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "nil",
+ triggers: nil,
+ expectResult: false,
+ expectErr: checkTraceError(trace.IsNotFound),
+ },
+ {
+ name: "empty",
+ triggers: []Trigger{},
+ expectResult: false,
+ expectErr: checkTraceError(trace.IsNotFound),
+ },
+ {
+ name: "first trigger success firing",
+ triggers: []Trigger{
+ StaticTrigger{canStart: true},
+ StaticTrigger{canStart: false},
+ },
+ expectResult: true,
+ expectErr: require.NoError,
+ },
+ {
+ name: "first trigger success not firing",
+ triggers: []Trigger{
+ StaticTrigger{canStart: false},
+ StaticTrigger{canStart: true},
+ },
+ expectResult: false,
+ expectErr: require.NoError,
+ },
+ {
+ name: "first trigger failure",
+ triggers: []Trigger{
+ StaticTrigger{err: trace.LimitExceeded("got rate-limited")},
+ StaticTrigger{canStart: true},
+ },
+ expectResult: false,
+ expectErr: checkTraceError(trace.IsLimitExceeded),
+ },
+ {
+ name: "first trigger skipped, second getter success",
+ triggers: []Trigger{
+ StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
+ StaticTrigger{canStart: true},
+ },
+ expectResult: true,
+ expectErr: require.NoError,
+ },
+ {
+ name: "first trigger skipped, second getter failure",
+ triggers: []Trigger{
+ StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
+ StaticTrigger{err: trace.LimitExceeded("got rate-limited")},
+ },
+ expectResult: false,
+ expectErr: checkTraceError(trace.IsLimitExceeded),
+ },
+ {
+ name: "first trigger skipped, second getter skipped",
+ triggers: []Trigger{
+ StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
+ StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
+ },
+ expectResult: false,
+ expectErr: checkTraceError(trace.IsNotFound),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(
+ tt.name, func(t *testing.T) {
+ // Test execution
+ trigger := FailoverTrigger(tt.triggers)
+ result, err := trigger.CanStart(ctx, nil)
+ require.Equal(t, tt.expectResult, result)
+ tt.expectErr(t, err)
+ },
+ )
+ }
+}
+
+func TestFailoverTrigger_Name(t *testing.T) {
+ tests := []struct {
+ name string
+ triggers []Trigger
+ expectResult string
+ }{
+ {
+ name: "nil",
+ triggers: nil,
+ expectResult: "",
+ },
+ {
+ name: "empty",
+ triggers: []Trigger{},
+ expectResult: "",
+ },
+ {
+ name: "one trigger",
+ triggers: []Trigger{
+ StaticTrigger{name: "proxy"},
+ },
+ expectResult: "proxy",
+ },
+ {
+ name: "two triggers",
+ triggers: []Trigger{
+ StaticTrigger{name: "proxy"},
+ StaticTrigger{name: "version-server"},
+ },
+ expectResult: "proxy, failover version-server",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(
+ tt.name, func(t *testing.T) {
+ // Test execution
+ trigger := FailoverTrigger(tt.triggers)
+ result := trigger.Name()
+ require.Equal(t, tt.expectResult, result)
+ },
+ )
+ }
+}
diff --git a/lib/automaticupgrades/version/proxy.go b/lib/automaticupgrades/version/proxy.go
new file mode 100644
index 0000000000000..db55123dd529e
--- /dev/null
+++ b/lib/automaticupgrades/version/proxy.go
@@ -0,0 +1,75 @@
+/*
+ * Teleport
+ * Copyright (C) 2023 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package version
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/client/webclient"
+ "github.com/gravitational/teleport/lib/automaticupgrades/cache"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+)
+
+type proxyVersionClient struct {
+ client *webclient.ReusableClient
+}
+
+func (b *proxyVersionClient) Get(ctx context.Context) (string, error) {
+ resp, err := b.client.Find()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ // We check if a version is advertised to know if the proxy implements RFD-184 or not.
+ if resp.AutoUpdate.AgentVersion == "" {
+ return "", trace.NotImplemented("proxy does not seem to implement RFD-184")
+ }
+ return resp.AutoUpdate.AgentVersion, nil
+}
+
+// ProxyVersionGetter gets the target version from the Teleport Proxy Service /find endpoint, as
+// specified in the RFD-184: https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md
+// The Getter returns trace.NotImplementedErr when running against a proxy that does not seem to
+// expose automatic update instructions over the /find endpoint (proxy too old).
+type ProxyVersionGetter struct {
+ name string
+ cachedGetter func(context.Context) (string, error)
+}
+
+// Name implements Getter
+func (g ProxyVersionGetter) Name() string {
+ return g.name
+}
+
+// GetVersion implements Getter
+func (g ProxyVersionGetter) GetVersion(ctx context.Context) (string, error) {
+ result, err := g.cachedGetter(ctx)
+ return result, trace.Wrap(err)
+}
+
+// NewProxyVersionGetter creates a ProxyVersionGetter from a webclient.
+// The answer is cached for a minute.
+func NewProxyVersionGetter(name string, clt *webclient.ReusableClient) Getter {
+ versionClient := &proxyVersionClient{
+ client: clt,
+ }
+
+ return ProxyVersionGetter{name, cache.NewTimedMemoize[string](versionClient.Get, constants.CacheDuration).Get}
+}
diff --git a/lib/automaticupgrades/version/versionget.go b/lib/automaticupgrades/version/versionget.go
index f1e7723a9a320..e2a1a893e5270 100644
--- a/lib/automaticupgrades/version/versionget.go
+++ b/lib/automaticupgrades/version/versionget.go
@@ -1,6 +1,6 @@
/*
* Teleport
- * Copyright (C) 2023 Gravitational, Inc.
+ * Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
@@ -36,13 +36,42 @@ type Getter interface {
GetVersion(context.Context) (string, error)
}
+// FailoverGetter wraps multiple Getters and tries them sequentially.
+// Any error is considered fatal, except for the trace.NotImplementedErr
+// which indicates the version getter is not supported yet and we should
+// failover to the next version getter.
+type FailoverGetter []Getter
+
+// GetVersion implements Getter
+// Getters are evaluated sequentially, the result of the first getter not returning
+// trace.NotImplementedErr is used.
+func (f FailoverGetter) GetVersion(ctx context.Context) (string, error) {
+ for _, getter := range f {
+ version, err := getter.GetVersion(ctx)
+ switch {
+ case err == nil:
+ return version, nil
+ case trace.IsNotImplemented(err):
+ continue
+ default:
+ return "", trace.Wrap(err)
+ }
+ }
+ return "", trace.NotFound("every versionGetter returned NotImplemented")
+}
+
// ValidVersionChange receives the current version and the candidate next version
// and evaluates if the version transition is valid.
func ValidVersionChange(ctx context.Context, current, next string) bool {
log := ctrllog.FromContext(ctx).V(1)
// Cannot upgrade to a non-valid version
if !semver.IsValid(next) {
- log.Error(trace.BadParameter("next version is not following semver"), "version change is invalid", "nextVersion", next)
+ log.Error(
+ trace.BadParameter("next version is not following semver"),
+ "version change is invalid",
+ "current_version", current,
+ "next_version", next,
+ )
return false
}
switch semver.Compare(next, current) {
diff --git a/lib/automaticupgrades/version/versionget_test.go b/lib/automaticupgrades/version/versionget_test.go
index 80c2ec767b8fb..78f4940db229a 100644
--- a/lib/automaticupgrades/version/versionget_test.go
+++ b/lib/automaticupgrades/version/versionget_test.go
@@ -22,6 +22,7 @@ import (
"context"
"testing"
+ "github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)
@@ -66,8 +67,99 @@ func TestValidVersionChange(t *testing.T) {
},
}
for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next))
- })
+ t.Run(
+ tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next))
+ },
+ )
+ }
+}
+
+// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc
+func checkTraceError(check func(error) bool) require.ErrorAssertionFunc {
+ return func(t require.TestingT, err error, i ...interface{}) {
+ require.True(t, check(err), i...)
+ }
+}
+
+func TestFailoverGetter_GetVersion(t *testing.T) {
+ t.Parallel()
+
+ // Test setup
+ ctx := context.Background()
+ tests := []struct {
+ name string
+ getters []Getter
+ expectResult string
+ expectErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "nil",
+ getters: nil,
+ expectResult: "",
+ expectErr: checkTraceError(trace.IsNotFound),
+ },
+ {
+ name: "empty",
+ getters: []Getter{},
+ expectResult: "",
+ expectErr: checkTraceError(trace.IsNotFound),
+ },
+ {
+ name: "first getter success",
+ getters: []Getter{
+ StaticGetter{version: semverMid},
+ StaticGetter{version: semverHigh},
+ },
+ expectResult: semverMid,
+ expectErr: require.NoError,
+ },
+ {
+ name: "first getter failure",
+ getters: []Getter{
+ StaticGetter{err: trace.LimitExceeded("got rate-limited")},
+ StaticGetter{version: semverHigh},
+ },
+ expectResult: "",
+ expectErr: checkTraceError(trace.IsLimitExceeded),
+ },
+ {
+ name: "first getter skipped, second getter success",
+ getters: []Getter{
+ StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
+ StaticGetter{version: semverHigh},
+ },
+ expectResult: semverHigh,
+ expectErr: require.NoError,
+ },
+ {
+ name: "first getter skipped, second getter failure",
+ getters: []Getter{
+ StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
+ StaticGetter{err: trace.LimitExceeded("got rate-limited")},
+ },
+ expectResult: "",
+ expectErr: checkTraceError(trace.IsLimitExceeded),
+ },
+ {
+ name: "first getter skipped, second getter skipped",
+ getters: []Getter{
+ StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
+ StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
+ },
+ expectResult: "",
+ expectErr: checkTraceError(trace.IsNotFound),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(
+ tt.name, func(t *testing.T) {
+ // Test execution
+ getter := FailoverGetter(tt.getters)
+ result, err := getter.GetVersion(ctx)
+ require.Equal(t, tt.expectResult, result)
+ tt.expectErr(t, err)
+ },
+ )
}
}