From a432d45d59ee07dc7e49a11bf1973c97e94ca184 Mon Sep 17 00:00:00 2001 From: Matt Toohey Date: Tue, 5 Mar 2024 13:35:32 +1100 Subject: [PATCH] fix: prevent schema dependency cycles (#894) (#1015) --- backend/schema/schema_test.go | 110 ++++++++++++++++++++++++++++++++++ backend/schema/validate.go | 78 ++++++++++++++++++++++++ 2 files changed, 188 insertions(+) diff --git a/backend/schema/schema_test.go b/backend/schema/schema_test.go index 086e7862ef..5545ad6ab9 100644 --- a/backend/schema/schema_test.go +++ b/backend/schema/schema_test.go @@ -481,3 +481,113 @@ var testSchema = MustValidate(&Schema{ }, }, }) + +func TestValidateDependencies(t *testing.T) { + tests := []struct { + name string + schema string + err string + }{ + { + // one <--> two, cyclical + name: "TwoModuleCycle", + schema: ` + module one { + verb one(builtin.Empty) builtin.Empty + calls two.two + } + + module two { + verb two(builtin.Empty) builtin.Empty + calls one.one + } + `, + err: "found cycle in dependencies: two -> one -> two", + }, + { + // one --> two --> three, noncyclical + name: "ThreeModulesNoCycle", + schema: ` + module one { + verb one(builtin.Empty) builtin.Empty + calls two.two + } + + module two { + verb two(builtin.Empty) builtin.Empty + calls three.three + } + + module three { + verb three(builtin.Empty) builtin.Empty + } + `, + err: "", + }, + { + // one --> two --> three -> one, cyclical + name: "ThreeModulesCycle", + schema: ` + module one { + verb one(builtin.Empty) builtin.Empty + calls two.two + } + + module two { + verb two(builtin.Empty) builtin.Empty + calls three.three + } + + module three { + verb three(builtin.Empty) builtin.Empty + calls one.one + } + `, + err: "found cycle in dependencies: two -> three -> one -> two", + }, + { + // one.a --> two.a + // one.b <--- + // cyclical (does not depend on verbs used) + name: "TwoModuleCycleDiffVerbs", + schema: ` + module one { + verb a(builtin.Empty) builtin.Empty + calls two.a + verb b(builtin.Empty) builtin.Empty + } + + module two { + verb a(builtin.Empty) builtin.Empty + calls one.b + } + `, + err: "found cycle in dependencies: two -> one -> two", + }, + { + // one --> one, this is allowed + name: "SelfReference", + schema: ` + module one { + verb a(builtin.Empty) builtin.Empty + calls one.b + + verb b(builtin.Empty) builtin.Empty + calls one.a + } + `, + err: "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := ParseString("", test.schema) + if test.err == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, test.err) + } + }) + } +} diff --git a/backend/schema/validate.go b/backend/schema/validate.go index 5bc866a9dd..109b62b57b 100644 --- a/backend/schema/validate.go +++ b/backend/schema/validate.go @@ -58,6 +58,11 @@ func Validate(schema *Schema) (*Schema, error) { scopes := NewScopes() + // Validate dependencies + if err := validateDependencies(schema); err != nil { + merr = append(merr, err) + } + // First pass, add all the modules. for _, module := range schema.Modules { if module == builtins { @@ -309,3 +314,76 @@ func cleanErrors(merr []error) []error { }) return merr } + +type dependencyVertex struct { + from, to string +} + +type dependencyVertexState int + +const ( + notExplored dependencyVertexState = iota + exploring + fullyExplored +) + +func validateDependencies(schema *Schema) error { + // go through schema's modules, find cycles in modules' dependencies + + // First pass, set up direct imports and vertex states for each module + // We need each import array and vertex array to be sorted to make the output deterministic + imports := map[string][]string{} + vertexes := []dependencyVertex{} + vertexStates := map[dependencyVertex]dependencyVertexState{} + + for _, module := range schema.Modules { + currentImports := module.Imports() + sort.Strings(currentImports) + imports[module.Name] = currentImports + + for _, imp := range currentImports { + v := dependencyVertex{module.Name, imp} + vertexes = append(vertexes, v) + vertexStates[v] = notExplored + } + } + + sort.Slice(vertexes, func(i, j int) bool { + lhs := vertexes[i] + rhs := vertexes[j] + return lhs.from < rhs.from || (lhs.from == rhs.from && lhs.to < rhs.to) + }) + + // DFS to find cycles + for _, v := range vertexes { + if cycle := dfsForDependencyCycle(imports, vertexStates, v); cycle != nil { + return fmt.Errorf("found cycle in dependencies: %s", strings.Join(cycle, " -> ")) + } + } + + return nil +} + +func dfsForDependencyCycle(imports map[string][]string, vertexStates map[dependencyVertex]dependencyVertexState, v dependencyVertex) []string { + switch vertexStates[v] { + case notExplored: + vertexStates[v] = exploring + + for _, toModule := range imports[v.to] { + nextV := dependencyVertex{v.to, toModule} + if cycle := dfsForDependencyCycle(imports, vertexStates, nextV); cycle != nil { + // found cycle. prepend current module to cycle and return + cycle = append([]string{nextV.from}, cycle...) + return cycle + } + } + vertexStates[v] = fullyExplored + return nil + case exploring: + return []string{v.to} + case fullyExplored: + return nil + } + + return nil +}