From 8d2d5f83ec9960bdd34972a5c9cff6295be2373a Mon Sep 17 00:00:00 2001 From: Michael Wilson Date: Wed, 6 Mar 2024 21:02:11 -0500 Subject: [PATCH] Add CompareResources support for custom equals functions. The `cmp` functions are not performant. The reconciler will now support an IsEqual function if it is defined, and will use it if a resource has one. Additionally, one has been added for Okta assignments and dependent objects. --- .github/workflows/lint.yaml | 4 + Makefile | 13 ++ api/types/derived.gen.go | 88 +++++++ api/types/okta.go | 8 + api/types/okta_test.go | 176 ++++++++++++++ api/types/resource.go | 10 + api/types/resource_test.go | 220 ++++++++++++++++++ build.assets/tooling/cmd/goderive/main.go | 53 +++++ .../plugin/teleportequal/teleportequal.go | 140 +++++++++++ build.assets/tooling/go.mod | 3 + build.assets/tooling/go.sum | 5 + lib/services/compare.go | 26 ++- lib/services/compare_test.go | 52 +++++ 13 files changed, 790 insertions(+), 8 deletions(-) create mode 100644 api/types/derived.gen.go create mode 100644 build.assets/tooling/cmd/goderive/main.go create mode 100644 build.assets/tooling/cmd/goderive/plugin/teleportequal/teleportequal.go create mode 100644 lib/services/compare_test.go diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 9bae7b8604378..ccb85835b023c 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -102,3 +102,7 @@ jobs: - name: Check if Operator CRDs are up to date # We have to add the current directory as a safe directory or else git commands will not work as expected. run: git config --global --add safe.directory $(realpath .) && make crds-up-to-date + + - name: Check if derived functions are up to date + # We have to add the current directory as a safe directory or else git commands will not work as expected. + run: git config --global --add safe.directory $(realpath .) && make derive-up-to-date diff --git a/Makefile b/Makefile index 0aa3beb548283..31ffc8da590d4 100644 --- a/Makefile +++ b/Makefile @@ -1303,6 +1303,19 @@ buf/installed: exit 1; \ fi +# derive will generate derived functions for our API. +.PHONY: derive +derive: + cd $(TOOLINGDIR) && go run ./cmd/goderive/main.go ../../api/types + +# derive-up-to-date checks if the generated derived functions are up to date. +.PHONY: derive-up-to-date +derive-up-to-date: must-start-clean/host derive + @if ! $(GIT) diff --quiet; then \ + echo 'Please run make derive.'; \ + exit 1; \ + fi + # grpc generates gRPC stubs from service definitions. # This target runs in the buildbox container. .PHONY: grpc diff --git a/api/types/derived.gen.go b/api/types/derived.gen.go new file mode 100644 index 0000000000000..e5556cd283221 --- /dev/null +++ b/api/types/derived.gen.go @@ -0,0 +1,88 @@ +// Code generated by goderive DO NOT EDIT. + +package types + +// deriveTeleportEqualOktaAssignmentV1 returns whether this and that are equal. +func deriveTeleportEqualOktaAssignmentV1(this, that *OktaAssignmentV1) bool { + return (this == nil && that == nil) || + this != nil && that != nil && + deriveTeleportEqualResourceHeader(&this.ResourceHeader, &that.ResourceHeader) && + deriveTeleportEqual(&this.Spec, &that.Spec) +} + +// deriveTeleportEqualResourceHeader returns whether this and that are equal. +func deriveTeleportEqualResourceHeader(this, that *ResourceHeader) bool { + return (this == nil && that == nil) || + this != nil && that != nil && + this.Kind == that.Kind && + this.SubKind == that.SubKind && + this.Version == that.Version && + deriveTeleportEqualMetadata(&this.Metadata, &that.Metadata) +} + +// deriveTeleportEqualMetadata returns whether this and that are equal. +func deriveTeleportEqualMetadata(this, that *Metadata) bool { + return (this == nil && that == nil) || + this != nil && that != nil && + this.Name == that.Name && + this.Namespace == that.Namespace && + this.Description == that.Description && + deriveTeleportEqual_(this.Labels, that.Labels) && + ((this.Expires == nil && that.Expires == nil) || (this.Expires != nil && that.Expires != nil && (*(this.Expires)).Equal(*(that.Expires)))) +} + +// deriveTeleportEqual returns whether this and that are equal. +func deriveTeleportEqual(this, that *OktaAssignmentSpecV1) bool { + return (this == nil && that == nil) || + this != nil && that != nil && + this.User == that.User && + deriveTeleportEqual_1(this.Targets, that.Targets) && + this.CleanupTime.Equal(that.CleanupTime) && + this.Status == that.Status && + this.LastTransition.Equal(that.LastTransition) && + this.Finalized == that.Finalized +} + +// deriveTeleportEqual_ returns whether this and that are equal. +func deriveTeleportEqual_(this, that map[string]string) bool { + if this == nil || that == nil { + return this == nil && that == nil + } + if len(this) != len(that) { + return false + } + for k, v := range this { + thatv, ok := that[k] + if !ok { + return false + } + if !(v == thatv) { + return false + } + } + return true +} + +// deriveTeleportEqual_1 returns whether this and that are equal. +func deriveTeleportEqual_1(this, that []*OktaAssignmentTargetV1) bool { + if this == nil || that == nil { + return this == nil && that == nil + } + if len(this) != len(that) { + return false + } + for i := 0; i < len(this); i++ { + if !(deriveTeleportEqual_2(this[i], that[i])) { + return false + } + } + return true +} + +// deriveTeleportEqual_2 returns whether this and that are equal. +func deriveTeleportEqual_2(this, that *OktaAssignmentTargetV1) bool { + return (this == nil && that == nil) || + this != nil && that != nil && + this.Type == that.Type && + this.Id == that.Id +} diff --git a/api/types/okta.go b/api/types/okta.go index 1be51a74f68e3..5adace4cc244a 100644 --- a/api/types/okta.go +++ b/api/types/okta.go @@ -396,6 +396,14 @@ func (o *OktaAssignmentV1) CheckAndSetDefaults() error { return nil } +// IsEqual determines if two okta assignment resources are equivalent to one another. +func (o *OktaAssignmentV1) IsEqual(i OktaAssignment) bool { + if other, ok := i.(*OktaAssignmentV1); ok { + return deriveTeleportEqualOktaAssignmentV1(o, other) + } + return false +} + // OktaAssignmentTarget is an target for an Okta assignment. type OktaAssignmentTarget interface { // GetTargetType returns the target type. diff --git a/api/types/okta_test.go b/api/types/okta_test.go index 904f6792da214..833b0709cc15e 100644 --- a/api/types/okta_test.go +++ b/api/types/okta_test.go @@ -19,6 +19,7 @@ package types import ( "fmt" "testing" + "time" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -79,6 +80,181 @@ func TestOktaAssignments_SetStatus(t *testing.T) { } } +func TestOktAssignmentIsEqual(t *testing.T) { + newAssignment := func(changeFns ...func(*OktaAssignmentV1)) *OktaAssignmentV1 { + assignment := &OktaAssignmentV1{ + ResourceHeader: ResourceHeader{ + Kind: KindOktaAssignment, + Version: V1, + Metadata: Metadata{ + Name: "name", + }, + }, + Spec: OktaAssignmentSpecV1{ + User: "user", + Targets: []*OktaAssignmentTargetV1{ + {Id: "1", Type: OktaAssignmentTargetV1_APPLICATION}, + {Id: "2", Type: OktaAssignmentTargetV1_GROUP}, + }, + CleanupTime: time.Time{}, + Status: OktaAssignmentSpecV1_PENDING, + LastTransition: time.Time{}, + Finalized: true, + }, + } + require.NoError(t, assignment.CheckAndSetDefaults()) + + for _, fn := range changeFns { + fn(assignment) + } + + return assignment + } + tests := []struct { + name string + o1 *OktaAssignmentV1 + o2 *OktaAssignmentV1 + expected bool + }{ + { + name: "empty equals", + o1: &OktaAssignmentV1{}, + o2: &OktaAssignmentV1{}, + expected: true, + }, + { + name: "nil equals", + o1: nil, + o2: (*OktaAssignmentV1)(nil), + expected: true, + }, + { + name: "one is nil", + o1: &OktaAssignmentV1{}, + o2: (*OktaAssignmentV1)(nil), + expected: false, + }, + { + name: "populated equals", + o1: newAssignment(), + o2: newAssignment(), + expected: true, + }, + { + name: "resource header is different", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.ResourceHeader.Kind = "different-kind" + }), + expected: false, + }, + { + name: "user is different", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.User = "different-user" + }), + expected: false, + }, + { + name: "targets different id", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Targets = []*OktaAssignmentTargetV1{ + {Id: "2", Type: OktaAssignmentTargetV1_APPLICATION}, + {Id: "2", Type: OktaAssignmentTargetV1_GROUP}, + } + }), + expected: false, + }, + { + name: "targets different type", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Targets = []*OktaAssignmentTargetV1{ + {Id: "1", Type: OktaAssignmentTargetV1_GROUP}, + {Id: "2", Type: OktaAssignmentTargetV1_GROUP}, + } + }), + expected: false, + }, + { + name: "targets different sizes", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Targets = []*OktaAssignmentTargetV1{ + {Id: "1", Type: OktaAssignmentTargetV1_APPLICATION}, + } + }), + expected: false, + }, + { + name: "targets both nil", + o1: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Targets = nil + }), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Targets = nil + }), + expected: true, + }, + { + name: "targets o1 is nil", + o1: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Targets = nil + }), + o2: newAssignment(), + expected: false, + }, + { + name: "targets o2 is nil", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Targets = nil + }), + expected: false, + }, + { + name: "cleanup time is different", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.CleanupTime = time.Date(1, 2, 3, 4, 5, 6, 7, time.UTC) + }), + expected: false, + }, + { + name: "status is different", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Status = OktaAssignmentSpecV1_PROCESSING + }), + expected: false, + }, + { + name: "last transition is different", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.CleanupTime = time.Date(1, 2, 3, 4, 5, 6, 7, time.UTC) + }), + expected: false, + }, + { + name: "finalized is different", + o1: newAssignment(), + o2: newAssignment(func(o *OktaAssignmentV1) { + o.Spec.Finalized = false + }), + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal(t, test.expected, test.o1.IsEqual(test.o2)) + }) + } +} + func newOktaAssignment(t *testing.T, status string) OktaAssignment { assignment := &OktaAssignmentV1{} diff --git a/api/types/resource.go b/api/types/resource.go index 2183cb7a0ed97..57293f4729046 100644 --- a/api/types/resource.go +++ b/api/types/resource.go @@ -378,6 +378,11 @@ func (h *ResourceHeader) GetAllLabels() map[string]string { return h.Metadata.Labels } +// IsEqual determines if two resource header resources are equivalent to one another. +func (h *ResourceHeader) IsEqual(other *ResourceHeader) bool { + return deriveTeleportEqualResourceHeader(h, other) +} + func (h *ResourceHeader) CheckAndSetDefaults() error { if h.Kind == "" { return trace.BadParameter("resource has an empty Kind field") @@ -452,6 +457,11 @@ func (m *Metadata) SetOrigin(origin string) { m.Labels[OriginLabel] = origin } +// IsEqual determines if two metadata resources are equivalent to one another. +func (m *Metadata) IsEqual(other *Metadata) bool { + return deriveTeleportEqualMetadata(m, other) +} + // CheckAndSetDefaults checks validity of all parameters and sets defaults func (m *Metadata) CheckAndSetDefaults() error { if m.Name == "" { diff --git a/api/types/resource_test.go b/api/types/resource_test.go index 62d6a73b43fa0..33174783e3bc4 100644 --- a/api/types/resource_test.go +++ b/api/types/resource_test.go @@ -18,6 +18,7 @@ package types import ( "testing" + "time" "github.com/stretchr/testify/require" ) @@ -586,3 +587,222 @@ func TestFriendlyName(t *testing.T) { }) } } + +func TestMetadataIsEqual(t *testing.T) { + newMetadata := func(changeFns ...func(*Metadata)) *Metadata { + metadata := &Metadata{ + Name: "name", + Namespace: "namespace", + Description: "description", + Labels: map[string]string{"label1": "value1"}, + Expires: &time.Time{}, + ID: 1234, + Revision: "aaaa", + } + + for _, fn := range changeFns { + fn(metadata) + } + + return metadata + } + tests := []struct { + name string + m1 *Metadata + m2 *Metadata + expected bool + }{ + { + name: "empty equals", + m1: &Metadata{}, + m2: &Metadata{}, + expected: true, + }, + { + name: "nil equals", + m1: nil, + m2: (*Metadata)(nil), + expected: true, + }, + { + name: "one is nil", + m1: &Metadata{}, + m2: (*Metadata)(nil), + expected: false, + }, + { + name: "populated equals", + m1: newMetadata(), + m2: newMetadata(), + expected: true, + }, + { + name: "id and revision have no effect", + m1: newMetadata(), + m2: newMetadata(func(m *Metadata) { + m.ID = 7890 + m.Revision = "bbbb" + }), + expected: true, + }, + { + name: "name is different", + m1: newMetadata(), + m2: newMetadata(func(m *Metadata) { + m.Name = "different-name" + }), + expected: false, + }, + { + name: "namespace is different", + m1: newMetadata(), + m2: newMetadata(func(m *Metadata) { + m.Namespace = "different-namespace" + }), + expected: false, + }, + { + name: "description is different", + m1: newMetadata(), + m2: newMetadata(func(m *Metadata) { + m.Description = "different-description" + }), + expected: false, + }, + { + name: "labels is different", + m1: newMetadata(), + m2: newMetadata(func(m *Metadata) { + m.Labels = map[string]string{"label2": "value2"} + }), + expected: false, + }, + { + name: "expires is different", + m1: newMetadata(), + m2: newMetadata(func(m *Metadata) { + newTime := time.Date(1, 2, 3, 4, 5, 6, 7, time.UTC) + m.Expires = &newTime + }), + expected: false, + }, + { + name: "expires both nil", + m1: newMetadata(func(m *Metadata) { m.Expires = nil }), + m2: newMetadata(func(m *Metadata) { m.Expires = nil }), + expected: true, + }, + { + name: "expires m1 nil", + m1: newMetadata(func(m *Metadata) { m.Expires = nil }), + m2: newMetadata(), + expected: false, + }, + { + name: "expires m2 nil", + m1: newMetadata(), + m2: newMetadata(func(m *Metadata) { m.Expires = nil }), + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal(t, test.expected, test.m1.IsEqual(test.m2)) + }) + } +} + +func TestResourceHeaderIsEqual(t *testing.T) { + newHeader := func(changeFns ...func(*ResourceHeader)) *ResourceHeader { + header := &ResourceHeader{ + Kind: "kind", + SubKind: "subkind", + Version: "v1", + Metadata: Metadata{ + Name: "name", + Namespace: "namespace", + Description: "description", + Labels: map[string]string{"label1": "value1"}, + Expires: &time.Time{}, + ID: 1234, + Revision: "aaaa", + }, + } + + for _, fn := range changeFns { + fn(header) + } + + return header + } + tests := []struct { + name string + h1 *ResourceHeader + h2 *ResourceHeader + expected bool + }{ + { + name: "empty equals", + h1: &ResourceHeader{}, + h2: &ResourceHeader{}, + expected: true, + }, + { + name: "nil equals", + h1: nil, + h2: (*ResourceHeader)(nil), + expected: true, + }, + { + name: "one is nil", + h1: &ResourceHeader{}, + h2: (*ResourceHeader)(nil), + expected: false, + }, + { + name: "populated equals", + h1: newHeader(), + h2: newHeader(), + expected: true, + }, + { + name: "kind is different", + h1: newHeader(), + h2: newHeader(func(h *ResourceHeader) { + h.Kind = "different-kind" + }), + expected: false, + }, + { + name: "subkind is different", + h1: newHeader(), + h2: newHeader(func(h *ResourceHeader) { + h.SubKind = "different-subkind" + }), + expected: false, + }, + { + name: "metadata is different", + h1: newHeader(), + h2: newHeader(func(h *ResourceHeader) { + h.Metadata = Metadata{} + }), + expected: false, + }, + { + name: "version is different", + h1: newHeader(), + h2: newHeader(func(h *ResourceHeader) { + h.Version = "different-version" + }), + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal(t, test.expected, test.h1.IsEqual(test.h2)) + }) + } +} diff --git a/build.assets/tooling/cmd/goderive/main.go b/build.assets/tooling/cmd/goderive/main.go new file mode 100644 index 0000000000000..1260866add8bc --- /dev/null +++ b/build.assets/tooling/cmd/goderive/main.go @@ -0,0 +1,53 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/awalterschulze/goderive/derive" + + "github.com/gravitational/teleport/build.assets/tooling/cmd/goderive/plugin/teleportequal" +) + +func main() { + // Establish Teleport derive plugins of interest. + plugins := []derive.Plugin{ + teleportequal.NewPlugin(), + } + + // Parse args, which are just paths at the moment.. + flag.Parse() + paths := derive.ImportPaths(flag.Args()) + + // Load the given paths into the generator. + g, err := derive.NewPlugins(plugins, false, false).Load(paths) + if err != nil { + fmt.Printf("Error creating new plugins: %v\n", err) + os.Exit(1) + } + + // Generate the derived code. + if err := g.Generate(); err != nil { + fmt.Printf("Error generating code: %v\n", err) + os.Exit(1) + } +} diff --git a/build.assets/tooling/cmd/goderive/plugin/teleportequal/teleportequal.go b/build.assets/tooling/cmd/goderive/plugin/teleportequal/teleportequal.go new file mode 100644 index 0000000000000..9d5cbb68fd05c --- /dev/null +++ b/build.assets/tooling/cmd/goderive/plugin/teleportequal/teleportequal.go @@ -0,0 +1,140 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package teleportequal + +import ( + "go/types" + "strings" + + "github.com/awalterschulze/goderive/derive" + "github.com/awalterschulze/goderive/plugin/equal" +) + +// NewPlugin will create a Teleport equals plugin. This is the default derive equals plugin +// with protobuf fields, ID/Revision, and Status fields filtered out. +func NewPlugin() derive.Plugin { + return derive.NewPlugin("teleport-equal", "deriveTeleportEqual", New) +} + +func New(typesMap derive.TypesMap, p derive.Printer, deps map[string]derive.Dependency) derive.Generator { + gen := &gen{ + TypesMap: typesMap, + } + gen.equalGen = equal.New(gen, p, deps) + return gen +} + +type gen struct { + derive.TypesMap + equalGen derive.Generator +} + +func (g *gen) Add(name string, typs []types.Type) (string, error) { + return g.equalGen.Add(name, typs) +} + +func (g *gen) Generate(typs []types.Type) error { + return g.equalGen.Generate(typs) +} + +func (g *gen) ToGenerate() [][]types.Type { + return filterTypes(g.TypesMap.ToGenerate()) +} + +// filterTypes will take the given types from ToGenerate and filter out fields from +// messages that we want to ignore. +func filterTypes(typsOfTyps [][]types.Type) [][]types.Type { + for i, typs := range typsOfTyps { + for j, typ := range typs { + typsOfTyps[i][j] = removeIgnoredFields("", typ) + } + } + return typsOfTyps +} + +// removeIgnoredFields will remove fields we want to ignore from any given types. +func removeIgnoredFields(name string, typ types.Type) types.Type { + // If the current type is a pointer, call removeIgnoredFields for the type the pointer points to. + if ptr, ok := typ.(*types.Pointer); ok { + return types.NewPointer(removeIgnoredFields(name, ptr.Elem())) + } + + // If this is named, call removeIgnoredFields for the underlying type and pass along the name. + if named, ok := typ.(*types.Named); ok { + methods := make([]*types.Func, named.NumMethods()) + for i := 0; i < named.NumMethods(); i++ { + methods[i] = named.Method(i) + } + return types.NewNamed(named.Obj(), removeIgnoredFields(named.Obj().Name(), named.Underlying()), methods) + } + + // If this is a struct, filter out ignored fields from the struct. + if strct, ok := typ.(*types.Struct); ok { + return removeIgnoredFieldsFromStruct(name, strct) + } + + return typ +} + +// removeIgnoredFieldsFromStruct will remove the following fields from structs it encounters: +// - ID/Revision from the Metadata struct. +// - protobuf XXX_ fields. +// - Status fields that are sitting aside a Spec field. +func removeIgnoredFieldsFromStruct(name string, strct *types.Struct) types.Type { + numFields := strct.NumFields() + var filteredFields []*types.Var + var filteredTags []string + var hasSpec bool + + // Figure out if the field has a spec. If it does, we should ignore any found status fields. + for i := 0; i < numFields; i++ { + if strct.Field(i).Name() == "Spec" { + hasSpec = true + } + } + + for i := 0; i < numFields; i++ { + field := strct.Field(i) + fieldName := field.Name() + + // Ignore status fields that sit aside spec fields. + if hasSpec && fieldName == "Status" { + continue + } + + // Ignore XXX_ fields, which are proto fields. + if strings.HasPrefix(fieldName, "XXX_") { + continue + } + + // If this is the metadata struct, disregard the ID and Revision fields. + if strings.HasPrefix(name, "Metadata") && (fieldName == "ID" || fieldName == "Revision") { + continue + } + + filteredFields = append(filteredFields, field) + filteredTags = append(filteredTags, strct.Tag(i)) + + } + if len(filteredFields) != numFields { + return types.NewStruct(filteredFields, filteredTags) + } + + return strct +} diff --git a/build.assets/tooling/go.mod b/build.assets/tooling/go.mod index bc34cb59d848a..3c2fae8431659 100644 --- a/build.assets/tooling/go.mod +++ b/build.assets/tooling/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/alecthomas/kingpin/v2 v2.3.2 // replaced + github.com/awalterschulze/goderive v0.0.0-20230417115348-bbb2c8c30585 github.com/bmatcuk/doublestar/v4 v4.6.1 github.com/bradleyfalzon/ghinstallation/v2 v2.9.0 github.com/gogo/protobuf v1.3.2 @@ -41,6 +42,7 @@ require ( github.com/imdario/mergo v0.3.13 // indirect github.com/jonboulle/clockwork v0.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/kisielk/gotool v1.0.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect @@ -58,6 +60,7 @@ require ( golang.org/x/sys v0.15.0 // indirect golang.org/x/term v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.16.1 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/build.assets/tooling/go.sum b/build.assets/tooling/go.sum index dce83eab3c799..e3e5f70d3b1a0 100644 --- a/build.assets/tooling/go.sum +++ b/build.assets/tooling/go.sum @@ -14,6 +14,8 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8V github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apparentlymart/go-dump v0.0.0-20180507223929-23540a00eaa3/go.mod h1:oL81AME2rN47vu18xqj1S1jPIPuN7afo62yKTNn3XMM= github.com/apparentlymart/go-textseg v1.0.0/go.mod h1:z96Txxhf3xSFMPmb5X/1W05FF/Nj9VFpLOpjS5yuumk= +github.com/awalterschulze/goderive v0.0.0-20230417115348-bbb2c8c30585 h1:NBBz3zlM7i5awyrsgR7n6SC3D7YfmVOPOhSFrVapuVY= +github.com/awalterschulze/goderive v0.0.0-20230417115348-bbb2c8c30585/go.mod h1:rXccmDQDJN/4aGqWxWhq+UmBJeQEkFV/2/rkluP+ipA= github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I= github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/bradleyfalzon/ghinstallation/v2 v2.9.0 h1:HmxIYqnxubRYcYGRc5v3wUekmo5Wv2uX3gukmWJ0AFk= @@ -127,6 +129,7 @@ github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0 h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= @@ -309,6 +312,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/lib/services/compare.go b/lib/services/compare.go index 1dedf4f2a595e..c9e72df60783b 100644 --- a/lib/services/compare.go +++ b/lib/services/compare.go @@ -28,16 +28,26 @@ import ( "github.com/gravitational/teleport/api/types" ) +// IsEqual[T] will be used instead of cmp.Equal if a resource implements it. +type IsEqual[T any] interface { + IsEqual(T) bool +} + // CompareResources compares two resources by all significant fields. func CompareResources[T any](resA, resB T) int { - equal := cmp.Equal(resA, resB, - ignoreProtoXXXFields(), - cmpopts.IgnoreFields(types.Metadata{}, "ID", "Revision"), - cmpopts.IgnoreFields(types.DatabaseV3{}, "Status"), - cmpopts.IgnoreFields(types.UserSpecV2{}, "Status"), - cmpopts.IgnoreUnexported(headerv1.Metadata{}), - cmpopts.EquateEmpty(), - ) + var equal bool + if hasEqual, ok := any(resA).(IsEqual[T]); ok { + equal = hasEqual.IsEqual(resB) + } else { + equal = cmp.Equal(resA, resB, + ignoreProtoXXXFields(), + cmpopts.IgnoreFields(types.Metadata{}, "ID", "Revision"), + cmpopts.IgnoreFields(types.DatabaseV3{}, "Status"), + cmpopts.IgnoreFields(types.UserSpecV2{}, "Status"), + cmpopts.IgnoreUnexported(headerv1.Metadata{}), + cmpopts.EquateEmpty(), + ) + } if equal { return Equal } diff --git a/lib/services/compare_test.go b/lib/services/compare_test.go new file mode 100644 index 0000000000000..392b808df8433 --- /dev/null +++ b/lib/services/compare_test.go @@ -0,0 +1,52 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package services + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCompareResources(t *testing.T) { + compareTestCase(t, "cmp equal", compareResource{true}, compareResource{true}, Equal) + compareTestCase(t, "cmp not equal", compareResource{true}, compareResource{false}, Different) + + // These results should be forced since we're going through a custom compare function. + compareTestCase(t, "IsEqual equal", &compareResourceWithEqual{true}, &compareResourceWithEqual{false}, Equal) + compareTestCase(t, "IsEqual not equal", &compareResourceWithEqual{false}, &compareResourceWithEqual{false}, Different) +} + +func compareTestCase[T any](t *testing.T, name string, resA, resB T, expected int) { + t.Run(name, func(t *testing.T) { + require.Equal(t, expected, CompareResources(resA, resB)) + }) +} + +type compareResource struct { + Field bool +} + +type compareResourceWithEqual struct { + ForceCompare bool +} + +func (r *compareResourceWithEqual) IsEqual(_ *compareResourceWithEqual) bool { + return r.ForceCompare +}