Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kube-agent-updater: add RFD-184 trigger and version getter #49297

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/automaticupgrades/maintenance/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
type StaticTrigger struct {
name string
canStart bool
err error
}

// Name returns the StaticTrigger name.
Expand All @@ -38,7 +39,7 @@ func (m StaticTrigger) Name() string {

// CanStart returns the statically defined maintenance approval result.
func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) {
return m.canStart, nil
return m.canStart, m.err
}

// Default returns the default behavior if the trigger fails. This cannot
Expand Down
85 changes: 85 additions & 0 deletions lib/automaticupgrades/maintenance/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package maintenance

import (
"context"

"github.com/gravitational/trace"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/gravitational/teleport/api/client/webclient"
"github.com/gravitational/teleport/lib/automaticupgrades/cache"
"github.com/gravitational/teleport/lib/automaticupgrades/constants"
)

type proxyMaintenanceClient struct {
client *webclient.ReusableClient
}

// Get does the HTTPS call to the Teleport Proxy sevrice to check if the update should happen now.
// If the proxy response does not contain the auto_update.agent_version field,
// this means the proxy does not support autoupdates. In this case we return trace.NotImplementedErr.
func (b *proxyMaintenanceClient) Get(ctx context.Context) (bool, error) {
hugoShaka marked this conversation as resolved.
Show resolved Hide resolved
resp, err := b.client.Find()
if err != nil {
return false, trace.Wrap(err)
}
// We check if a version is advertised to know if the proxy implements RFD-184 or not.
if resp.AutoUpdate.AgentVersion == "" {
return false, trace.NotImplemented("proxy does not seem to implement RFD-184")
}
return resp.AutoUpdate.AgentAutoUpdate, nil
}

// ProxyMaintenanceTrigger checks if the maintenance should be triggered from the Teleport Proxy service /find endpoint,
// as specified in the RFD-184: https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md
// The Trigger returns trace.NotImplementedErr when running against a proxy that does not seem to
// expose automatic update instructions over the /find endpoint (proxy too old).
type ProxyMaintenanceTrigger struct {
name string
cachedGetter func(context.Context) (bool, error)
}

// Name implements maintenance.Trigger returns the trigger name for logging
// and debugging purposes.
func (g ProxyMaintenanceTrigger) Name() string {
return g.name
}

// Default implements maintenance.Trigger and returns what to do if the trigger can't be evaluated.
// ProxyMaintenanceTrigger should fail open, so the function returns true.
func (g ProxyMaintenanceTrigger) Default() bool {
return false
}

// CanStart implements maintenance.Trigger.
func (g ProxyMaintenanceTrigger) CanStart(ctx context.Context, _ client.Object) (bool, error) {
result, err := g.cachedGetter(ctx)
return result, trace.Wrap(err)
}

// NewProxyMaintenanceTrigger builds and return a Trigger checking a public HTTP endpoint.
func NewProxyMaintenanceTrigger(name string, clt *webclient.ReusableClient) Trigger {
maintenanceClient := &proxyMaintenanceClient{
client: clt,
}

return ProxyMaintenanceTrigger{name, cache.NewTimedMemoize[bool](maintenanceClient.Get, constants.CacheDuration).Get}
}
52 changes: 51 additions & 1 deletion lib/automaticupgrades/maintenance/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package maintenance

