diff --git a/cmd/schema_gen.go b/cmd/schema_gen.go index 37169fd..cc9d688 100644 --- a/cmd/schema_gen.go +++ b/cmd/schema_gen.go @@ -146,8 +146,8 @@ func suggestDocsForGraph(graph *fs.Graph) error { } // generateSchemaForModel generates a schema and writes yml for modelName. -func generateSchemaForModel(ctx context.Context, model *fs.File) error { - target, err := model.GetTarget() +func generateSchemaForModel(ctx context.Context, file *fs.File) error { + target, err := file.GetTarget() if err != nil { fmt.Println("could not get target for schema") return err @@ -155,7 +155,7 @@ func generateSchemaForModel(ctx context.Context, model *fs.File) error { fmt.Println("\n๐ŸŽฏ Target for retrieving schema:", target.ProjectID+"."+target.DataSet) // retrieve columns from BigQuery - bqColumns, err := getColumnsForModel(ctx, model.Name, target) + bqColumns, err := getColumnsForModel(ctx, file.Name, target) if err != nil { fmt.Println("Could not retrieve schema") return err @@ -163,16 +163,16 @@ func generateSchemaForModel(ctx context.Context, model *fs.File) error { fmt.Println("โœ… BQ Schema retrieved. Number of columns in BQ table:", len(bqColumns)) // create schema file - ymlPath, schemaFile := generateEmptySchemaFile(model) + ymlPath, schemaFile := generateEmptySchemaFile(file) var schemaModel *properties.Model - if model.Schema == nil { - fmt.Println("\n๐Ÿ” " + model.Name + " schema file not found.. ๐ŸŒฑ Generating new schema file") - schemaModel = generateNewSchemaModel(model.Name, bqColumns) + if file.Schema == nil { + fmt.Println("\n๐Ÿ” " + file.Name + " schema file not found.. ๐ŸŒฑ Generating new schema file") + schemaModel = generateNewSchemaModel(file.Name, bqColumns) } else { - fmt.Println("\n๐Ÿ” " + model.Name + " schema file found.. ๐Ÿ›  Updating schema file") - // set working schema model to current schema model - schemaModel = model.Schema + fmt.Println("\n๐Ÿ” " + file.Name + " schema file found.. ๐Ÿ›  Updating schema file") + // set working schema file to current schema file + schemaModel = file.Schema // add and remove columns in-place addMissingColumnsToSchema(schemaModel, bqColumns) removeOutdatedColumnsFromSchema(schemaModel, bqColumns) @@ -184,7 +184,7 @@ func generateSchemaForModel(ctx context.Context, model *fs.File) error { fmt.Println("Error writing YML to file in path") return err } - fmt.Println("\nโœ… " + model.Name + "schema successfully updated at path: " + ymlPath) + fmt.Println("\nโœ… " + file.Name + "schema successfully updated at path: " + ymlPath) return nil } @@ -205,8 +205,8 @@ func getColumnsForModel(ctx context.Context, modelName string, target *config.Ta // generate an empty schema file which will be populated according to existing yml schemas and the bigquery schema. // Returns the local path for the yml file and the yml file struct -func generateEmptySchemaFile(model *fs.File) (ymlPath string, schemaFile properties.File) { - ymlPath = strings.Replace(model.Path, ".sql", ".yml", 1) +func generateEmptySchemaFile(file *fs.File) (ymlPath string, schemaFile properties.File) { + ymlPath = strings.Replace(file.Path, ".sql", ".yml", 1) schemaFile = properties.File{} schemaFile.Version = properties.FileVersion return ymlPath, schemaFile diff --git a/cmd/test_gen.go b/cmd/test_gen.go new file mode 100644 index 0000000..c0365e1 --- /dev/null +++ b/cmd/test_gen.go @@ -0,0 +1,318 @@ +package cmd + +import ( + "context" + "ddbt/bigquery" + "ddbt/config" + "ddbt/fs" + "ddbt/properties" + schemaTestMacros "ddbt/schemaTestMacros" + "ddbt/utils" + "fmt" + "os" + "strings" + "sync" + + "github.com/spf13/cobra" +) + +// TODO: +// only output suggestions to terminal for new tests +// Parse macro files +// Test with value inputs e.g. accepted values + +func init() { + rootCmd.AddCommand(testGenCmd) + addModelsFlag(testGenCmd) +} + +type ColumnTestQuery struct { + Column string + TestName string + TestQuery string +} + +type TestSuggestions struct { + mu sync.Mutex + suggestions map[string]map[string][]string +} + +func (d *TestSuggestions) SetSuggestion(modelName string, testSuggestions map[string][]string) { + d.mu.Lock() + d.suggestions[modelName] = testSuggestions + d.mu.Unlock() +} + +func (d *TestSuggestions) Init() { + d.mu.Lock() + d.suggestions = make(map[string]map[string][]string) + d.mu.Unlock() +} + +func (d *TestSuggestions) Value() (suggestions map[string]map[string][]string) { + d.mu.Lock() + suggestions = d.suggestions + d.mu.Unlock() + return +} + +var testGenCmd = &cobra.Command{ + Use: "test-gen [model name]", + Short: "Suggests tests to add to the YML schema file for a given model", + Args: cobra.RangeArgs(0, 1), + ValidArgsFunction: completeModelFn, + Run: func(cmd *cobra.Command, args []string) { + switch { + case len(args) == 0 && len(ModelFilters) == 0: + fmt.Println("Please specify model with test-gen -m model-name") + os.Exit(1) + case len(args) == 1 && len(ModelFilters) > 0: + fmt.Println("Please specify model with either test-gen model-name or test-gen -m model-name but not both") + os.Exit(1) + case len(args) == 1: + // This will actually allow something weird like + // ddbt test-gen +model+ + ModelFilters = append(ModelFilters, args[0]) + } + + fmt.Println(`โ„น๏ธ test-gen requires models without sampling to be accurate +โ“ Have you run the provided models with sampling off? (y/N)`) + var userPrompt string + fmt.Scanln(&userPrompt) + if userPrompt != "y" { + fmt.Println("โŒ Please run your model with sampling and then use test-gen") + os.Exit(1) + } + + // Build a graph from the given filter. + fileSystem, _ := compileAllModels() + graph := buildGraph(fileSystem, ModelFilters) + + // Generate schema for every file in the graph concurrently. + if err := generateTestsForModelsGraph(graph); err != nil { + fmt.Printf("โŒ %s\n", err) + os.Exit(1) + } + os.Exit(1) + }, +} + +func generateTestsForModelsGraph(graph *fs.Graph) error { + pb := utils.NewProgressBar("๐Ÿ–จ Generating tests for models in graph", graph.Len()) + + ctx, cancel := context.WithCancel(context.Background()) + var testSugs TestSuggestions + testSugs.Init() + + err := graph.Execute(func(file *fs.File) error { + if file.Type == fs.ModelFile { + testSuggestions, err := generateTestsForModel(ctx, file) + if err != nil { + pb.Stop() + if err != context.Canceled { + fmt.Printf("โŒ %s\n", err) + } + cancel() + return err + } + testSugs.SetSuggestion(file.Name, testSuggestions) + } + + pb.Increment() + return nil + }, config.NumberThreads(), pb) + + if err != nil { + return err + } + pb.Stop() + + err = userPromptTests(graph, testSugs.suggestions) + if err != nil { + return err + } + + return nil +} + +// generateTestsForModel generates tests for model and writes yml schema for modelName. +func generateTestsForModel(ctx context.Context, file *fs.File) (map[string][]string, error) { + target, err := file.GetTarget() + if err != nil { + fmt.Println("could not get target for schema") + return nil, err + } + fmt.Println("\n๐ŸŽฏ Target for retrieving schema:", target.ProjectID+"."+target.DataSet) + + // retrieve columns from BigQuery + bqColumns, err := getColumnsForModel(ctx, file.Name, target) + if err != nil { + fmt.Println("Could not retrieve schema") + return nil, err + } + fmt.Println("โœ… BQ Schema retrieved. Number of columns in BQ table:", len(bqColumns)) + + // iterate through functions which return test sql and definition + testFuncs := []func(string, string, string, string) (string, string){ + schemaTestMacros.TestNotNullMacro, + schemaTestMacros.TestUniqueMacro, + } + + var allTestQueries []ColumnTestQuery + for _, col := range bqColumns { + for _, test := range testFuncs { + testQuery, testName := test(target.ProjectID, target.DataSet, file.Name, col) + allTestQueries = append(allTestQueries, ColumnTestQuery{ + Column: col, + TestName: testName, + TestQuery: testQuery, + }) + } + } + + passedTestQueries, err := runQueriesParallel(ctx, target, allTestQueries) + if err != nil { + return nil, err + } + updateSchemaFile(passedTestQueries, file) + + return passedTestQueries, nil +} + +func runQueriesParallel(ctx context.Context, target *config.Target, allTestQueries []ColumnTestQuery) (map[string][]string, error) { + // number of parallel query runners + numQueryRunners := 100 + + queries := make(chan ColumnTestQuery) + go func() { + for _, q := range allTestQueries { + queries <- q + } + close(queries) + }() + + out := make(chan ColumnTestQuery, len(allTestQueries)) + errs := make(chan error, len(allTestQueries)) + wg := sync.WaitGroup{} + + wg.Add(numQueryRunners) + for i := 0; i < numQueryRunners; i++ { + go func(i int) { + defer wg.Done() + for query := range queries { + evaluateTestQuery(ctx, target, query, out, errs, i) + } + }(i) + } + + go func() { + wg.Wait() + close(out) + close(errs) + }() + + if err := <-errs; err != nil { + // there is at least one error, but we ignore the rest + return nil, err + } + + passedTestQueries := make(map[string][]string) + for passedTestQuery := range out { + if _, contains := passedTestQueries[passedTestQuery.Column]; contains { + passedTestQueries[passedTestQuery.Column] = append(passedTestQueries[passedTestQuery.Column], passedTestQuery.TestName) + } else { + passedTestQueries[passedTestQuery.Column] = []string{passedTestQuery.TestName} + } + } + + return passedTestQueries, nil +} + +func evaluateTestQuery(ctx context.Context, target *config.Target, ctq ColumnTestQuery, out chan ColumnTestQuery, + errs chan error, workerIndex int) { + results, _, err := bigquery.GetRows(ctx, ctq.TestQuery, target) + + if err == nil { + if len(results) != 1 { + errs <- fmt.Errorf(fmt.Sprintf( + "a schema test should only return 1 row, got %d for %s test on column %s by worker %v", + len(results), ctq.TestName, ctq.Column, workerIndex), + ) + } else if len(results[0]) != 1 { + errs <- fmt.Errorf(fmt.Sprintf( + "a schema test should only return 1 column, got %d for %s test on column %s by worker %v", + len(results), ctq.TestName, ctq.Column, workerIndex), + ) + } else { + rows, _ := bigquery.ValueAsUint64(results[0][0]) + if rows == 0 { + out <- ctq + } + } + } + if err != nil { + errs <- err + } + +} + +func updateSchemaFile(passedTestQueries map[string][]string, model *fs.File) { + updatedColumns := model.Schema.Columns + for colIndex, column := range model.Schema.Columns { + if _, exists := passedTestQueries[column.Name]; exists { + + // search for test in existing tests + for _, test := range passedTestQueries[column.Name] { + testFound := false + for _, existingTest := range column.Tests { + if existingTest.Name == test { + testFound = true + break + } + } + if !testFound { + column.Tests = append(column.Tests, &properties.Test{ + Name: test, + }) + } + } + } + updatedColumns[colIndex] = column + } + model.Schema.Columns = updatedColumns +} + +func userPromptTests(graph *fs.Graph, testSugsMap map[string]map[string][]string) error { + if len(testSugsMap) > 0 { + fmt.Println("\n๐Ÿงช Valid tests found for the following models: ") + for model, columnTests := range testSugsMap { + fmt.Println("\n๐Ÿงฌ Model:", model) + for column, tests := range columnTests { + fmt.Println("๐Ÿ› Column:", column) + testPrint := strings.Join(tests[:], "\n - ") + fmt.Println(" -", testPrint) + } + } + fmt.Println("\nโ” Would you like to add these tests to the schema (y/N)?") + + var userPrompt string + fmt.Scanln(&userPrompt) + + if userPrompt == "y" { + for file, _ := range graph.ListNodes() { + if _, contains := testSugsMap[file.Name]; contains { + ymlPath, schemaFile := generateEmptySchemaFile(file) + schemaModel := file.Schema + schemaFile.Models = properties.Models{schemaModel} + err := schemaFile.WriteToFile(ymlPath) + if err != nil { + fmt.Println("Error writing YML to file in path") + return err + } + } + } + fmt.Println("โœ… Tests added to schema files") + } + } + return nil +} diff --git a/schemaTestMacros/testNotNullMacro.go b/schemaTestMacros/testNotNullMacro.go new file mode 100644 index 0000000..bda5a1e --- /dev/null +++ b/schemaTestMacros/testNotNullMacro.go @@ -0,0 +1,9 @@ +package schemaTestMacros + +import "fmt" + +func TestNotNullMacro(project string, dataset string, model string, column_name string) (string, string) { + return fmt.Sprintf(`select count(*) + from %s.%s.%s where %s is null + `, project, dataset, model, column_name), "not_null" +} diff --git a/schemaTestMacros/testUniqueMacro.go b/schemaTestMacros/testUniqueMacro.go new file mode 100644 index 0000000..0c7a168 --- /dev/null +++ b/schemaTestMacros/testUniqueMacro.go @@ -0,0 +1,16 @@ +package schemaTestMacros + +import "fmt" + +func TestUniqueMacro(project string, dataset string, model string, column_name string) (string, string) { + return fmt.Sprintf(`select count(*) + from ( + select + %s + from %s.%s.%s + where %s is not null + group by %s + having count(*) > 1 + ) validation_errors + `, column_name, project, dataset, model, column_name, column_name), "unique" +} diff --git a/utils/version.go b/utils/version.go index b54bfad..33ee5fd 100644 --- a/utils/version.go +++ b/utils/version.go @@ -1,3 +1,3 @@ package utils -const DdbtVersion = "0.5.1" +const DdbtVersion = "0.6.0"