diff --git a/ast/compile.go b/ast/compile.go index 9025f862b2..ae9c079e86 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -127,6 +127,7 @@ type Compiler struct { maxErrs int sorted []string // list of sorted module names pathExists func([]string) (bool, error) + pathConflictCheckRoots []string after map[string][]CompilerStageDefinition metrics metrics.Metrics capabilities *Capabilities // user-supplied capabilities @@ -383,6 +384,15 @@ func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Comp return c } +// WithPathConflictsCheckRoots enables checking path conflicts from the specified root instead +// of the top root node. Limiting conflict checks to a known set of roots, such as bundle roots, +// improves performance. Each root has the format of a "/"-delimited string, excluding the "data" +// root document. +func (c *Compiler) WithPathConflictsCheckRoots(rootPaths []string) *Compiler { + c.pathConflictCheckRoots = rootPaths + return c +} + // WithStageAfter registers a stage to run during compilation after // the named stage. func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler { diff --git a/ast/compile_test.go b/ast/compile_test.go index dade9ed506..11e7d3d7c4 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "reflect" + "slices" "sort" "strings" "testing" @@ -1989,6 +1990,40 @@ p[r] := 2 if { r := "foo" }`, assertCompilerErrorStrings(t, c, expected) } +func TestCompilerCheckRuleConflictsWithRoots(t *testing.T) { + + c := getCompilerWithParsedModules(map[string]string{ + "mod1.rego": `package badrules.dataoverlap + +p if { true }`, + "mod2.rego": `package badrules.existserr +p if { true }`, + + // this does not trigger conflict check because + // WithPathConflictsCheckRoots limits the root to "badrules". + "mod3.rego": `package badrules_outside_root.dataoverlap +p if { true }`, + }) + + c.WithPathConflictsCheck(func(path []string) (bool, error) { + if slices.Contains(path, "dataoverlap") { + return true, nil + } else if reflect.DeepEqual(path, []string{"badrules", "existserr", "p"}) { + return false, fmt.Errorf("unexpected error") + } + return false, nil + }).WithPathConflictsCheckRoots([]string{"badrules"}) + + compileStages(c, c.checkRuleConflicts) + + expected := []string{ + "rego_compile_error: conflict check for data path badrules/existserr/p: unexpected error", + "rego_compile_error: conflicting rule for data path badrules/dataoverlap/p found", + } + + assertCompilerErrorStrings(t, c, expected) +} + func TestCompilerCheckRuleConflictsDefaultFunction(t *testing.T) { tests := []struct { note string diff --git a/ast/conflicts.go b/ast/conflicts.go index c2713ad576..685cc6b694 100644 --- a/ast/conflicts.go +++ b/ast/conflicts.go @@ -5,6 +5,7 @@ package ast import ( + "slices" "strings" ) @@ -18,8 +19,33 @@ func CheckPathConflicts(c *Compiler, exists func([]string) (bool, error)) Errors return nil } - for _, node := range root.Children { - errs = append(errs, checkDocumentConflicts(node, exists, nil)...) + if len(c.pathConflictCheckRoots) == 0 || slices.Contains(c.pathConflictCheckRoots, "") { + for _, child := range root.Children { + errs = append(errs, checkDocumentConflicts(child, exists, nil)...) + } + return errs + } + + for _, rootPath := range c.pathConflictCheckRoots { + // traverse AST from `path` to go to the new root + paths := strings.Split(rootPath, "/") + node := root + for _, key := range paths { + node = node.Child(String(key)) + if node == nil { + break + } + } + + if node == nil { + // could not find the node from the AST (e.g. `path` is from a data file) + // then no conflict is possible + continue + } + + for _, child := range node.Children { + errs = append(errs, checkDocumentConflicts(child, exists, paths)...) + } } return errs diff --git a/plugins/bundle/plugin.go b/plugins/bundle/plugin.go index 53ccde387c..efd8eedcf4 100644 --- a/plugins/bundle/plugin.go +++ b/plugins/bundle/plugin.go @@ -615,6 +615,10 @@ func (p *Plugin) activate(ctx context.Context, name string, b *bundle.Bundle, is compiler = compiler.WithPathConflictsCheck(storage.NonEmpty(ctx, p.manager.Store, txn)). WithEnablePrintStatements(p.manager.EnablePrintStatements()) + if b.Manifest.Roots != nil { + compiler = compiler.WithPathConflictsCheckRoots(*b.Manifest.Roots) + } + var activateErr error opts := &bundle.ActivateOpts{ diff --git a/plugins/bundle/plugin_test.go b/plugins/bundle/plugin_test.go index 4650689f73..60fafaa762 100644 --- a/plugins/bundle/plugin_test.go +++ b/plugins/bundle/plugin_test.go @@ -6574,6 +6574,194 @@ func TestGetNormalizedBundleName(t *testing.T) { } } +func TestBundleActivationWithRootOverlap(t *testing.T) { + ctx := context.Background() + plugin := getPluginWithExistingLoadedBundle( + t, + "policy-bundle", + []string{"foo/bar"}, + nil, + []testModule{ + { + Path: "foo/bar/bar.rego", + Data: `package foo.bar +result := true`, + }, + }, + ) + + bundleName := "new-bundle" + plugin.status[bundleName] = &Status{Name: bundleName, Metrics: metrics.New()} + plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) + + b := getTestBundleWithData( + []string{"foo/bar/baz"}, + []byte(`{"foo": {"bar": 1, "baz": "qux"}}`), + nil, + ) + + b.Manifest.Init() + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b, Metrics: metrics.New(), Size: snapshotBundleSize}) + + // "foo/bar" and "foo/bar/baz" overlap with each other; activation will fail + status, ok := plugin.status[bundleName] + if !ok { + t.Fatalf("Expected to find status for %s, found nil", bundleName) + } + if status.Code != errCode { + t.Fatalf("Expected status code to be %s, found %s", errCode, status.Code) + } + if exp := "detected overlapping roots"; !strings.Contains(status.Message, exp) { + t.Fatalf(`Expected status message to contain "%s", found %s`, exp, status.Message) + } +} + +func TestBundleActivationWithNoManifestRootsButWithPathConflict(t *testing.T) { + ctx := context.Background() + plugin := getPluginWithExistingLoadedBundle( + t, + "policy-bundle", + []string{"foo/bar"}, + nil, + []testModule{ + { + Path: "foo/bar/bar.rego", + Data: `package foo.bar +result := true`, + }, + }, + ) + + bundleName := "new-bundle" + plugin.status[bundleName] = &Status{Name: bundleName, Metrics: metrics.New()} + plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) + + b := getTestBundleWithData( + nil, + []byte(`{"foo": {"bar": 1, "baz": "qux"}}`), + nil, + ) + + b.Manifest.Init() + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b, Metrics: metrics.New(), Size: snapshotBundleSize}) + + // new bundle has path "foo/bar" which overlaps with existing bundle with path "foo/bar"; activation will fail + status, ok := plugin.status[bundleName] + if !ok { + t.Fatalf("Expected to find status for %s, found nil", bundleName) + } + if status.Code != errCode { + t.Fatalf("Expected status code to be %s, found %s", errCode, status.Code) + } + if !strings.Contains(status.Message, "detected overlapping") { + t.Fatalf(`Expected status message to contain "detected overlapping roots", found %s`, status.Message) + } +} + +func TestBundleActivationWithNoManifestRootsOverlap(t *testing.T) { + ctx := context.Background() + plugin := getPluginWithExistingLoadedBundle( + t, + "policy-bundle", + []string{"foo/bar"}, + nil, + []testModule{ + { + Path: "foo/bar/bar.rego", + Data: `package foo.bar +result := true`, + }, + }, + ) + + bundleName := "new-bundle" + plugin.status[bundleName] = &Status{Name: bundleName, Metrics: metrics.New()} + plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) + + b := getTestBundleWithData( + []string{"foo/baz"}, + nil, + []testModule{ + { + Path: "foo/bar/baz.rego", + Data: `package foo.baz +result := true`, + }, + }, + ) + + b.Manifest.Init() + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b, Metrics: metrics.New(), Size: snapshotBundleSize}) + + status, ok := plugin.status[bundleName] + if !ok { + t.Fatalf("Expected to find status for %s, found nil", bundleName) + } + if status.Code != "" { + t.Fatalf("Expected status code to be empty, found %s", status.Code) + } +} + +type testModule struct { + Path string + Data string +} + +func getTestBundleWithData(roots []string, data []byte, modules []testModule) bundle.Bundle { + b := bundle.Bundle{} + + if len(roots) > 0 { + b.Manifest = bundle.Manifest{Roots: &roots} + } + + if len(data) > 0 { + b.Data = util.MustUnmarshalJSON(data).(map[string]interface{}) + } + + for _, m := range modules { + if len(m.Data) > 0 { + b.Modules = append(b.Modules, + bundle.ModuleFile{ + Path: m.Path, + Parsed: ast.MustParseModule(m.Data), + Raw: []byte(m.Data), + }, + ) + } + } + + b.Manifest.Init() + + return b +} + +func getPluginWithExistingLoadedBundle(t *testing.T, bundleName string, roots []string, data []byte, modules []testModule) *Plugin { + ctx := context.Background() + store := inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false), inmem.OptReturnASTValuesOnRead(true)) + manager := getTestManagerWithOpts(nil, store) + plugin := New(&Config{}, manager) + plugin.status[bundleName] = &Status{Name: bundleName, Metrics: metrics.New()} + plugin.downloaders[bundleName] = download.New(download.Config{}, plugin.manager.Client(""), bundleName) + + ensurePluginState(t, plugin, plugins.StateNotReady) + + b := getTestBundleWithData(roots, data, modules) + + plugin.oneShot(ctx, bundleName, download.Update{Bundle: &b, Metrics: metrics.New(), Size: snapshotBundleSize}) + + ensurePluginState(t, plugin, plugins.StateOK) + + if status, ok := plugin.status[bundleName]; !ok { + t.Fatalf("Expected to find status for %s, found nil", bundleName) + } else if status.Type != bundle.SnapshotBundleType { + t.Fatalf("Expected snapshot bundle but got %v", status.Type) + } else if status.Size != snapshotBundleSize { + t.Fatalf("Expected snapshot bundle size %d but got %d", snapshotBundleSize, status.Size) + } + + return plugin +} + func writeTestBundleToDisk(t *testing.T, srcDir string, signed bool) bundle.Bundle { t.Helper()