import (
"context"
"strings"

"github.com/gravitational/trace"
"sigs.k8s.io/controller-runtime/pkg/client"
ctrllog "sigs.k8s.io/controller-runtime/pkg/log"
)
Expand Down Expand Up @@ -51,7 +53,10 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool {
start, err := trigger.CanStart(ctx, object)
if err != nil {
start = trigger.Default()
log.Error(err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", start)
log.Error(
err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue",
start,
)
} else {
log.Info("trigger evaluated", "trigger", trigger.Name(), "result", start)
}
Expand All @@ -62,3 +67,48 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool {
}
return false
}

// FailoverTrigger wraps multiple Triggers and tries them sequentially.
// Any error is considered fatal, except for the trace.NotImplementedErr
// which indicates the trigger is not supported yet and we should
// failover to the next trigger.
type FailoverTrigger []Trigger

// Name implements Trigger
func (f FailoverTrigger) Name() string {
names := make([]string, len(f))
for i, t := range f {
names[i] = t.Name()
}

return strings.Join(names, ", failover ")
}

// CanStart implements Trigger
// Triggers are evaluated sequentially, the result of the first trigger not returning
// trace.NotImplementedErr is used.
func (f FailoverTrigger) CanStart(ctx context.Context, object client.Object) (bool, error) {
for _, trigger := range f {
canStart, err := trigger.CanStart(ctx, object)
switch {
case err == nil:
return canStart, nil
case trace.IsNotImplemented(err):
continue
default:
return false, trace.Wrap(err)
}
}
return false, trace.NotFound("every trigger returned NotImplemented")
}

// Default implements Trigger.
// The default is the logical OR of every Trigger.Default.
func (f FailoverTrigger) Default() bool {
for _, trigger := range f {
if trigger.Default() {
return true
}
}
return false
}
169 changes: 169 additions & 0 deletions lib/automaticupgrades/maintenance/trigger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package maintenance

import (
"context"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc
func checkTraceError(check func(error) bool) require.ErrorAssertionFunc {
return func(t require.TestingT, err error, i ...interface{}) {
require.True(t, check(err), i...)
}
}

func TestFailoverTrigger_CanStart(t *testing.T) {
t.Parallel()

// Test setup
ctx := context.Background()
tests := []struct {
name string
triggers []Trigger
expectResult bool
expectErr require.ErrorAssertionFunc
}{
{
name: "nil",
triggers: nil,
expectResult: false,
expectErr: checkTraceError(trace.IsNotFound),
},
{
name: "empty",
triggers: []Trigger{},
expectResult: false,
expectErr: checkTraceError(trace.IsNotFound),
},
{
name: "first trigger success firing",
triggers: []Trigger{
StaticTrigger{canStart: true},
StaticTrigger{canStart: false},
},
expectResult: true,
expectErr: require.NoError,
},
{
name: "first trigger success not firing",
triggers: []Trigger{
StaticTrigger{canStart: false},
StaticTrigger{canStart: true},
},
expectResult: false,
expectErr: require.NoError,
},
{
name: "first trigger failure",
triggers: []Trigger{
StaticTrigger{err: trace.LimitExceeded("got rate-limited")},
StaticTrigger{canStart: true},
},
expectResult: false,
expectErr: checkTraceError(trace.IsLimitExceeded),
},
{
name: "first trigger skipped, second getter success",
triggers: []Trigger{
StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticTrigger{canStart: true},
},
expectResult: true,
expectErr: require.NoError,
},
{
name: "first trigger skipped, second getter failure",
triggers: []Trigger{
StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticTrigger{err: trace.LimitExceeded("got rate-limited")},
},
expectResult: false,
expectErr: checkTraceError(trace.IsLimitExceeded),
},
{
name: "first trigger skipped, second getter skipped",
triggers: []Trigger{
StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
},
expectResult: false,
expectErr: checkTraceError(trace.IsNotFound),
},
}
for _, tt := range tests {
t.Run(
tt.name, func(t *testing.T) {
// Test execution
trigger := FailoverTrigger(tt.triggers)
result, err := trigger.CanStart(ctx, nil)
require.Equal(t, tt.expectResult, result)
tt.expectErr(t, err)
},
)
}
}

func TestFailoverTrigger_Name(t *testing.T) {
tests := []struct {
name string
triggers []Trigger
expectResult string
}{
{
name: "nil",
triggers: nil,
expectResult: "",
},
{
name: "empty",
triggers: []Trigger{},
expectResult: "",
},
{
name: "one trigger",
triggers: []Trigger{
StaticTrigger{name: "proxy"},
},
expectResult: "proxy",
},
{
name: "two triggers",
triggers: []Trigger{
StaticTrigger{name: "proxy"},
StaticTrigger{name: "version-server"},
},
expectResult: "proxy, failover version-server",
},
}
for _, tt := range tests {
t.Run(
tt.name, func(t *testing.T) {
// Test execution
trigger := FailoverTrigger(tt.triggers)
result := trigger.Name()
require.Equal(t, tt.expectResult, result)
},
)
}
}
Loading
Loading