diff --git a/lib/backend/backend.go b/lib/backend/backend.go index a47a23d4eef7b..55c386944c7e2 100644 --- a/lib/backend/backend.go +++ b/lib/backend/backend.go @@ -20,12 +20,10 @@ package backend import ( - "bytes" "context" "fmt" "io" "sort" - "strings" "time" "github.com/google/uuid" @@ -219,7 +217,7 @@ type Watch struct { // String returns a user-friendly description // of the watcher func (w *Watch) String() string { - return fmt.Sprintf("Watcher(name=%v, prefixes=%v)", w.Name, string(bytes.Join(w.Prefixes, []byte(", ")))) + return fmt.Sprintf("Watcher(name=%v, prefixes=%v)", w.Name, w.Prefixes) } // Watcher returns watcher @@ -380,7 +378,7 @@ func (it Items) Swap(i, j int) { // Less is part of sort.Interface. func (it Items) Less(i, j int) bool { - return bytes.Compare(it[i].Key, it[j].Key) < 0 + return it[i].Key.Compare(it[j].Key) < 0 } // TTL returns TTL in duration units, rounds up to one second @@ -431,27 +429,6 @@ func (p earliest) Swap(i, j int) { p[i], p[j] = p[j], p[i] } -// Separator is used as a separator between key parts -const Separator = '/' - -// NewKey joins parts into path separated by Separator, -// makes sure path always starts with Separator ("/") -func NewKey(parts ...string) Key { - return internalKey("", parts...) -} - -// ExactKey is like Key, except a Separator is appended to the result -// path of Key. This is to ensure range matching of a path will only -// math child paths and not other paths that have the resulting path -// as a prefix. -func ExactKey(parts ...string) Key { - return append(NewKey(parts...), Separator) -} - -func internalKey(internalPrefix string, parts ...string) Key { - return Key(strings.Join(append([]string{internalPrefix}, parts...), string(Separator))) -} - // CreateRevision generates a new identifier to be used // as a resource revision. Backend implementations that provide // their own mechanism for versioning resources should be diff --git a/lib/backend/buffer.go b/lib/backend/buffer.go index 24211e0b2b1ea..9509a4b038138 100644 --- a/lib/backend/buffer.go +++ b/lib/backend/buffer.go @@ -322,7 +322,7 @@ type BufferWatcher struct { // String returns user-friendly representation // of the buffer watcher func (w *BufferWatcher) String() string { - return fmt.Sprintf("Watcher(name=%v, prefixes=%v, capacity=%v, size=%v)", w.Name, string(bytes.Join(w.Prefixes, []byte(", "))), w.capacity, len(w.eventsC)) + return fmt.Sprintf("Watcher(name=%v, prefixes=%v, capacity=%v, size=%v)", w.Name, w.Prefixes, w.capacity, len(w.eventsC)) } // Events returns events channel. This method performs internal work and should be re-called after each event diff --git a/lib/backend/dynamo/dynamodbbk.go b/lib/backend/dynamo/dynamodbbk.go index fb3232625f863..8a21cce32c256 100644 --- a/lib/backend/dynamo/dynamodbbk.go +++ b/lib/backend/dynamo/dynamodbbk.go @@ -19,7 +19,6 @@ package dynamo import ( - "bytes" "context" "errors" "net/http" @@ -638,7 +637,7 @@ func (b *Backend) CompareAndSwap(ctx context.Context, expected backend.Item, rep if len(replaceWith.Key) == 0 { return nil, trace.BadParameter("missing parameter Key") } - if !bytes.Equal(expected.Key, replaceWith.Key) { + if expected.Key.Compare(replaceWith.Key) != 0 { return nil, trace.BadParameter("expected and replaceWith keys should match") } diff --git a/lib/backend/etcdbk/etcd.go b/lib/backend/etcdbk/etcd.go index 64902e7c89f3d..0da60f497b6f4 100644 --- a/lib/backend/etcdbk/etcd.go +++ b/lib/backend/etcdbk/etcd.go @@ -20,7 +20,6 @@ package etcdbk import ( - "bytes" "context" "crypto/tls" "crypto/x509" @@ -797,7 +796,7 @@ func (b *EtcdBackend) CompareAndSwap(ctx context.Context, expected backend.Item, if len(replaceWith.Key) == 0 { return nil, trace.BadParameter("missing parameter Key") } - if !bytes.Equal(expected.Key, replaceWith.Key) { + if expected.Key.Compare(replaceWith.Key) != 0 { return nil, trace.BadParameter("expected and replaceWith keys should match") } var opts []clientv3.OpOption @@ -1108,7 +1107,7 @@ func fromType(eventType mvccpb.Event_EventType) types.OpType { } func (b *EtcdBackend) trimPrefix(in backend.Key) backend.Key { - return bytes.TrimPrefix(in, backend.Key(b.cfg.Key)) + return in.TrimPrefix(backend.Key(b.cfg.Key)) } func (b *EtcdBackend) prependPrefix(in backend.Key) string { diff --git a/lib/backend/firestore/firestorebk.go b/lib/backend/firestore/firestorebk.go index 509731211d8e2..8785fc5eee9dc 100644 --- a/lib/backend/firestore/firestorebk.go +++ b/lib/backend/firestore/firestorebk.go @@ -564,7 +564,7 @@ func (b *Backend) CompareAndSwap(ctx context.Context, expected backend.Item, rep if len(replaceWith.Key) == 0 { return nil, trace.BadParameter("missing parameter Key") } - if !bytes.Equal(expected.Key, replaceWith.Key) { + if expected.Key.Compare(replaceWith.Key) != 0 { return nil, trace.BadParameter("expected and replaceWith keys should match") } diff --git a/lib/backend/key.go b/lib/backend/key.go index 2ff85090ef9fd..4c7f25c604edc 100644 --- a/lib/backend/key.go +++ b/lib/backend/key.go @@ -16,5 +16,93 @@ package backend +import ( + "bytes" + "fmt" + "strings" +) + // Key is the unique identifier for an [Item]. -type Key = []byte +type Key []byte + +// Separator is used as a separator between key parts +const Separator = '/' + +// NewKey joins parts into path separated by Separator, +// makes sure path always starts with Separator ("/") +func NewKey(parts ...string) Key { + return internalKey("", parts...) +} + +// ExactKey is like Key, except a Separator is appended to the result +// path of Key. This is to ensure range matching of a path will only +// math child paths and not other paths that have the resulting path +// as a prefix. +func ExactKey(parts ...string) Key { + return append(NewKey(parts...), Separator) +} + +func internalKey(internalPrefix string, parts ...string) Key { + return Key(strings.Join(append([]string{internalPrefix}, parts...), string(Separator))) +} + +// String returns the textual representation of the key with +// each component concatenated together via the [Separator]. +func (k Key) String() string { + return string(k) +} + +// HasPrefix reports whether the key begins with prefix. +func (k Key) HasPrefix(prefix Key) bool { + return bytes.HasPrefix(k, prefix) +} + +// TrimPrefix returns the key without the provided leading prefix string. +// If the key doesn't start with prefix, it is returned unchanged. +func (k Key) TrimPrefix(prefix Key) Key { + return bytes.TrimPrefix(k, prefix) +} + +func (k Key) PrependPrefix(p Key) Key { + return append(p, k...) +} + +// HasSuffix reports whether the key ends with suffix. +func (k Key) HasSuffix(suffix Key) bool { + return bytes.HasSuffix(k, suffix) +} + +// TrimSuffix returns the key without the provided trailing suffix string. +// If the key doesn't end with suffix, it is returned unchanged. +func (k Key) TrimSuffix(suffix Key) Key { + return bytes.TrimSuffix(k, suffix) +} + +func (k Key) Components() [][]byte { + if len(k) == 0 { + return nil + } + + sep := []byte{Separator} + return bytes.Split(bytes.TrimPrefix(k, sep), sep) +} + +func (k Key) Compare(o Key) int { + return bytes.Compare(k, o) +} + +// Scan implement sql.Scanner, allowing a [Key] to +// be directly retrieved from sql backends without +// an intermediary object. +func (k *Key) Scan(scan any) error { + switch key := scan.(type) { + case []byte: + *k = bytes.Clone(key) + case string: + *k = []byte(strings.Clone(key)) + default: + return fmt.Errorf("invalid Key type %T", scan) + } + + return nil +} diff --git a/lib/backend/key_test.go b/lib/backend/key_test.go new file mode 100644 index 0000000000000..d554fb6922357 --- /dev/null +++ b/lib/backend/key_test.go @@ -0,0 +1,475 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package backend_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/gravitational/teleport/lib/backend" +) + +func TestKey(t *testing.T) { + + k1 := backend.Key("test") + k2 := backend.NewKey("test") + k3 := backend.ExactKey("test") + + assert.NotEqual(t, k1, k2) + assert.NotEqual(t, k2, k3) + assert.NotEqual(t, k1, k3) + + assert.Equal(t, "test", k1.String()) + assert.Equal(t, "/test", k2.String()) + assert.Equal(t, "/test/", k3.String()) +} + +func TestKeyString(t *testing.T) { + tests := []struct { + name string + expected string + key backend.Key + }{ + { + name: "empty key produces empty string", + }, + { + name: "empty new key produces empty string", + key: backend.NewKey(), + expected: "", + }, + { + name: "key with only empty string produces separator", + key: backend.NewKey(""), + expected: "/", + }, + { + name: "key with contents are separated", + key: backend.NewKey("foo", "bar", "baz", "quux"), + expected: "/foo/bar/baz/quux", + }, + { + name: "empty exact key produces separator", + key: backend.ExactKey(), + expected: "/", + }, + { + name: "empty string exact key produces double separator", + key: backend.ExactKey(""), + expected: "//", + }, + { + name: "exact key adds trailing separator", + key: backend.ExactKey("foo", "bar", "baz", "quux"), + expected: "/foo/bar/baz/quux/", + }, + { + name: "noend key", + key: backend.Key{0}, + expected: "\x00", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expected, test.key.String()) + + }) + } +} + +func TestKeyComponents(t *testing.T) { + tests := []struct { + name string + key backend.Key + expected [][]byte + }{ + { + name: "default value has zero components", + }, + { + name: "empty key has zero components", + key: backend.NewKey(), + }, + { + name: "empty exact key has empty component", + key: backend.ExactKey(), + expected: [][]byte{{}}, + }, + { + name: "single value key has a component", + key: backend.NewKey("alpha"), + expected: [][]byte{[]byte("alpha")}, + }, + { + name: "multiple components", + key: backend.NewKey("foo", "bar", "baz"), + expected: [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, + }, + { + name: "key without separator", + key: backend.Key("testing"), + expected: [][]byte{[]byte("testing")}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expected, test.key.Components()) + }) + } +} + +func TestKeyScan(t *testing.T) { + tests := []struct { + name string + scan any + expectedError string + expectedKey backend.Key + }{ + { + name: "invalid type int", + scan: 123, + expectedError: "invalid Key type int", + }, + { + name: "invalid type bool", + scan: false, + expectedError: "invalid Key type bool", + }, + { + name: "empty string key", + scan: "", + expectedKey: backend.Key{}, + }, + { + name: "empty byte slice key", + scan: []byte{}, + expectedKey: backend.Key{}, + }, + { + name: "populated string key", + scan: backend.NewKey("foo", "bar", "baz").String(), + expectedKey: backend.NewKey("foo", "bar", "baz"), + }, + { + name: "populated byte slice key", + scan: []byte(backend.NewKey("foo", "bar", "baz").String()), + expectedKey: backend.NewKey("foo", "bar", "baz"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + k := new(backend.Key) + err := k.Scan(test.scan) + if test.expectedError == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, test.expectedError) + } + assert.Equal(t, test.expectedKey, *k) + }) + } +} + +func TestKeyHasSuffix(t *testing.T) { + tests := []struct { + name string + key backend.Key + suffix backend.Key + assertion assert.BoolAssertionFunc + }{ + { + name: "default key has no suffixes", + suffix: backend.NewKey("test"), + assertion: assert.False, + }, + { + name: "default key is suffix", + assertion: assert.True, + }, + { + name: "prefix is not a suffix", + key: backend.NewKey("a", "b", "c"), + suffix: backend.NewKey("a", "b"), + assertion: assert.False, + }, + { + name: "empty suffix", + key: backend.NewKey("a", "b", "c"), + assertion: assert.True, + }, + { + name: "valid multi component suffix", + key: backend.NewKey("a", "b", "c"), + suffix: backend.NewKey("b", "c"), + assertion: assert.True, + }, + { + name: "valid single component suffix", + key: backend.NewKey("a", "b", "c"), + suffix: backend.NewKey("c"), + assertion: assert.True, + }, + { + name: "equivalent keys are suffix", + key: backend.NewKey("a", "b", "c"), + suffix: backend.NewKey("a", "b", "c"), + assertion: assert.True, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.assertion(t, test.key.HasSuffix(test.suffix)) + }) + } +} +func TestKeyHasPrefix(t *testing.T) { + tests := []struct { + name string + key backend.Key + prefix backend.Key + assertion assert.BoolAssertionFunc + }{ + { + name: "default key has no prexies", + prefix: backend.NewKey("test"), + assertion: assert.False, + }, + { + name: "default key is prefix", + assertion: assert.True, + }, + { + name: "suffix is not a prefix", + key: backend.NewKey("a", "b", "c"), + prefix: backend.NewKey("b", "c"), + assertion: assert.False, + }, + { + name: "empty prefix", + key: backend.NewKey("a", "b", "c"), + assertion: assert.True, + }, + { + name: "valid multi component prefix", + key: backend.NewKey("a", "b", "c"), + prefix: backend.NewKey("a", "b"), + assertion: assert.True, + }, + { + name: "valid single component prefix", + key: backend.NewKey("a", "b", "c"), + prefix: backend.NewKey("a"), + assertion: assert.True, + }, + { + name: "equivalent keys are prefix", + key: backend.NewKey("a", "b", "c"), + prefix: backend.NewKey("a", "b", "c"), + assertion: assert.True, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.assertion(t, test.key.HasPrefix(test.prefix)) + }) + } +} + +func TestKeyTrimSuffix(t *testing.T) { + tests := []struct { + name string + key backend.Key + trim backend.Key + expected backend.Key + }{ + { + name: "empty key trims nothing", + trim: backend.NewKey("a", "b"), + }, + { + name: "empty trim trims nothing", + key: backend.NewKey("a", "b"), + expected: backend.NewKey("a", "b"), + }, + { + name: "non-matching trim trims nothing", + key: backend.NewKey("a", "b"), + trim: backend.NewKey("c", "d"), + expected: backend.NewKey("a", "b"), + }, + { + name: "prefix trim trims nothing", + key: backend.NewKey("a", "b", "c"), + trim: backend.NewKey("a", "b"), + expected: backend.NewKey("a", "b", "c"), + }, + { + name: "all trimmed on exact match", + key: backend.NewKey("a", "b", "c"), + trim: backend.NewKey("a", "b", "c"), + expected: backend.NewKey(), + }, + { + name: "partial trim", + key: backend.NewKey("a", "b", "c"), + trim: backend.NewKey("b", "c"), + expected: backend.NewKey("a"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + trimmed := test.key.TrimSuffix(test.trim) + assert.Equal(t, test.expected, trimmed) + }) + } +} + +func TestKeyTrimPrefix(t *testing.T) { + tests := []struct { + name string + key backend.Key + trim backend.Key + expected backend.Key + }{ + { + name: "empty key trims nothing", + trim: backend.NewKey("a", "b"), + }, + { + name: "empty trim trims nothing", + key: backend.NewKey("a", "b"), + expected: backend.NewKey("a", "b"), + }, + { + name: "non-matching trim trims nothing", + key: backend.NewKey("a", "b"), + trim: backend.NewKey("c", "d"), + expected: backend.NewKey("a", "b"), + }, + { + name: "suffix trim trims nothing", + key: backend.NewKey("a", "b", "c"), + trim: backend.NewKey("b", "c"), + expected: backend.NewKey("a", "b", "c"), + }, + { + name: "all trimmed on exact match", + key: backend.NewKey("a", "b", "c"), + trim: backend.NewKey("a", "b", "c"), + expected: backend.NewKey(), + }, + { + name: "partial trim", + key: backend.NewKey("a", "b", "c"), + trim: backend.NewKey("a", "b"), + expected: backend.NewKey("c"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + trimmed := test.key.TrimPrefix(test.trim) + assert.Equal(t, test.expected, trimmed) + }) + } +} + +func TestKeyPrependPrefix(t *testing.T) { + tests := []struct { + name string + key backend.Key + prefix backend.Key + expected backend.Key + }{ + { + name: "empty prefix is noop", + key: backend.NewKey("a", "b"), + expected: backend.NewKey("a", "b"), + }, + { + name: "empty key is prefixed", + prefix: backend.NewKey("a", "b"), + expected: backend.NewKey("a", "b"), + }, + { + name: "prefix applied", + key: backend.NewKey("a", "b"), + prefix: backend.NewKey("1", "2"), + expected: backend.NewKey("1", "2", "a", "b"), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + prefixed := test.key.PrependPrefix(test.prefix) + assert.Equal(t, test.expected, prefixed) + }) + } +} + +func TestKeyCompare(t *testing.T) { + tests := []struct { + name string + key backend.Key + other backend.Key + expected int + }{ + { + name: "equal keys", + key: backend.NewKey("a", "b", "c"), + other: backend.NewKey("a", "b", "c"), + expected: 0, + }, + { + name: "less", + key: backend.NewKey("a", "b", "c"), + other: backend.NewKey("a", "b", "d"), + expected: -1, + }, + { + name: "greater", + key: backend.NewKey("d", "b", "c"), + other: backend.NewKey("a", "b"), + expected: 1, + }, + { + name: "empty key is always less", + other: backend.NewKey("a", "b"), + expected: -1, + }, + { + name: "key is always greater than empty", + key: backend.NewKey("a", "b"), + expected: 1, + }, + { + name: "empty keys are equal", + expected: 0, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expected, test.key.Compare(test.other)) + }) + } +} diff --git a/lib/backend/lite/lite.go b/lib/backend/lite/lite.go index 0723cf37eb29f..b2ff52eb25e26 100644 --- a/lib/backend/lite/lite.go +++ b/lib/backend/lite/lite.go @@ -437,7 +437,7 @@ func (l *Backend) CompareAndSwap(ctx context.Context, expected backend.Item, rep if len(replaceWith.Key) == 0 { return nil, trace.BadParameter("missing parameter Key") } - if !bytes.Equal(expected.Key, replaceWith.Key) { + if expected.Key.Compare(replaceWith.Key) != 0 { return nil, trace.BadParameter("expected and replaceWith keys should match") } diff --git a/lib/backend/memory/item.go b/lib/backend/memory/item.go index af4f3df8dea18..3713508855442 100644 --- a/lib/backend/memory/item.go +++ b/lib/backend/memory/item.go @@ -19,8 +19,6 @@ package memory import ( - "bytes" - "github.com/google/btree" "github.com/gravitational/teleport/lib/backend" @@ -39,7 +37,7 @@ type btreeItem struct { func (i *btreeItem) Less(iother btree.Item) bool { switch other := iother.(type) { case *btreeItem: - return bytes.Compare(i.Key, other.Key) < 0 + return i.Key.Compare(other.Key) < 0 case *prefixItem: return !iother.Less(i) default: @@ -56,5 +54,5 @@ type prefixItem struct { // Less is used for Btree operations func (p *prefixItem) Less(iother btree.Item) bool { other := iother.(*btreeItem) - return !bytes.HasPrefix(other.Key, p.prefix) + return !other.Key.HasPrefix(p.prefix) } diff --git a/lib/backend/memory/memory.go b/lib/backend/memory/memory.go index 52ac2d01d1658..89c2ae33bb000 100644 --- a/lib/backend/memory/memory.go +++ b/lib/backend/memory/memory.go @@ -350,7 +350,7 @@ func (m *Memory) CompareAndSwap(ctx context.Context, expected backend.Item, repl if len(replaceWith.Key) == 0 { return nil, trace.BadParameter("missing parameter Key") } - if !bytes.Equal(expected.Key, replaceWith.Key) { + if expected.Key.Compare(replaceWith.Key) != 0 { return nil, trace.BadParameter("expected and replaceWith keys should match") } m.Lock() diff --git a/lib/backend/pgbk/atomicwrite.go b/lib/backend/pgbk/atomicwrite.go index 78d08e0b648b8..aaaada25d68a9 100644 --- a/lib/backend/pgbk/atomicwrite.go +++ b/lib/backend/pgbk/atomicwrite.go @@ -50,12 +50,12 @@ func (b *Backend) AtomicWrite(ctx context.Context, condacts []backend.Conditiona case backend.KindExists: condBatchItems = append(condBatchItems, batchItem{ "SELECT EXISTS (SELECT FROM kv WHERE key = $1 AND (expires IS NULL OR expires >= now()))", - []any{nonNil(ca.Key)}, + []any{nonNilKey(ca.Key)}, }) case backend.KindNotExists: condBatchItems = append(condBatchItems, batchItem{ "SELECT NOT EXISTS (SELECT FROM kv WHERE key = $1 AND (expires IS NULL OR expires >= now()))", - []any{nonNil(ca.Key)}, + []any{nonNilKey(ca.Key)}, }) case backend.KindRevision: expectedRevision, ok := revisionFromString(ca.Condition.Revision) @@ -64,7 +64,7 @@ func (b *Backend) AtomicWrite(ctx context.Context, condacts []backend.Conditiona } condBatchItems = append(condBatchItems, batchItem{ "SELECT EXISTS (SELECT FROM kv WHERE key = $1 AND revision = $2 AND (expires IS NULL OR expires >= now()))", - []any{nonNil(ca.Key), expectedRevision}, + []any{nonNilKey(ca.Key), expectedRevision}, }) default: // condacts was already checked for validity @@ -80,12 +80,12 @@ func (b *Backend) AtomicWrite(ctx context.Context, condacts []backend.Conditiona "INSERT INTO kv (key, value, expires, revision) VALUES ($1, $2, $3, $4)" + " ON CONFLICT (key) DO UPDATE SET" + " value = excluded.value, expires = excluded.expires, revision = excluded.revision", - []any{nonNil(ca.Key), nonNil(ca.Action.Item.Value), zeronull.Timestamptz(ca.Action.Item.Expires.UTC()), newRevision}, + []any{nonNilKey(ca.Key), nonNil(ca.Action.Item.Value), zeronull.Timestamptz(ca.Action.Item.Expires.UTC()), newRevision}, }) case backend.KindDelete: actBatchItems = append(actBatchItems, batchItem{ "DELETE FROM kv WHERE kv.key = $1 AND (kv.expires IS NULL OR kv.expires > now())", - []any{nonNil(ca.Key)}, + []any{nonNilKey(ca.Key)}, }) default: // condacts was already checked for validity diff --git a/lib/backend/pgbk/pgbk.go b/lib/backend/pgbk/pgbk.go index 5df4a2b4af901..1729f6d69c4e3 100644 --- a/lib/backend/pgbk/pgbk.go +++ b/lib/backend/pgbk/pgbk.go @@ -19,7 +19,6 @@ package pgbk import ( - "bytes" "context" "errors" "log/slog" @@ -259,7 +258,7 @@ func (b *Backend) Create(ctx context.Context, i backend.Item) (*backend.Lease, e " ON CONFLICT (key) DO UPDATE SET"+ " value = excluded.value, expires = excluded.expires, revision = excluded.revision"+ " WHERE kv.expires IS NOT NULL AND kv.expires <= now()", - nonNil(i.Key), nonNil(i.Value), zeronull.Timestamptz(i.Expires), revision) + nonNilKey(i.Key), nonNil(i.Value), zeronull.Timestamptz(i.Expires), revision) if err != nil { return false, trace.Wrap(err) } @@ -286,7 +285,7 @@ func (b *Backend) Put(ctx context.Context, i backend.Item) (*backend.Lease, erro "INSERT INTO kv (key, value, expires, revision) VALUES ($1, $2, $3, $4)"+ " ON CONFLICT (key) DO UPDATE SET"+ " value = excluded.value, expires = excluded.expires, revision = excluded.revision", - nonNil(i.Key), nonNil(i.Value), zeronull.Timestamptz(i.Expires), revision) + nonNilKey(i.Key), nonNil(i.Value), zeronull.Timestamptz(i.Expires), revision) return struct{}{}, trace.Wrap(err) }); err != nil { return nil, trace.Wrap(err) @@ -298,7 +297,7 @@ func (b *Backend) Put(ctx context.Context, i backend.Item) (*backend.Lease, erro // CompareAndSwap implements [backend.Backend]. func (b *Backend) CompareAndSwap(ctx context.Context, expected, replaceWith backend.Item) (*backend.Lease, error) { - if !bytes.Equal(expected.Key, replaceWith.Key) { + if expected.Key.Compare(replaceWith.Key) != 0 { return nil, trace.BadParameter("expected and replaceWith keys should match") } @@ -309,7 +308,7 @@ func (b *Backend) CompareAndSwap(ctx context.Context, expected, replaceWith back "UPDATE kv SET value = $1, expires = $2, revision = $3"+ " WHERE kv.key = $4 AND kv.value = $5 AND (kv.expires IS NULL OR kv.expires > now())", nonNil(replaceWith.Value), zeronull.Timestamptz(replaceWith.Expires), revision, - nonNil(replaceWith.Key), nonNil(expected.Value)) + nonNilKey(replaceWith.Key), nonNil(expected.Value)) if err != nil { return false, trace.Wrap(err) } @@ -335,7 +334,7 @@ func (b *Backend) Update(ctx context.Context, i backend.Item) (*backend.Lease, e tag, err := b.pool.Exec(ctx, "UPDATE kv SET value = $1, expires = $2, revision = $3"+ " WHERE kv.key = $4 AND (kv.expires IS NULL OR kv.expires > now())", - nonNil(i.Value), zeronull.Timestamptz(i.Expires), revision, nonNil(i.Key)) + nonNil(i.Value), zeronull.Timestamptz(i.Expires), revision, nonNilKey(i.Key)) if err != nil { return false, trace.Wrap(err) } @@ -367,7 +366,7 @@ func (b *Backend) ConditionalUpdate(ctx context.Context, i backend.Item) (*backe "WHERE kv.key = $4 AND kv.revision = $5 AND "+ "(kv.expires IS NULL OR kv.expires > now())", nonNil(i.Value), zeronull.Timestamptz(i.Expires), newRevision, - nonNil(i.Key), expectedRevision) + nonNilKey(i.Key), expectedRevision) if err != nil { return false, trace.Wrap(err) } @@ -394,7 +393,7 @@ func (b *Backend) Get(ctx context.Context, key backend.Key) (*backend.Item, erro var item *backend.Item batch.Queue("SELECT kv.value, kv.expires, kv.revision FROM kv"+ - " WHERE kv.key = $1 AND (kv.expires IS NULL OR kv.expires > now())", nonNil(key), + " WHERE kv.key = $1 AND (kv.expires IS NULL OR kv.expires > now())", nonNilKey(key), ).QueryRow(func(row pgx.Row) error { var value []byte var expires time.Time @@ -449,7 +448,7 @@ func (b *Backend) GetRange(ctx context.Context, startKey, endKey backend.Key, li "SELECT kv.key, kv.value, kv.expires, kv.revision FROM kv"+ " WHERE kv.key BETWEEN $1 AND $2 AND (kv.expires IS NULL OR kv.expires > now())"+ " ORDER BY kv.key LIMIT $3", - nonNil(startKey), nonNil(endKey), limit, + nonNilKey(startKey), nonNilKey(endKey), limit, ).Query(func(rows pgx.Rows) error { var err error items, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (backend.Item, error) { @@ -486,7 +485,7 @@ func (b *Backend) GetRange(ctx context.Context, startKey, endKey backend.Key, li func (b *Backend) Delete(ctx context.Context, key backend.Key) error { deleted, err := pgcommon.Retry(ctx, b.log, func() (bool, error) { tag, err := b.pool.Exec(ctx, - "DELETE FROM kv WHERE kv.key = $1 AND (kv.expires IS NULL OR kv.expires > now())", nonNil(key)) + "DELETE FROM kv WHERE kv.key = $1 AND (kv.expires IS NULL OR kv.expires > now())", nonNilKey(key)) if err != nil { return false, trace.Wrap(err) } @@ -512,7 +511,7 @@ func (b *Backend) ConditionalDelete(ctx context.Context, key backend.Key, rev st tag, err := b.pool.Exec(ctx, "DELETE FROM kv WHERE kv.key = $1 AND kv.revision = $2 AND "+ "(kv.expires IS NULL OR kv.expires > now())", - nonNil(key), expectedRevision) + nonNilKey(key), expectedRevision) if err != nil { return false, trace.Wrap(err) } @@ -538,7 +537,7 @@ func (b *Backend) DeleteRange(ctx context.Context, startKey, endKey backend.Key) if _, err := pgcommon.Retry(ctx, b.log, func() (struct{}, error) { _, err := b.pool.Exec(ctx, "DELETE FROM kv WHERE kv.key BETWEEN $1 AND $2", - nonNil(startKey), nonNil(endKey), + nonNilKey(startKey), nonNilKey(endKey), ) return struct{}{}, trace.Wrap(err) }); err != nil { @@ -555,7 +554,7 @@ func (b *Backend) KeepAlive(ctx context.Context, lease backend.Lease, expires ti tag, err := b.pool.Exec(ctx, "UPDATE kv SET expires = $1, revision = $2"+ " WHERE kv.key = $3 AND (kv.expires IS NULL OR kv.expires > now())", - zeronull.Timestamptz(expires.UTC()), revision, nonNil(lease.Key)) + zeronull.Timestamptz(expires.UTC()), revision, nonNilKey(lease.Key)) if err != nil { return false, trace.Wrap(err) } diff --git a/lib/backend/pgbk/utils.go b/lib/backend/pgbk/utils.go index 68ba291199fd2..00047015a1326 100644 --- a/lib/backend/pgbk/utils.go +++ b/lib/backend/pgbk/utils.go @@ -20,6 +20,8 @@ package pgbk import ( "github.com/google/uuid" + + "github.com/gravitational/teleport/lib/backend" ) // revision is transparently converted to and from Postgres UUIDs. @@ -46,6 +48,15 @@ func revisionFromString(s string) (r revision, ok bool) { return u, true } +// nonNilKey replaces an empty key with a non-nil one. +func nonNilKey(b backend.Key) []byte { + if b == nil { + return []byte{} + } + + return []byte(b.String()) +} + // nonNil replaces a nil slice with an empty, non-nil one. func nonNil(b []byte) []byte { if b == nil { diff --git a/lib/services/local/access.go b/lib/services/local/access.go index c302b9e48197d..915ccd94d2eba 100644 --- a/lib/services/local/access.go +++ b/lib/services/local/access.go @@ -19,7 +19,6 @@ package local import ( - "bytes" "context" "strings" "time" @@ -113,7 +112,7 @@ func (s *AccessService) ListRoles(ctx context.Context, req *proto.ListRolesReque return true, nil } - if !bytes.HasSuffix(item.Key, []byte(paramsPrefix)) { + if !item.Key.HasSuffix(backend.Key(paramsPrefix)) { // Item represents a different resource type in the // same namespace. continue @@ -388,13 +387,13 @@ func (s *AccessService) ReplaceRemoteLocks(ctx context.Context, clusterName stri Expires: lock.Expiry(), Revision: rev, } - newRemoteLocksToStore[string(item.Key)] = item + newRemoteLocksToStore[item.Key.String()] = item } for _, origLockItem := range origRemoteLocks.Items { // If one of the new remote locks to store is already known, // perform a CompareAndSwap. - key := string(origLockItem.Key) + key := origLockItem.Key.String() if newLockItem, ok := newRemoteLocksToStore[key]; ok { if _, err := s.CompareAndSwap(ctx, origLockItem, newLockItem); err != nil { return trace.Wrap(err) diff --git a/lib/services/local/dynamic_access.go b/lib/services/local/dynamic_access.go index 0494385fcab7d..8f61c378bbff5 100644 --- a/lib/services/local/dynamic_access.go +++ b/lib/services/local/dynamic_access.go @@ -19,7 +19,6 @@ package local import ( - "bytes" "context" "slices" "time" @@ -261,7 +260,7 @@ func (s *DynamicAccessService) GetAccessRequests(ctx context.Context, filter typ } var requests []types.AccessRequest for _, item := range result.Items { - if !bytes.HasSuffix(item.Key, []byte(paramsPrefix)) { + if !item.Key.HasSuffix(backend.Key(paramsPrefix)) { // Item represents a different resource type in the // same namespace. continue @@ -341,7 +340,7 @@ func (s *DynamicAccessService) ListAccessRequests(ctx context.Context, req *prot return true, nil } - if !bytes.HasSuffix(item.Key, []byte(paramsPrefix)) { + if !item.Key.HasSuffix(backend.Key(paramsPrefix)) { // Item represents a different resource type in the // same namespace. continue diff --git a/lib/services/local/events.go b/lib/services/local/events.go index b906d264d580a..af4ce9f862189 100644 --- a/lib/services/local/events.go +++ b/lib/services/local/events.go @@ -19,9 +19,7 @@ package local import ( - "bytes" "context" - "strings" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -382,7 +380,7 @@ func (p baseParser) prefixes() []backend.Key { func (p baseParser) match(key backend.Key) bool { for _, prefix := range p.matchPrefixes { - if bytes.HasPrefix(key, prefix) { + if key.HasPrefix(prefix) { return true } } @@ -721,8 +719,8 @@ func (p *namespaceParser) match(key backend.Key) bool { // namespaces are stored under key '/namespaces//params' // and this code matches similar pattern return p.baseParser.match(key) && - bytes.HasSuffix(key, []byte(paramsPrefix)) && - bytes.Count(key, []byte{backend.Separator}) == 3 + key.HasSuffix(backend.Key(paramsPrefix)) && + len(key.Components()) == 3 } func (p *namespaceParser) parse(event backend.Event) (types.Resource, error) { @@ -794,10 +792,10 @@ func (p *accessRequestParser) prefixes() []backend.Key { } func (p *accessRequestParser) match(key backend.Key) bool { - if !bytes.HasPrefix(key, p.matchPrefix) { + if !key.HasPrefix(p.matchPrefix) { return false } - if !bytes.HasSuffix(key, p.matchSuffix) { + if !key.HasSuffix(p.matchSuffix) { return false } return true @@ -835,8 +833,8 @@ func (p *userParser) match(key backend.Key) bool { // users are stored under key '/web/users//params' // and this code matches similar pattern return p.baseParser.match(key) && - bytes.HasSuffix(key, []byte(paramsPrefix)) && - bytes.Count(key, []byte{backend.Separator}) == 4 + key.HasSuffix(backend.Key(paramsPrefix)) && + len(key.Components()) == 4 } func (p *userParser) parse(event backend.Event) (types.Resource, error) { @@ -1368,7 +1366,7 @@ func (p *remoteClusterParser) prefixes() []backend.Key { } func (p *remoteClusterParser) match(key backend.Key) bool { - return bytes.HasPrefix(key, p.matchPrefix) + return key.HasPrefix(p.matchPrefix) } func (p *remoteClusterParser) parse(event backend.Event) (types.Resource, error) { @@ -1429,7 +1427,7 @@ func (p *networkRestrictionsParser) prefixes() []backend.Key { } func (p *networkRestrictionsParser) match(key backend.Key) bool { - return bytes.HasPrefix(key, p.matchPrefix) + return key.HasPrefix(p.matchPrefix) } func (p *networkRestrictionsParser) parse(event backend.Event) (types.Resource, error) { @@ -1959,25 +1957,19 @@ type kubeWaitingContainerParser struct { func (p *kubeWaitingContainerParser) parse(event backend.Event) (types.Resource, error) { switch event.Type { case types.OpDelete: - // remove the first separator so no separated parts should be - // empty strings - key := string(event.Item.Key) - if len(key) > 0 && key[0] == backend.Separator { - key = key[1:] - } - parts := strings.Split(key, string(backend.Separator)) + parts := event.Item.Key.Components() if len(parts) != 6 { return nil, trace.BadParameter("malformed key for %s event: %s", types.KindKubeWaitingContainer, event.Item.Key) } resource, err := kubewaitingcontainer.NewKubeWaitingContainer( - parts[5], + string(parts[5]), &kubewaitingcontainerpb.KubernetesWaitingContainerSpec{ - Username: parts[1], - Cluster: parts[2], - Namespace: parts[3], - PodName: parts[4], - ContainerName: parts[5], + Username: string(parts[1]), + Cluster: string(parts[2]), + Namespace: string(parts[3]), + PodName: string(parts[4]), + ContainerName: string(parts[5]), Patch: []byte("{}"), // default to empty patch. It doesn't matter for delete ops. PatchType: kubewaitingcontainer.JSONPatchType, // default to JSON patch. It doesn't matter for delete ops. }, @@ -2042,11 +2034,7 @@ type userNotificationParser struct { func (p *userNotificationParser) parse(event backend.Event) (types.Resource, error) { switch event.Type { case types.OpDelete: - // Remove the first separator so none of the separated parts will be - // empty strings - key := string(event.Item.Key) - key = strings.TrimPrefix(key, string(backend.Separator)) - parts := strings.Split(key, string(backend.Separator)) + parts := event.Item.Key.Components() if len(parts) != 4 { return nil, trace.BadParameter("malformed key for %s event: %s", types.KindNotification, event.Item.Key) } @@ -2055,10 +2043,10 @@ func (p *userNotificationParser) parse(event backend.Event) (types.Resource, err Kind: types.KindNotification, Version: types.V1, Spec: ¬ificationsv1.NotificationSpec{ - Username: parts[2], + Username: string(parts[2]), }, Metadata: &headerv1.Metadata{ - Name: parts[3], + Name: string(parts[3]), }, } @@ -2090,12 +2078,7 @@ type globalNotificationParser struct { func (p *globalNotificationParser) parse(event backend.Event) (types.Resource, error) { switch event.Type { case types.OpDelete: - // Remove the first separator so none of the separated parts will be - // empty strings - key := string(event.Item.Key) - key = strings.TrimPrefix(key, string(backend.Separator)) - // notifications/global/ - parts := strings.Split(key, string(backend.Separator)) + parts := event.Item.Key.Components() if len(parts) != 3 { return nil, trace.BadParameter("malformed key for %s event: %s", types.KindGlobalNotification, event.Item.Key) } @@ -2109,7 +2092,7 @@ func (p *globalNotificationParser) parse(event backend.Event) (types.Resource, e }, }, Metadata: &headerv1.Metadata{ - Name: parts[2], + Name: string(parts[2]), }, } @@ -2141,13 +2124,7 @@ type botInstanceParser struct { func (p *botInstanceParser) parse(event backend.Event) (types.Resource, error) { switch event.Type { case types.OpDelete: - // Remove the first separator so none of the separated parts will be - // empty strings - key := string(event.Item.Key) - key = strings.TrimPrefix(key, string(backend.Separator)) - - // bot_instance/<1: bot name>/<2: uuid> - parts := strings.Split(key, string(backend.Separator)) + parts := event.Item.Key.Components() if len(parts) != 3 { return nil, trace.BadParameter("malformed key for %s event: %s", types.KindBotInstance, event.Item.Key) } @@ -2156,11 +2133,11 @@ func (p *botInstanceParser) parse(event backend.Event) (types.Resource, error) { Kind: types.KindBotInstance, Version: types.V1, Spec: &machineidv1.BotInstanceSpec{ - BotName: parts[1], - InstanceId: parts[2], + BotName: string(parts[1]), + InstanceId: string(parts[2]), }, Metadata: &headerv1.Metadata{ - Name: parts[2], + Name: string(parts[2]), }, } @@ -2242,7 +2219,7 @@ func resourceHeader(event backend.Event, kind, version string, offset int) (type Kind: kind, Version: version, Metadata: types.Metadata{ - Name: string(name), + Name: name, Namespace: apidefaults.Namespace, }, }, nil @@ -2258,7 +2235,7 @@ func resourceHeaderWithTemplate(event backend.Event, hdr types.ResourceHeader, o SubKind: hdr.SubKind, Version: hdr.Version, Metadata: types.Metadata{ - Name: string(name), + Name: name, Namespace: apidefaults.Namespace, }, }, nil @@ -2322,7 +2299,7 @@ func (p *deviceParser) parse(event backend.Event) (types.Resource, error) { Kind: types.KindDevice, Version: types.V1, Metadata: types.Metadata{ - Name: string(name), + Name: name, }, }, } @@ -2449,19 +2426,19 @@ type EventMatcher interface { Match(types.Event) (types.Resource, error) } -// base returns last element delimited by separator, index is -// is an index of the key part to get counting from the end -func base(key backend.Key, offset int) ([]byte, error) { - parts := bytes.Split(key, []byte{backend.Separator}) +// base returns the key component that is offset +// components before the last component. +func base(key backend.Key, offset int) (string, error) { + parts := key.Components() if len(parts) < offset+1 { - return nil, trace.NotFound("failed parsing %v", string(key)) + return "", trace.NotFound("failed parsing %v", string(key)) } - return parts[len(parts)-offset-1], nil + return string(parts[len(parts)-offset-1]), nil } // baseTwoKeys returns two last keys func baseTwoKeys(key backend.Key) (string, string, error) { - parts := bytes.Split(key, []byte{backend.Separator}) + parts := key.Components() if len(parts) < 2 { return "", "", trace.NotFound("failed parsing %v", string(key)) } diff --git a/lib/services/local/generic/generic.go b/lib/services/local/generic/generic.go index 2ed1bf9d2ea47..4d45dd5a21cd7 100644 --- a/lib/services/local/generic/generic.go +++ b/lib/services/local/generic/generic.go @@ -212,7 +212,7 @@ func (s *Service[T]) listResourcesReturnNextResourceWithKey(ctx context.Context, next = &out[pageSize] // Truncate the last item that was used to determine next row existence. out = out[:pageSize] - nextKey = trimLastKey(string(lastKey), s.backendPrefix) + nextKey = trimLastKey(lastKey.String(), s.backendPrefix) } return out, next, nextKey, nil @@ -259,7 +259,7 @@ func (s *Service[T]) ListResourcesWithFilter(ctx context.Context, pageSize int, var nextKey string if len(resources) > pageSize { - nextKey = trimLastKey(string(lastKey), s.backendPrefix) + nextKey = trimLastKey(lastKey.String(), s.backendPrefix) // Truncate the last item that was used to determine next row existence. resources = resources[:pageSize] } diff --git a/lib/services/local/plugin_data.go b/lib/services/local/plugin_data.go index 03b5b898e66ba..34cc5934fc774 100644 --- a/lib/services/local/plugin_data.go +++ b/lib/services/local/plugin_data.go @@ -19,7 +19,6 @@ package local import ( - "bytes" "context" "time" @@ -96,7 +95,7 @@ func (p *PluginDataService) getPluginData(ctx context.Context, filter types.Plug } var matches []types.PluginData for _, item := range result.Items { - if !bytes.HasSuffix(item.Key, []byte(paramsPrefix)) { + if !item.Key.HasSuffix(backend.Key(paramsPrefix)) { // Item represents a different resource type in the // same namespace. continue @@ -248,7 +247,7 @@ func itemToPluginData(item backend.Item) (types.PluginData, error) { return data, nil } -func pluginDataKey(kind string, name string) []byte { +func pluginDataKey(kind string, name string) backend.Key { return backend.NewKey(pluginDataPrefix, kind, name, paramsPrefix) } diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index e80735e145c47..dbc03946e20a9 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -19,7 +19,6 @@ package local import ( - "bytes" "context" "sort" "time" @@ -79,7 +78,7 @@ func (s *PresenceService) GetNamespaces() ([]types.Namespace, error) { } out := make([]types.Namespace, 0, len(result.Items)) for _, item := range result.Items { - if !bytes.HasSuffix(item.Key, backend.Key(paramsPrefix)) { + if !item.Key.HasSuffix(backend.Key(paramsPrefix)) { continue } ns, err := services.UnmarshalNamespace( diff --git a/lib/services/local/resource.go b/lib/services/local/resource.go index f820d603efb7e..8cad7aacf993e 100644 --- a/lib/services/local/resource.go +++ b/lib/services/local/resource.go @@ -22,6 +22,7 @@ import ( "context" "encoding/json" "strings" + "unicode/utf8" "github.com/gravitational/trace" @@ -428,20 +429,21 @@ func itemFromLock(l types.Lock) (*backend.Item, error) { // has order N cost. // fullUsersPrefix is the entire string preceding the name of a user in a key -var fullUsersPrefix = string(backend.NewKey(webPrefix, usersPrefix)) + "/" +var fullUsersPrefix = backend.ExactKey(webPrefix, usersPrefix) // splitUsernameAndSuffix is a helper for extracting usernames and suffixes from // backend key values. -func splitUsernameAndSuffix(key string) (name string, suffix string, err error) { - if !strings.HasPrefix(key, fullUsersPrefix) { +func splitUsernameAndSuffix(key backend.Key) (name string, suffix string, err error) { + if !key.HasPrefix(fullUsersPrefix) { return "", "", trace.BadParameter("expected format '%s//', got '%s'", fullUsersPrefix, key) } - key = strings.TrimPrefix(key, fullUsersPrefix) - idx := strings.Index(key, "/") - if idx < 1 || idx >= len(key) { + k := key.TrimPrefix(fullUsersPrefix) + + components := k.Components() + if len(components) < 2 { return "", "", trace.BadParameter("expected format /, got %q", key) } - return key[:idx], key[idx+1:], nil + return string(components[0]), k.String()[len(components[0])+utf8.RuneLen(backend.Separator):], nil } // collectUserItems handles the case where multiple items pertain to the same user resource. @@ -450,12 +452,11 @@ func splitUsernameAndSuffix(key string) (name string, suffix string, err error) func collectUserItems(items []backend.Item) (users map[string]userItems, rem []backend.Item, err error) { users = make(map[string]userItems) for _, item := range items { - key := string(item.Key) - if !strings.HasPrefix(key, fullUsersPrefix) { + if !item.Key.HasPrefix(fullUsersPrefix) { rem = append(rem, item) continue } - name, suffix, err := splitUsernameAndSuffix(key) + name, suffix, err := splitUsernameAndSuffix(item.Key) if err != nil { return nil, nil, err } diff --git a/lib/services/local/trust.go b/lib/services/local/trust.go index c358bb9e9635a..5ceaf0d3c1545 100644 --- a/lib/services/local/trust.go +++ b/lib/services/local/trust.go @@ -203,7 +203,7 @@ func (s *CA) DeleteCertAuthorities(ctx context.Context, ids ...types.CertAuthID) if err := id.Check(); err != nil { return trace.Wrap(err) } - for _, key := range [][]byte{activeCAKey(id), inactiveCAKey(id)} { + for _, key := range []backend.Key{activeCAKey(id), inactiveCAKey(id)} { condacts = append(condacts, backend.ConditionalAction{ Key: key, Condition: backend.Whatever(), diff --git a/lib/services/local/users.go b/lib/services/local/users.go index 2fbab50339bed..cd3ca3c892041 100644 --- a/lib/services/local/users.go +++ b/lib/services/local/users.go @@ -151,7 +151,7 @@ func (s *IdentityService) ListUsers(ctx context.Context, req *userspb.ListUsersR // the next user in the list while still allowing listing to operate // without missing any users. func nextUserToken(user types.User) string { - return string(backend.RangeEnd(backend.ExactKey(user.GetName())))[utf8.RuneLen(backend.Separator):] + return backend.RangeEnd(backend.ExactKey(user.GetName())).String()[utf8.RuneLen(backend.Separator):] } // streamUsersWithSecrets is a helper that converts a stream of backend items over the user key range into a stream @@ -165,7 +165,7 @@ func (s *IdentityService) streamUsersWithSecrets(itemStream stream.Stream[backen var current collector collectorStream := stream.FilterMap(itemStream, func(item backend.Item) (collector, bool) { - name, suffix, err := splitUsernameAndSuffix(string(item.Key)) + name, suffix, err := splitUsernameAndSuffix(item.Key) if err != nil { s.log.Warnf("Failed to extract name/suffix for user item at %q: %v", item.Key, err) return collector{}, false @@ -223,7 +223,7 @@ func (s *IdentityService) streamUsersWithSecrets(itemStream stream.Stream[backen func (s *IdentityService) streamUsersWithoutSecrets(itemStream stream.Stream[backend.Item]) stream.Stream[*types.UserV2] { suffix := backend.Key(paramsPrefix) userStream := stream.FilterMap(itemStream, func(item backend.Item) (*types.UserV2, bool) { - if !bytes.HasSuffix(item.Key, suffix) { + if !item.Key.HasSuffix(suffix) { return nil, false } @@ -251,7 +251,7 @@ func (s *IdentityService) GetUsers(ctx context.Context, withSecrets bool) ([]typ } var out []types.User for _, item := range result.Items { - if !bytes.HasSuffix(item.Key, backend.Key(paramsPrefix)) { + if !item.Key.HasSuffix(backend.Key(paramsPrefix)) { continue } u, err := services.UnmarshalUser( @@ -584,8 +584,8 @@ func (s *IdentityService) getUserWithSecrets(ctx context.Context, user string) ( var items userItems for _, item := range result.Items { - suffix := bytes.TrimPrefix(item.Key, startKey) - items.Set(string(suffix), item) // Result of Set i + suffix := item.Key.TrimPrefix(startKey) + items.Set(suffix.String(), item) // Result of Set i } u, err := userFromUserItems(user, items) diff --git a/lib/services/local/usertoken.go b/lib/services/local/usertoken.go index 5a56adb739636..948206dbbac16 100644 --- a/lib/services/local/usertoken.go +++ b/lib/services/local/usertoken.go @@ -19,7 +19,6 @@ package local import ( - "bytes" "context" "github.com/gravitational/trace" @@ -39,7 +38,7 @@ func (s *IdentityService) GetUserTokens(ctx context.Context) ([]types.UserToken, var tokens []types.UserToken for _, item := range result.Items { - if !bytes.HasSuffix(item.Key, []byte(paramsPrefix)) { + if !item.Key.HasSuffix(backend.Key(paramsPrefix)) { continue } diff --git a/lib/services/unified_resource.go b/lib/services/unified_resource.go index 1d7aa928b9efa..6789f1671c776 100644 --- a/lib/services/unified_resource.go +++ b/lib/services/unified_resource.go @@ -19,7 +19,6 @@ package services import ( - "bytes" "context" "strings" "sync" @@ -145,7 +144,7 @@ func (c *UnifiedResourceCache) putLocked(resource resource) { // from those trees before adding a new one. This can happen // when a node's hostname changes oldSortKey := makeResourceSortKey(oldResource) - if string(oldSortKey.byName) != string(sortKey.byName) { + if oldSortKey.byName.Compare(sortKey.byName) != 0 { c.deleteSortKey(oldSortKey) } } @@ -168,10 +167,10 @@ func putResources[T resource](cache *UnifiedResourceCache, resources []T) { func (c *UnifiedResourceCache) deleteSortKey(sortKey resourceSortKey) error { if _, ok := c.nameTree.Delete(&item{Key: sortKey.byName}); !ok { - return trace.NotFound("key %q is not found in unified cache name sort tree", string(sortKey.byName)) + return trace.NotFound("key %q is not found in unified cache name sort tree", sortKey.byName.String()) } if _, ok := c.typeTree.Delete(&item{Key: sortKey.byType}); !ok { - return trace.NotFound("key %q is not found in unified cache type sort tree", string(sortKey.byType)) + return trace.NotFound("key %q is not found in unified cache type sort tree", sortKey.byType.String()) } return nil } @@ -249,7 +248,7 @@ func (c *UnifiedResourceCache) getRange(ctx context.Context, startKey backend.Ke // do we have all we need? set nextKey and stop iterating // we do this after the matchFn to make sure they have access to the "next" node if req.Limit > 0 && len(res) >= int(req.Limit) { - nextKey = string(item.Key) + nextKey = item.Key.String() return false } res = append(res, resourceFromMap) @@ -699,7 +698,7 @@ func (c *UnifiedResourceCache) defineCollectorAsInitialized() { func (i *item) Less(iother btree.Item) bool { switch other := iother.(type) { case *item: - return bytes.Compare(i.Key, other.Key) < 0 + return i.Key.Compare(other.Key) < 0 case *prefixItem: return !iother.Less(i) default: @@ -716,7 +715,7 @@ type prefixItem struct { // Less is used for Btree operations func (p *prefixItem) Less(iother btree.Item) bool { other := iother.(*item) - return !bytes.HasPrefix(other.Key, p.prefix) + return !other.Key.HasPrefix(p.prefix) } type resource interface {