From 367ed4adff6394d3a4a527c671eae65e04186112 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 25 Jul 2024 12:33:25 +0200 Subject: [PATCH 01/18] faker update + mcokdata test. --- cmd/cli_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ cmd/mock_data.go | 24 ++++++++++-------- go.mod | 3 ++- go.sum | 7 +++-- 4 files changed, 87 insertions(+), 13 deletions(-) create mode 100644 cmd/cli_test.go diff --git a/cmd/cli_test.go b/cmd/cli_test.go new file mode 100644 index 0000000..b05f28e --- /dev/null +++ b/cmd/cli_test.go @@ -0,0 +1,66 @@ +package cmd + +import ( + "database/sql" + "database/sql/driver" + "github.com/google/uuid" + "regexp" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +func TestCreateMockData(t *testing.T) { + // Create a new mock database connection + sqlDB, mock, err := sqlmock.New() + assert.NoError(t, err) + defer func(sqlDB *sql.DB) { + err := sqlDB.Close() + if err != nil { + + } + }(sqlDB) + + // Create a new GORM DB instance with the mock database + dialector := postgres.New(postgres.Config{ + Conn: sqlDB, + DriverName: "postgres", + }) + db, err := gorm.Open(dialector, &gorm.Config{}) + assert.NoError(t, err) + + createExpectations := func(tableName string, count int, argCount int) { + for i := 0; i < count; i++ { + mock.ExpectBegin() + args := make([]driver.Value, argCount) + for j := range args { + args[j] = sqlmock.AnyArg() + } + mock.ExpectQuery(regexp.QuoteMeta(`INSERT INTO "` + tableName + `"`)). + WithArgs(args...). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}). + AddRow(uuid.New(), time.Now(), time.Now())) + mock.ExpectCommit() + } + } + + // Set up expectations for all tables + createExpectations("tags", 10, 6) + createExpectations("ai_models", 2, 10) + createExpectations("api_keys", 2, 7) + createExpectations("audit_logs", 2, 11) + createExpectations("products", 2, 7) + createExpectations("usages", 2, 10) + createExpectations("workspaces", 2, 6) + + createMockData(db) + + err = mock.ExpectationsWereMet() + if err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/cmd/mock_data.go b/cmd/mock_data.go index b843f7e..bf032e1 100644 --- a/cmd/mock_data.go +++ b/cmd/mock_data.go @@ -2,7 +2,7 @@ package cmd import ( "fmt" - "github.com/bxcodec/faker/v4" + "github.com/go-faker/faker/v4" "github.com/openshieldai/openshield/lib" "github.com/openshieldai/openshield/models" "gorm.io/gorm" @@ -34,15 +34,19 @@ func init() { }) } -func createMockData() { - db := lib.DB() - createMockTags(db, 10) - createMockRecords(db, &models.AiModels{}, 2) - createMockRecords(db, &models.ApiKeys{}, 2) - createMockRecords(db, &models.AuditLogs{}, 2) - createMockRecords(db, &models.Products{}, 2) - createMockRecords(db, &models.Usage{}, 2) - createMockRecords(db, &models.Workspaces{}, 2) +func createMockData(db ...*gorm.DB) { + database := lib.DB() + if len(db) > 0 && db[0] != nil { + database = db[0] + } + database = database.Debug() + createMockTags(database, 10) + createMockRecords(database, &models.AiModels{}, 2) + createMockRecords(database, &models.ApiKeys{}, 2) + createMockRecords(database, &models.AuditLogs{}, 2) + createMockRecords(database, &models.Products{}, 2) + createMockRecords(database, &models.Usage{}, 2) + createMockRecords(database, &models.Workspaces{}, 2) } func createMockTags(db *gorm.DB, count int) { diff --git a/go.mod b/go.mod index 1d80d6d..0fa0715 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,10 @@ module github.com/openshieldai/openshield go 1.22 require ( - github.com/bxcodec/faker/v4 v4.0.0-beta.3 + github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/cespare/xxhash/v2 v2.3.0 github.com/fsnotify/fsnotify v1.7.0 + github.com/go-faker/faker/v4 v4.4.2 github.com/gofiber/fiber/v2 v2.52.5 github.com/gofiber/storage/redis/v3 v3.1.2 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 98c54ed..de51bbe 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,11 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/bxcodec/faker/v4 v4.0.0-beta.3 h1:gqYNBvN72QtzKkYohNDKQlm+pg+uwBDVMN28nWHS18k= -github.com/bxcodec/faker/v4 v4.0.0-beta.3/go.mod h1:m6+Ch1Lj3fqW/unZmvkXIdxWS5+XQWPWxcbbQW2X+Ho= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= @@ -21,6 +21,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-faker/faker/v4 v4.4.2 h1:96WeU9QKEqRUVYdjHquY2/5bAqmVM0IfGKHV5mbfqmQ= +github.com/go-faker/faker/v4 v4.4.2/go.mod h1:4K3v4AbKXYNHMQNaREMc9/kRB9j5JJzpFo6KHRvrcIw= github.com/gofiber/fiber/v2 v2.52.5 h1:tWoP1MJQjGEe4GB5TUGOi7P2E0ZMMRx5ZTG4rT+yGMo= github.com/gofiber/fiber/v2 v2.52.5/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= github.com/gofiber/storage/redis/v3 v3.1.2 h1:qYHSRbkRQCD9HovLOOoswe+DoGF28/hwD4d8kmxDNcs= @@ -45,6 +47,7 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= From 8e9aef49d4cbc09b84a1079e8351cef1e33eac77 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 25 Jul 2024 12:38:22 +0200 Subject: [PATCH 02/18] handle warnings --- cmd/mock_data.go | 55 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/cmd/mock_data.go b/cmd/mock_data.go index bf032e1..778bb66 100644 --- a/cmd/mock_data.go +++ b/cmd/mock_data.go @@ -15,23 +15,42 @@ var generatedTags []string func init() { // Faker providers setup - faker.AddProvider("aifamily", func(v reflect.Value) (interface{}, error) { - return string(models.OpenAI), nil - }) - - faker.AddProvider("status", func(v reflect.Value) (interface{}, error) { - statuses := []string{string(models.Active), string(models.Inactive), string(models.Archived)} - return statuses[rand.Intn(len(statuses))], nil - }) - - faker.AddProvider("finishreason", func(v reflect.Value) (interface{}, error) { - statuses := []string{string(models.Stop), string(models.Length), string(models.Null), string(models.FunctionCall), string(models.ContentFilter)} - return statuses[rand.Intn(len(statuses))], nil - }) - - faker.AddProvider("tags", func(v reflect.Value) (interface{}, error) { - return getRandomTags(), nil - }) + { + err := faker.AddProvider("aifamily", func(v reflect.Value) (interface{}, error) { + return string(models.OpenAI), nil + }) + if err != nil { + return + } + } + { + err := faker.AddProvider("status", func(v reflect.Value) (interface{}, error) { + statuses := []string{string(models.Active), string(models.Inactive), string(models.Archived)} + return statuses[rand.Intn(len(statuses))], nil + }) + { + if err != nil { + return + } + } + } + { + err := faker.AddProvider("finishreason", func(v reflect.Value) (interface{}, error) { + statuses := []string{string(models.Stop), string(models.Length), string(models.Null), string(models.FunctionCall), string(models.ContentFilter)} + return statuses[rand.Intn(len(statuses))], nil + }) + if err != nil { + return + } + } + { + err := faker.AddProvider("tags", func(v reflect.Value) (interface{}, error) { + return getRandomTags(), nil + }) + if err != nil { + return + } + } } func createMockData(db ...*gorm.DB) { @@ -93,7 +112,7 @@ func createMockRecords(db *gorm.DB, model interface{}, count int) { fmt.Printf("%+v\n\n", model) result := db.Create(model) if result.Error != nil { - fmt.Errorf("error inserting fake data for %T: %v", model, result.Error) + _ = fmt.Errorf("error inserting fake data for %T: %v", model, result.Error) } } } From cb8966b3f9ca1e43e13e8f222945bd767c2a5bea Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 25 Jul 2024 14:56:50 +0200 Subject: [PATCH 03/18] create tables test --- cmd/cli.go | 22 +++++++++++++++++++++- cmd/cli_test.go | 40 ++++++++++++++++++++++++++++++++++++++-- cmd/mock_data.go | 1 - lib/db.go | 8 +++++--- 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/cmd/cli.go b/cmd/cli.go index 2b7b39f..4998444 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -3,8 +3,10 @@ package cmd import ( "fmt" "github.com/openshieldai/openshield/lib" + "github.com/openshieldai/openshield/models" "github.com/openshieldai/openshield/server" "github.com/spf13/cobra" + "gorm.io/gorm" "os" "os/signal" "syscall" @@ -42,7 +44,7 @@ var createTablesCmd = &cobra.Command{ Use: "create-tables", Short: "Create database tables from models", Run: func(cmd *cobra.Command, args []string) { - lib.DB() // Call the function from lib package + createTables() }, } @@ -144,3 +146,21 @@ func stopServer() error { return nil } +func createTables(db ...*gorm.DB) { + database := lib.DB() + if len(db) > 0 && db[0] != nil { + database = db[0] + err := database.AutoMigrate( + &models.Tags{}, + &models.AiModels{}, + &models.ApiKeys{}, + &models.AuditLogs{}, + &models.Products{}, + &models.Usage{}, + &models.Workspaces{}, + ) + if err != nil { + _ = fmt.Errorf("failed to migrate models: %v", err) + } + } +} diff --git a/cmd/cli_test.go b/cmd/cli_test.go index b05f28e..5bfd41e 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -3,6 +3,7 @@ package cmd import ( "database/sql" "database/sql/driver" + "fmt" "github.com/google/uuid" "regexp" "testing" @@ -15,7 +16,7 @@ import ( ) func TestCreateMockData(t *testing.T) { - // Create a new mock database connection + sqlDB, mock, err := sqlmock.New() assert.NoError(t, err) defer func(sqlDB *sql.DB) { @@ -25,7 +26,6 @@ func TestCreateMockData(t *testing.T) { } }(sqlDB) - // Create a new GORM DB instance with the mock database dialector := postgres.New(postgres.Config{ Conn: sqlDB, DriverName: "postgres", @@ -64,3 +64,39 @@ func TestCreateMockData(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestCreateTables(t *testing.T) { + // Create a new mock database connection + sqlDB, mock, err := sqlmock.New() + assert.NoError(t, err) + defer func(sqlDB *sql.DB) { + err := sqlDB.Close() + if err != nil { + _ = fmt.Errorf("failed to create mock db %v", err) + } + }(sqlDB) + + dialector := postgres.New(postgres.Config{ + Conn: sqlDB, + DriverName: "postgres", + }) + db, err := gorm.Open(dialector, &gorm.Config{}) + assert.NoError(t, err) + + tables := []string{"tags", "ai_models", "api_keys", "audit_logs", "products", "usage", "workspaces"} + for _, table := range tables { + mock.ExpectQuery(`SELECT EXISTS \(SELECT FROM information_schema.tables WHERE table_name = \$1\)`). + WithArgs(table). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + } + + for _, table := range tables { + var exists bool + err := db.Raw("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = ?)", table).Scan(&exists).Error + assert.NoError(t, err) + assert.True(t, exists, "Table %s should exist", table) + } + + err = mock.ExpectationsWereMet() + assert.NoError(t, err) +} diff --git a/cmd/mock_data.go b/cmd/mock_data.go index 778bb66..6f36ebc 100644 --- a/cmd/mock_data.go +++ b/cmd/mock_data.go @@ -58,7 +58,6 @@ func createMockData(db ...*gorm.DB) { if len(db) > 0 && db[0] != nil { database = db[0] } - database = database.Debug() createMockTags(database, 10) createMockRecords(database, &models.AiModels{}, 2) createMockRecords(database, &models.ApiKeys{}, 2) diff --git a/lib/db.go b/lib/db.go index 7b83789..5ba699b 100644 --- a/lib/db.go +++ b/lib/db.go @@ -17,11 +17,13 @@ func DB() *gorm.DB { } if config.Settings.Database.AutoMigration { - err := connection.AutoMigrate(&models.ApiKeys{}, - &models.Tags{}, + err := connection.AutoMigrate(&models.Tags{}, + &models.AiModels{}, + &models.ApiKeys{}, &models.AuditLogs{}, + &models.Products{}, &models.Usage{}, - &models.AiModels{}) + &models.Workspaces{}) if err != nil { log.Panic(err) } From c3ab70175e2b3662ee564d99e29ddc2ed73fcd10 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 25 Jul 2024 16:53:49 +0200 Subject: [PATCH 04/18] test for removerule @ start for addrule --- cmd/cli_test.go | 221 +++++++++++++++++++++++++++++++++++++++++++++++- cmd/config.go | 27 ++++-- 2 files changed, 241 insertions(+), 7 deletions(-) diff --git a/cmd/cli_test.go b/cmd/cli_test.go index 5bfd41e..5a9bc4e 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -1,10 +1,16 @@ package cmd import ( + "bytes" "database/sql" "database/sql/driver" "fmt" "github.com/google/uuid" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "io" + "io/ioutil" + "os" "regexp" "testing" "time" @@ -48,7 +54,6 @@ func TestCreateMockData(t *testing.T) { } } - // Set up expectations for all tables createExpectations("tags", 10, 6) createExpectations("ai_models", 2, 10) createExpectations("api_keys", 2, 7) @@ -66,7 +71,7 @@ func TestCreateMockData(t *testing.T) { } func TestCreateTables(t *testing.T) { - // Create a new mock database connection + sqlDB, mock, err := sqlmock.New() assert.NoError(t, err) defer func(sqlDB *sql.DB) { @@ -100,3 +105,215 @@ func TestCreateTables(t *testing.T) { err = mock.ExpectationsWereMet() assert.NoError(t, err) } +func TestAddAndRemoveRuleConfig(t *testing.T) { + + tmpfile, err := ioutil.TempFile("", "config.*.yaml") + if err != nil { + t.Fatal(err) + } + + defer func(name string) { + err := os.Remove(name) + if err != nil { + return + } + }(tmpfile.Name()) + { + err := os.Setenv("OPENSHIELD_CONFIG_FILE", tmpfile.Name()) + if err != nil { + return + } + defer func() { + err := os.Unsetenv("OPENSHIELD_CONFIG_FILE") + if err != nil { + return + } + }() + } + initialConfig := ` +filters: + input: + - name: existing_rule + type: pii_filter + enabled: true + action: + type: redact + config: + plugin_name: pii_plugin + threshold: 80 +` + if _, err := tmpfile.Write([]byte(initialConfig)); err != nil { + t.Fatal(err) + } + { + err := tmpfile.Close() + if err != nil { + return + } + } + viper.Reset() + viper.SetConfigFile(tmpfile.Name()) + if err := viper.ReadInConfig(); err != nil { + t.Fatalf("Error reading config file: %v", err) + } + + t.Run("AddRule", func(t *testing.T) { + input := "input\nnew_rule\nsentiment_filter\nblock\nsentiment_plugin\n90\n" + t.Logf("AddRule Input:\n%s", input) + inputBuffer := bytes.NewBufferString(input) + output, err := executeCommandWithInput(rootCmd, inputBuffer, "config", "add-rule") + if err != nil { + t.Fatalf("Error executing add-rule command: %v", err) + } + + t.Logf("Add Rule Command Output:\n%s", output) + + // Verify the output + assert.Contains(t, output, "Rule added successfully") + + // Verify the config was modified + v := viper.New() + v.SetConfigFile(tmpfile.Name()) + err = v.ReadInConfig() + if err != nil { + t.Fatalf("Error reading updated config: %v", err) + } + + rules := v.Get("filters.input") + rulesSlice, ok := rules.([]interface{}) + if !ok { + t.Fatalf("Expected rules to be a slice, got %T", rules) + } + + assert.Len(t, rulesSlice, 2, "Expected 2 rules after addition") + if len(rulesSlice) > 1 { + newRule := rulesSlice[1].(map[string]interface{}) + assert.Equal(t, "new_rule", newRule["name"]) + assert.Equal(t, "sentiment_filter", newRule["type"]) + assert.Equal(t, true, newRule["enabled"]) + assert.Equal(t, "block", newRule["action"].(map[string]interface{})["type"]) + assert.Equal(t, "sentiment_plugin", newRule["config"].(map[string]interface{})["plugin_name"]) + assert.Equal(t, float64(90), newRule["config"].(map[string]interface{})["threshold"]) + } + }) + + // Test removeRule + t.Run("RemoveRule", func(t *testing.T) { + input := "input\n2\n" + t.Logf("RemoveRule Input:\n%s", input) + inputBuffer := bytes.NewBufferString(input) + output, err := executeCommandWithInput(rootCmd, inputBuffer, "config", "remove-rule") + if err != nil { + t.Fatalf("Error executing remove-rule command: %v", err) + } + + t.Logf("Remove Rule Command Output:\n%s", output) + + // Verify the output + assert.Contains(t, output, "Rule 'new_rule' removed successfully") + + // Verify the config was modified + v := viper.New() + v.SetConfigFile(tmpfile.Name()) + err = v.ReadInConfig() + if err != nil { + t.Fatalf("Error reading updated config: %v", err) + } + + rules := v.Get("filters.input") + rulesSlice, ok := rules.([]interface{}) + if !ok { + t.Fatalf("Expected rules to be a slice, got %T", rules) + } + + assert.Len(t, rulesSlice, 1, "Expected 1 rule after removal") + if len(rulesSlice) > 0 { + remainingRule := rulesSlice[0].(map[string]interface{}) + assert.Equal(t, "existing_rule", remainingRule["name"], "Expected 'existing_rule' to remain") + } + }) +} + +func executeCommandWithInput(cmd *cobra.Command, input *bytes.Buffer, args ...string) (string, error) { + cmd.SetArgs(args) + + // Save the original stdin, stdout, and stderr + oldStdin := os.Stdin + oldStdout := os.Stdout + oldStderr := os.Stderr + defer func() { + os.Stdin = oldStdin + os.Stdout = oldStdout + os.Stderr = oldStderr + }() + + // Create pipes for stdin, stdout, and stderr + inr, inw, _ := os.Pipe() + outr, outw, _ := os.Pipe() + errr, errw, _ := os.Pipe() + + os.Stdin = inr + os.Stdout = outw + os.Stderr = errw + + // Write the input to the pipe in a separate goroutine + go func() { + defer func(inw *os.File) { + err := inw.Close() + if err != nil { + return + } + }(inw) + _, err := input.WriteTo(inw) + if err != nil { + return + } + }() + + // Capture the output and error in separate goroutines + output := &bytes.Buffer{} + outputDone := make(chan bool) + go func() { + _, err := io.Copy(output, outr) + if err != nil { + return + } + outputDone <- true + }() + + errorOutput := &bytes.Buffer{} + errorDone := make(chan bool) + go func() { + _, err := io.Copy(errorOutput, errr) + if err != nil { + return + } + errorDone <- true + }() + + // Execute the command + err := cmd.Execute() + + // Close the write end of the pipes + { + err := outw.Close() + if err != nil { + return "", err + } + } + { + err := errw.Close() + if err != nil { + return "", err + } + } + + // Wait for the output and error to be fully read + <-outputDone + <-errorDone + + // Combine stdout and stderr + combinedOutput := output.String() + errorOutput.String() + + return combinedOutput, err +} diff --git a/cmd/config.go b/cmd/config.go index 229ded5..6518b03 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -45,7 +45,10 @@ func editConfig() { fmt.Println("\nEnter the number of the setting you want to change, or 'q' to quit:") var input string - fmt.Scanln(&input) + _, err := fmt.Scanln(&input) + if err != nil { + return + } if input == "q" { break @@ -196,7 +199,11 @@ func updateConfig(v *viper.Viper, path string, value string) error { } func addRule() { v := viper.New() - v.SetConfigFile("config.yaml") + configFile := os.Getenv("OPENSHIELD_CONFIG_FILE") + if configFile == "" { + configFile = "config.yaml" // Default to config.yaml if env var is not set + } + v.SetConfigFile(configFile) err := v.ReadInConfig() if err != nil { fmt.Printf("Error reading config file: %v\n", err) @@ -205,7 +212,10 @@ func addRule() { var ruleType string fmt.Print("Enter rule type (input/output): ") - fmt.Scanln(&ruleType) + _, err = fmt.Scanln(&ruleType) + if err != nil { + return + } if ruleType != "input" && ruleType != "output" { fmt.Println("Invalid rule type. Please enter 'input' or 'output'.") @@ -270,7 +280,11 @@ func createRuleWizard() map[string]interface{} { func removeRule() { v := viper.New() - v.SetConfigFile("config.yaml") + configFile := os.Getenv("OPENSHIELD_CONFIG_FILE") + if configFile == "" { + configFile = "config.yaml" // Default to config.yaml if env var is not set + } + v.SetConfigFile(configFile) err := v.ReadInConfig() if err != nil { fmt.Printf("Error reading config file: %v\n", err) @@ -279,7 +293,10 @@ func removeRule() { var ruleType string fmt.Print("Enter rule type (input/output): ") - fmt.Scanln(&ruleType) + _, err = fmt.Scanln(&ruleType) + if err != nil { + return + } if ruleType != "input" && ruleType != "output" { fmt.Println("Invalid rule type. Please enter 'input' or 'output'.") From 906f12d93beb576d5e2dfab752925ec999ab0ed1 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Fri, 26 Jul 2024 11:23:53 +0200 Subject: [PATCH 05/18] Removerule input + models fix --- cmd/cli_test.go | 74 +++++++++++++++++++++++++++++++++++++++++--- models/aimodels.go | 14 ++++----- models/api_keys.go | 8 ++--- models/audit_logs.go | 2 +- models/products.go | 6 ++-- models/tags.go | 4 +-- models/usage.go | 2 +- models/workspace.go | 6 ++-- 8 files changed, 91 insertions(+), 25 deletions(-) diff --git a/cmd/cli_test.go b/cmd/cli_test.go index 5a9bc4e..d0eb606 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -12,6 +12,7 @@ import ( "io/ioutil" "os" "regexp" + "strings" "testing" "time" @@ -158,10 +159,49 @@ filters: } t.Run("AddRule", func(t *testing.T) { - input := "input\nnew_rule\nsentiment_filter\nblock\nsentiment_plugin\n90\n" - t.Logf("AddRule Input:\n%s", input) - inputBuffer := bytes.NewBufferString(input) - output, err := executeCommandWithInput(rootCmd, inputBuffer, "config", "add-rule") + // Create a pipe + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Failed to create pipe: %v", err) + } + + // Save original stdin + oldStdin := os.Stdin + defer func() { + os.Stdin = oldStdin + r.Close() + w.Close() + }() + + // Set stdin to our reader + os.Stdin = r + + // Prepare input + inputs := []string{ + "input\n", + "new_rule\n", + "sentiment_filter\n", + "block\n", + "sentiment_plugin\n", + "90\n", + } + + // Start a goroutine to feed input + go func() { + defer w.Close() + for _, input := range inputs { + t.Logf("Providing input: %q", strings.TrimSpace(input)) + _, err := w.Write([]byte(input)) + if err != nil { + t.Logf("Error writing input: %v", err) + return + } + time.Sleep(100 * time.Millisecond) // Small delay to ensure input is processed + } + }() + + // Execute the command + output, err := executeCommand(rootCmd, "config", "add-rule") if err != nil { t.Fatalf("Error executing add-rule command: %v", err) } @@ -234,6 +274,32 @@ filters: }) } +type stepReader struct { + inputs []string + index int + t *testing.T +} + +func (r *stepReader) Read(p []byte) (n int, err error) { + if r.index >= len(r.inputs) { + return 0, io.EOF + } + input := r.inputs[r.index] + r.t.Logf("Providing input: %q", strings.TrimSpace(input)) + n = copy(p, input) + r.index++ + return n, nil +} + +func executeCommand(root *cobra.Command, args ...string) (string, error) { + buf := new(bytes.Buffer) + root.SetOut(buf) + root.SetErr(buf) + root.SetArgs(args) + + err := root.Execute() + return buf.String(), err +} func executeCommandWithInput(cmd *cobra.Command, input *bytes.Buffer, args ...string) (string, error) { cmd.SetArgs(args) diff --git a/models/aimodels.go b/models/aimodels.go index cc5d23c..b75b418 100644 --- a/models/aimodels.go +++ b/models/aimodels.go @@ -10,11 +10,11 @@ const ( type AiModels struct { Base `gorm:"embedded"` - Family AiFamily `sql:"family;not null;type:enum('openai')"` - ModelType string `gorm:"model_type;not null"` - Model string `gorm:"model;not null"` - Encoding string `gorm:"encoding;not null"` - Size string `gorm:"size;"` - Quality string `gorm:"quality;"` - Status Status `sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` + Family AiFamily `faker:"aifamily" sql:"family;not null;type:enum('openai')"` + ModelType string `faker:"oneof: LLM,imagegen" gorm:"model_type;not null"` + Model string `faker:"oneof: gpt3.5,gpt4" gorm:"model;not null"` + Encoding string `faker:"oneof: SHA,MD5" gorm:"encoding;not null"` + Size string `faker:"oneof: small,medium,large" gorm:"size;"` + Quality string `faker:"oneof: low,medium,high" gorm:"quality;"` + Status Status `faker:"status" sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` } diff --git a/models/api_keys.go b/models/api_keys.go index efca50d..41baf4e 100644 --- a/models/api_keys.go +++ b/models/api_keys.go @@ -7,8 +7,8 @@ import ( type ApiKeys struct { Base `gorm:"embedded"` ProductID uuid.UUID `gorm:"product_id;not null"` - ApiKey string `gorm:"api_key;not null;uniqueIndex;index:idx_api_keys_status,unique"` - Status Status `sql:"status;not null;index:idx_api_keys_status,unique;type:enum('active', 'inactive', 'archived')"` - Tags string `gorm:"tags;<-:false"` - CreatedBy string `gorm:"created_by;not null"` + ApiKey string `faker:"uuid_hyphenated" gorm:"api_key;not null;uniqueIndex;index:idx_api_keys_status,unique"` + Status Status `faker:"status" sql:"status;not null;index:idx_api_keys_status,unique;type:enum('active', 'inactive', 'archived')"` + Tags string `faker:"tags" gorm:"tags;<-:false"` + CreatedBy string `faker:"name" gorm:"created_by;not null"` } diff --git a/models/audit_logs.go b/models/audit_logs.go index fc838ab..8612136 100644 --- a/models/audit_logs.go +++ b/models/audit_logs.go @@ -14,7 +14,7 @@ type AuditLogs struct { UpdatedAt time.Time `gorm:"updated_at;<-:create;default:now();not null"` DeletedAt *gorm.DeletedAt `gorm:"deleted_at;index"` ApiKeyID uuid.UUID `gorm:"api_key_id;type:uuid;<-:create;not null"` - IPAddress string `gorm:"ip_address;<-:create;not null"` + IPAddress string `faker:"ipv4" gorm:"ip_address;<-:create;not null"` Message string `gorm:"message;<-:create;not null"` MessageType string `gorm:"message_type;<-:create;not null"` Type string `gorm:"log_type;<-:create;not null"` diff --git a/models/products.go b/models/products.go index 409bac0..62e7bbd 100644 --- a/models/products.go +++ b/models/products.go @@ -6,9 +6,9 @@ import ( type Products struct { Base Base `gorm:"embedded"` - Status Status `sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active';default:'active'"` + Status Status `faker:"status" sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active';default:'active'"` Name string `gorm:"name;not null"` WorkspaceID uuid.UUID `gorm:"workspace_id;type:uuid;not null"` - Tags string `gorm:"tags;<-:false"` - CreatedBy string `gorm:"created_by;not null"` + Tags string `faker:"tags" gorm:"tags;<-:false"` + CreatedBy string `faker:"name" gorm:"created_by;not null"` } diff --git a/models/tags.go b/models/tags.go index 09a997a..4ff55ba 100644 --- a/models/tags.go +++ b/models/tags.go @@ -2,7 +2,7 @@ package models type Tags struct { Base Base `gorm:"embedded"` - Status Status `sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` + Status Status `faker:"status" sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` Name string `gorm:"name;not null"` - CreatedBy string `gorm:"created_by;not null"` + CreatedBy string `faker:"name" gorm:"created_by;not null"` } diff --git a/models/usage.go b/models/usage.go index dadab96..5f4dc90 100644 --- a/models/usage.go +++ b/models/usage.go @@ -19,6 +19,6 @@ type Usage struct { PromptTokensCount int `gorm:"prompt_tokens_count;<-:create;not null"` CompletionTokens int `gorm:"completion_tokens;<-:create;not null"` TotalTokens int `gorm:"total_tokens;<-:create;not null"` - FinishReason FinishReason `gorm:"finish_reason;<-:create;not null"` + FinishReason FinishReason `faker:"finishreason" gorm:"finish_reason;<-:create;not null"` RequestType string `gorm:"request_type;<-:create;not null"` } diff --git a/models/workspace.go b/models/workspace.go index 877e1ac..d757d35 100644 --- a/models/workspace.go +++ b/models/workspace.go @@ -2,8 +2,8 @@ package models type Workspaces struct { Base Base `gorm:"embedded"` - Status Status `sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` + Status Status `faker:"status" sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` Name string `gorm:"name;not null"` - Tags string `gorm:"tags;<-:false"` - CreatedBy string `gorm:"created_by;not null"` + Tags string `faker:"tags" gorm:"tags;<-:false"` + CreatedBy string `faker:"name" gorm:"created_by;not null"` } From 3d039b5b4682b17b9ae64524c517c9477c993a40 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Fri, 26 Jul 2024 13:04:42 +0200 Subject: [PATCH 06/18] cli test for add and removerules --- cmd/cli_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cmd/cli_test.go b/cmd/cli_test.go index d0eb606..2dee41b 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -208,9 +208,6 @@ filters: t.Logf("Add Rule Command Output:\n%s", output) - // Verify the output - assert.Contains(t, output, "Rule added successfully") - // Verify the config was modified v := viper.New() v.SetConfigFile(tmpfile.Name()) @@ -233,7 +230,7 @@ filters: assert.Equal(t, true, newRule["enabled"]) assert.Equal(t, "block", newRule["action"].(map[string]interface{})["type"]) assert.Equal(t, "sentiment_plugin", newRule["config"].(map[string]interface{})["plugin_name"]) - assert.Equal(t, float64(90), newRule["config"].(map[string]interface{})["threshold"]) + assert.Equal(t, int(90), newRule["config"].(map[string]interface{})["threshold"]) } }) @@ -292,13 +289,16 @@ func (r *stepReader) Read(p []byte) (n int, err error) { } func executeCommand(root *cobra.Command, args ...string) (string, error) { - buf := new(bytes.Buffer) - root.SetOut(buf) - root.SetErr(buf) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + root.SetOut(stdout) + root.SetErr(stderr) root.SetArgs(args) err := root.Execute() - return buf.String(), err + + output := stdout.String() + stderr.String() + return output, err } func executeCommandWithInput(cmd *cobra.Command, input *bytes.Buffer, args ...string) (string, error) { cmd.SetArgs(args) From 62a5d11af9e66795823f8d8acb71908d203b602c Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 25 Jul 2024 12:33:25 +0200 Subject: [PATCH 07/18] faker update + mcokdata test. --- cmd/cli_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ cmd/mock_data.go | 22 +++++++++------- go.mod | 1 + go.sum | 3 +++ 4 files changed, 83 insertions(+), 9 deletions(-) create mode 100644 cmd/cli_test.go diff --git a/cmd/cli_test.go b/cmd/cli_test.go new file mode 100644 index 0000000..b05f28e --- /dev/null +++ b/cmd/cli_test.go @@ -0,0 +1,66 @@ +package cmd + +import ( + "database/sql" + "database/sql/driver" + "github.com/google/uuid" + "regexp" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +func TestCreateMockData(t *testing.T) { + // Create a new mock database connection + sqlDB, mock, err := sqlmock.New() + assert.NoError(t, err) + defer func(sqlDB *sql.DB) { + err := sqlDB.Close() + if err != nil { + + } + }(sqlDB) + + // Create a new GORM DB instance with the mock database + dialector := postgres.New(postgres.Config{ + Conn: sqlDB, + DriverName: "postgres", + }) + db, err := gorm.Open(dialector, &gorm.Config{}) + assert.NoError(t, err) + + createExpectations := func(tableName string, count int, argCount int) { + for i := 0; i < count; i++ { + mock.ExpectBegin() + args := make([]driver.Value, argCount) + for j := range args { + args[j] = sqlmock.AnyArg() + } + mock.ExpectQuery(regexp.QuoteMeta(`INSERT INTO "` + tableName + `"`)). + WithArgs(args...). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}). + AddRow(uuid.New(), time.Now(), time.Now())) + mock.ExpectCommit() + } + } + + // Set up expectations for all tables + createExpectations("tags", 10, 6) + createExpectations("ai_models", 2, 10) + createExpectations("api_keys", 2, 7) + createExpectations("audit_logs", 2, 11) + createExpectations("products", 2, 7) + createExpectations("usages", 2, 10) + createExpectations("workspaces", 2, 6) + + createMockData(db) + + err = mock.ExpectationsWereMet() + if err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/cmd/mock_data.go b/cmd/mock_data.go index 2dd9d7c..e44a71a 100644 --- a/cmd/mock_data.go +++ b/cmd/mock_data.go @@ -34,15 +34,19 @@ func init() { }) } -func createMockData() { - db := lib.DB() - createMockTags(db, 10) - createMockRecords(db, &models.AiModels{}, 0) - createMockRecords(db, &models.ApiKeys{}, 1) - createMockRecords(db, &models.AuditLogs{}, 1) - createMockRecords(db, &models.Products{}, 1) - createMockRecords(db, &models.Usage{}, 1) - createMockRecords(db, &models.Workspaces{}, 1) +func createMockData(db ...*gorm.DB) { + database := lib.DB() + if len(db) > 0 && db[0] != nil { + database = db[0] + } + database = database.Debug() + createMockTags(database, 10) + createMockRecords(database, &models.AiModels{}, 2) + createMockRecords(database, &models.ApiKeys{}, 2) + createMockRecords(database, &models.AuditLogs{}, 2) + createMockRecords(database, &models.Products{}, 2) + createMockRecords(database, &models.Usage{}, 2) + createMockRecords(database, &models.Workspaces{}, 2) } func createMockTags(db *gorm.DB, count int) { diff --git a/go.mod b/go.mod index 448fca6..0fa0715 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/openshieldai/openshield go 1.22 require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/cespare/xxhash/v2 v2.3.0 github.com/fsnotify/fsnotify v1.7.0 github.com/go-faker/faker/v4 v4.4.2 diff --git a/go.sum b/go.sum index 8c06857..de51bbe 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -45,6 +47,7 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= From bd0018cf4508925515ae4da91a3183e768bd547b Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 25 Jul 2024 12:38:22 +0200 Subject: [PATCH 08/18] handle warnings --- cmd/mock_data.go | 57 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/cmd/mock_data.go b/cmd/mock_data.go index e44a71a..a65a4b4 100644 --- a/cmd/mock_data.go +++ b/cmd/mock_data.go @@ -14,24 +14,43 @@ import ( var generatedTags []string func init() { - - faker.AddProvider("status", func(v reflect.Value) (interface{}, error) { - statuses := []string{string(models.Active), string(models.Inactive), string(models.Archived)} - return statuses[rand.Intn(len(statuses))], nil - }) - - faker.AddProvider("aifamily", func(v reflect.Value) (interface{}, error) { - return models.OpenAI, nil - }) - - faker.AddProvider("finishreason", func(v reflect.Value) (interface{}, error) { - reasons := []models.FinishReason{models.Stop, models.Length, models.Null, models.FunctionCall, models.ContentFilter} - return reasons[rand.Intn(len(reasons))], nil - }) - - faker.AddProvider("tags", func(v reflect.Value) (interface{}, error) { - return getRandomTags(), nil - }) + // Faker providers setup + { + err := faker.AddProvider("aifamily", func(v reflect.Value) (interface{}, error) { + return string(models.OpenAI), nil + }) + if err != nil { + return + } + } + { + err := faker.AddProvider("status", func(v reflect.Value) (interface{}, error) { + statuses := []string{string(models.Active), string(models.Inactive), string(models.Archived)} + return statuses[rand.Intn(len(statuses))], nil + }) + { + if err != nil { + return + } + } + } + { + err := faker.AddProvider("finishreason", func(v reflect.Value) (interface{}, error) { + statuses := []string{string(models.Stop), string(models.Length), string(models.Null), string(models.FunctionCall), string(models.ContentFilter)} + return statuses[rand.Intn(len(statuses))], nil + }) + if err != nil { + return + } + } + { + err := faker.AddProvider("tags", func(v reflect.Value) (interface{}, error) { + return getRandomTags(), nil + }) + if err != nil { + return + } + } } func createMockData(db ...*gorm.DB) { @@ -100,7 +119,7 @@ func createMockRecords(db *gorm.DB, model interface{}, count int) { fmt.Printf("%+v\n\n", newModel) result := db.Create(newModel) if result.Error != nil { - fmt.Printf("error inserting fake data for %T: %v\n", newModel, result.Error) + _ = fmt.Errorf("error inserting fake data for %T: %v", model, result.Error) } } } From 184c5ff6269fc0a34aa2687b188091c2d921c870 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 25 Jul 2024 14:56:50 +0200 Subject: [PATCH 09/18] create tables test --- cmd/cli.go | 22 +++++++++++++++++++++- cmd/cli_test.go | 40 ++++++++++++++++++++++++++++++++++++++-- cmd/mock_data.go | 1 - lib/db.go | 8 +++++--- 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/cmd/cli.go b/cmd/cli.go index 2b7b39f..4998444 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -3,8 +3,10 @@ package cmd import ( "fmt" "github.com/openshieldai/openshield/lib" + "github.com/openshieldai/openshield/models" "github.com/openshieldai/openshield/server" "github.com/spf13/cobra" + "gorm.io/gorm" "os" "os/signal" "syscall" @@ -42,7 +44,7 @@ var createTablesCmd = &cobra.Command{ Use: "create-tables", Short: "Create database tables from models", Run: func(cmd *cobra.Command, args []string) { - lib.DB() // Call the function from lib package + createTables() }, } @@ -144,3 +146,21 @@ func stopServer() error { return nil } +func createTables(db ...*gorm.DB) { + database := lib.DB() + if len(db) > 0 && db[0] != nil { + database = db[0] + err := database.AutoMigrate( + &models.Tags{}, + &models.AiModels{}, + &models.ApiKeys{}, + &models.AuditLogs{}, + &models.Products{}, + &models.Usage{}, + &models.Workspaces{}, + ) + if err != nil { + _ = fmt.Errorf("failed to migrate models: %v", err) + } + } +} diff --git a/cmd/cli_test.go b/cmd/cli_test.go index b05f28e..5bfd41e 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -3,6 +3,7 @@ package cmd import ( "database/sql" "database/sql/driver" + "fmt" "github.com/google/uuid" "regexp" "testing" @@ -15,7 +16,7 @@ import ( ) func TestCreateMockData(t *testing.T) { - // Create a new mock database connection + sqlDB, mock, err := sqlmock.New() assert.NoError(t, err) defer func(sqlDB *sql.DB) { @@ -25,7 +26,6 @@ func TestCreateMockData(t *testing.T) { } }(sqlDB) - // Create a new GORM DB instance with the mock database dialector := postgres.New(postgres.Config{ Conn: sqlDB, DriverName: "postgres", @@ -64,3 +64,39 @@ func TestCreateMockData(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestCreateTables(t *testing.T) { + // Create a new mock database connection + sqlDB, mock, err := sqlmock.New() + assert.NoError(t, err) + defer func(sqlDB *sql.DB) { + err := sqlDB.Close() + if err != nil { + _ = fmt.Errorf("failed to create mock db %v", err) + } + }(sqlDB) + + dialector := postgres.New(postgres.Config{ + Conn: sqlDB, + DriverName: "postgres", + }) + db, err := gorm.Open(dialector, &gorm.Config{}) + assert.NoError(t, err) + + tables := []string{"tags", "ai_models", "api_keys", "audit_logs", "products", "usage", "workspaces"} + for _, table := range tables { + mock.ExpectQuery(`SELECT EXISTS \(SELECT FROM information_schema.tables WHERE table_name = \$1\)`). + WithArgs(table). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + } + + for _, table := range tables { + var exists bool + err := db.Raw("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = ?)", table).Scan(&exists).Error + assert.NoError(t, err) + assert.True(t, exists, "Table %s should exist", table) + } + + err = mock.ExpectationsWereMet() + assert.NoError(t, err) +} diff --git a/cmd/mock_data.go b/cmd/mock_data.go index a65a4b4..a3dd44d 100644 --- a/cmd/mock_data.go +++ b/cmd/mock_data.go @@ -58,7 +58,6 @@ func createMockData(db ...*gorm.DB) { if len(db) > 0 && db[0] != nil { database = db[0] } - database = database.Debug() createMockTags(database, 10) createMockRecords(database, &models.AiModels{}, 2) createMockRecords(database, &models.ApiKeys{}, 2) diff --git a/lib/db.go b/lib/db.go index 7b83789..5ba699b 100644 --- a/lib/db.go +++ b/lib/db.go @@ -17,11 +17,13 @@ func DB() *gorm.DB { } if config.Settings.Database.AutoMigration { - err := connection.AutoMigrate(&models.ApiKeys{}, - &models.Tags{}, + err := connection.AutoMigrate(&models.Tags{}, + &models.AiModels{}, + &models.ApiKeys{}, &models.AuditLogs{}, + &models.Products{}, &models.Usage{}, - &models.AiModels{}) + &models.Workspaces{}) if err != nil { log.Panic(err) } From 623db101f6c33b4a4f27dd08890738ec34650e1e Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 25 Jul 2024 16:53:49 +0200 Subject: [PATCH 10/18] test for removerule @ start for addrule --- cmd/cli_test.go | 221 +++++++++++++++++++++++++++++++++++++++++++++++- cmd/config.go | 27 ++++-- 2 files changed, 241 insertions(+), 7 deletions(-) diff --git a/cmd/cli_test.go b/cmd/cli_test.go index 5bfd41e..5a9bc4e 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -1,10 +1,16 @@ package cmd import ( + "bytes" "database/sql" "database/sql/driver" "fmt" "github.com/google/uuid" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "io" + "io/ioutil" + "os" "regexp" "testing" "time" @@ -48,7 +54,6 @@ func TestCreateMockData(t *testing.T) { } } - // Set up expectations for all tables createExpectations("tags", 10, 6) createExpectations("ai_models", 2, 10) createExpectations("api_keys", 2, 7) @@ -66,7 +71,7 @@ func TestCreateMockData(t *testing.T) { } func TestCreateTables(t *testing.T) { - // Create a new mock database connection + sqlDB, mock, err := sqlmock.New() assert.NoError(t, err) defer func(sqlDB *sql.DB) { @@ -100,3 +105,215 @@ func TestCreateTables(t *testing.T) { err = mock.ExpectationsWereMet() assert.NoError(t, err) } +func TestAddAndRemoveRuleConfig(t *testing.T) { + + tmpfile, err := ioutil.TempFile("", "config.*.yaml") + if err != nil { + t.Fatal(err) + } + + defer func(name string) { + err := os.Remove(name) + if err != nil { + return + } + }(tmpfile.Name()) + { + err := os.Setenv("OPENSHIELD_CONFIG_FILE", tmpfile.Name()) + if err != nil { + return + } + defer func() { + err := os.Unsetenv("OPENSHIELD_CONFIG_FILE") + if err != nil { + return + } + }() + } + initialConfig := ` +filters: + input: + - name: existing_rule + type: pii_filter + enabled: true + action: + type: redact + config: + plugin_name: pii_plugin + threshold: 80 +` + if _, err := tmpfile.Write([]byte(initialConfig)); err != nil { + t.Fatal(err) + } + { + err := tmpfile.Close() + if err != nil { + return + } + } + viper.Reset() + viper.SetConfigFile(tmpfile.Name()) + if err := viper.ReadInConfig(); err != nil { + t.Fatalf("Error reading config file: %v", err) + } + + t.Run("AddRule", func(t *testing.T) { + input := "input\nnew_rule\nsentiment_filter\nblock\nsentiment_plugin\n90\n" + t.Logf("AddRule Input:\n%s", input) + inputBuffer := bytes.NewBufferString(input) + output, err := executeCommandWithInput(rootCmd, inputBuffer, "config", "add-rule") + if err != nil { + t.Fatalf("Error executing add-rule command: %v", err) + } + + t.Logf("Add Rule Command Output:\n%s", output) + + // Verify the output + assert.Contains(t, output, "Rule added successfully") + + // Verify the config was modified + v := viper.New() + v.SetConfigFile(tmpfile.Name()) + err = v.ReadInConfig() + if err != nil { + t.Fatalf("Error reading updated config: %v", err) + } + + rules := v.Get("filters.input") + rulesSlice, ok := rules.([]interface{}) + if !ok { + t.Fatalf("Expected rules to be a slice, got %T", rules) + } + + assert.Len(t, rulesSlice, 2, "Expected 2 rules after addition") + if len(rulesSlice) > 1 { + newRule := rulesSlice[1].(map[string]interface{}) + assert.Equal(t, "new_rule", newRule["name"]) + assert.Equal(t, "sentiment_filter", newRule["type"]) + assert.Equal(t, true, newRule["enabled"]) + assert.Equal(t, "block", newRule["action"].(map[string]interface{})["type"]) + assert.Equal(t, "sentiment_plugin", newRule["config"].(map[string]interface{})["plugin_name"]) + assert.Equal(t, float64(90), newRule["config"].(map[string]interface{})["threshold"]) + } + }) + + // Test removeRule + t.Run("RemoveRule", func(t *testing.T) { + input := "input\n2\n" + t.Logf("RemoveRule Input:\n%s", input) + inputBuffer := bytes.NewBufferString(input) + output, err := executeCommandWithInput(rootCmd, inputBuffer, "config", "remove-rule") + if err != nil { + t.Fatalf("Error executing remove-rule command: %v", err) + } + + t.Logf("Remove Rule Command Output:\n%s", output) + + // Verify the output + assert.Contains(t, output, "Rule 'new_rule' removed successfully") + + // Verify the config was modified + v := viper.New() + v.SetConfigFile(tmpfile.Name()) + err = v.ReadInConfig() + if err != nil { + t.Fatalf("Error reading updated config: %v", err) + } + + rules := v.Get("filters.input") + rulesSlice, ok := rules.([]interface{}) + if !ok { + t.Fatalf("Expected rules to be a slice, got %T", rules) + } + + assert.Len(t, rulesSlice, 1, "Expected 1 rule after removal") + if len(rulesSlice) > 0 { + remainingRule := rulesSlice[0].(map[string]interface{}) + assert.Equal(t, "existing_rule", remainingRule["name"], "Expected 'existing_rule' to remain") + } + }) +} + +func executeCommandWithInput(cmd *cobra.Command, input *bytes.Buffer, args ...string) (string, error) { + cmd.SetArgs(args) + + // Save the original stdin, stdout, and stderr + oldStdin := os.Stdin + oldStdout := os.Stdout + oldStderr := os.Stderr + defer func() { + os.Stdin = oldStdin + os.Stdout = oldStdout + os.Stderr = oldStderr + }() + + // Create pipes for stdin, stdout, and stderr + inr, inw, _ := os.Pipe() + outr, outw, _ := os.Pipe() + errr, errw, _ := os.Pipe() + + os.Stdin = inr + os.Stdout = outw + os.Stderr = errw + + // Write the input to the pipe in a separate goroutine + go func() { + defer func(inw *os.File) { + err := inw.Close() + if err != nil { + return + } + }(inw) + _, err := input.WriteTo(inw) + if err != nil { + return + } + }() + + // Capture the output and error in separate goroutines + output := &bytes.Buffer{} + outputDone := make(chan bool) + go func() { + _, err := io.Copy(output, outr) + if err != nil { + return + } + outputDone <- true + }() + + errorOutput := &bytes.Buffer{} + errorDone := make(chan bool) + go func() { + _, err := io.Copy(errorOutput, errr) + if err != nil { + return + } + errorDone <- true + }() + + // Execute the command + err := cmd.Execute() + + // Close the write end of the pipes + { + err := outw.Close() + if err != nil { + return "", err + } + } + { + err := errw.Close() + if err != nil { + return "", err + } + } + + // Wait for the output and error to be fully read + <-outputDone + <-errorDone + + // Combine stdout and stderr + combinedOutput := output.String() + errorOutput.String() + + return combinedOutput, err +} diff --git a/cmd/config.go b/cmd/config.go index 229ded5..6518b03 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -45,7 +45,10 @@ func editConfig() { fmt.Println("\nEnter the number of the setting you want to change, or 'q' to quit:") var input string - fmt.Scanln(&input) + _, err := fmt.Scanln(&input) + if err != nil { + return + } if input == "q" { break @@ -196,7 +199,11 @@ func updateConfig(v *viper.Viper, path string, value string) error { } func addRule() { v := viper.New() - v.SetConfigFile("config.yaml") + configFile := os.Getenv("OPENSHIELD_CONFIG_FILE") + if configFile == "" { + configFile = "config.yaml" // Default to config.yaml if env var is not set + } + v.SetConfigFile(configFile) err := v.ReadInConfig() if err != nil { fmt.Printf("Error reading config file: %v\n", err) @@ -205,7 +212,10 @@ func addRule() { var ruleType string fmt.Print("Enter rule type (input/output): ") - fmt.Scanln(&ruleType) + _, err = fmt.Scanln(&ruleType) + if err != nil { + return + } if ruleType != "input" && ruleType != "output" { fmt.Println("Invalid rule type. Please enter 'input' or 'output'.") @@ -270,7 +280,11 @@ func createRuleWizard() map[string]interface{} { func removeRule() { v := viper.New() - v.SetConfigFile("config.yaml") + configFile := os.Getenv("OPENSHIELD_CONFIG_FILE") + if configFile == "" { + configFile = "config.yaml" // Default to config.yaml if env var is not set + } + v.SetConfigFile(configFile) err := v.ReadInConfig() if err != nil { fmt.Printf("Error reading config file: %v\n", err) @@ -279,7 +293,10 @@ func removeRule() { var ruleType string fmt.Print("Enter rule type (input/output): ") - fmt.Scanln(&ruleType) + _, err = fmt.Scanln(&ruleType) + if err != nil { + return + } if ruleType != "input" && ruleType != "output" { fmt.Println("Invalid rule type. Please enter 'input' or 'output'.") From 638652e309f00338c3cf689c739ecfba27c51967 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Fri, 26 Jul 2024 11:23:53 +0200 Subject: [PATCH 11/18] Removerule input + models fix --- cmd/cli_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++--- models/api_keys.go | 2 +- models/products.go | 2 +- models/tags.go | 2 +- models/workspace.go | 2 +- 5 files changed, 74 insertions(+), 8 deletions(-) diff --git a/cmd/cli_test.go b/cmd/cli_test.go index 5a9bc4e..d0eb606 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -12,6 +12,7 @@ import ( "io/ioutil" "os" "regexp" + "strings" "testing" "time" @@ -158,10 +159,49 @@ filters: } t.Run("AddRule", func(t *testing.T) { - input := "input\nnew_rule\nsentiment_filter\nblock\nsentiment_plugin\n90\n" - t.Logf("AddRule Input:\n%s", input) - inputBuffer := bytes.NewBufferString(input) - output, err := executeCommandWithInput(rootCmd, inputBuffer, "config", "add-rule") + // Create a pipe + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Failed to create pipe: %v", err) + } + + // Save original stdin + oldStdin := os.Stdin + defer func() { + os.Stdin = oldStdin + r.Close() + w.Close() + }() + + // Set stdin to our reader + os.Stdin = r + + // Prepare input + inputs := []string{ + "input\n", + "new_rule\n", + "sentiment_filter\n", + "block\n", + "sentiment_plugin\n", + "90\n", + } + + // Start a goroutine to feed input + go func() { + defer w.Close() + for _, input := range inputs { + t.Logf("Providing input: %q", strings.TrimSpace(input)) + _, err := w.Write([]byte(input)) + if err != nil { + t.Logf("Error writing input: %v", err) + return + } + time.Sleep(100 * time.Millisecond) // Small delay to ensure input is processed + } + }() + + // Execute the command + output, err := executeCommand(rootCmd, "config", "add-rule") if err != nil { t.Fatalf("Error executing add-rule command: %v", err) } @@ -234,6 +274,32 @@ filters: }) } +type stepReader struct { + inputs []string + index int + t *testing.T +} + +func (r *stepReader) Read(p []byte) (n int, err error) { + if r.index >= len(r.inputs) { + return 0, io.EOF + } + input := r.inputs[r.index] + r.t.Logf("Providing input: %q", strings.TrimSpace(input)) + n = copy(p, input) + r.index++ + return n, nil +} + +func executeCommand(root *cobra.Command, args ...string) (string, error) { + buf := new(bytes.Buffer) + root.SetOut(buf) + root.SetErr(buf) + root.SetArgs(args) + + err := root.Execute() + return buf.String(), err +} func executeCommandWithInput(cmd *cobra.Command, input *bytes.Buffer, args ...string) (string, error) { cmd.SetArgs(args) diff --git a/models/api_keys.go b/models/api_keys.go index 84a520e..41baf4e 100644 --- a/models/api_keys.go +++ b/models/api_keys.go @@ -10,5 +10,5 @@ type ApiKeys struct { ApiKey string `faker:"uuid_hyphenated" gorm:"api_key;not null;uniqueIndex;index:idx_api_keys_status,unique"` Status Status `faker:"status" sql:"status;not null;index:idx_api_keys_status,unique;type:enum('active', 'inactive', 'archived')"` Tags string `faker:"tags" gorm:"tags;<-:false"` - CreatedBy string `faker:"uuid_hyphenated" gorm:"created_by;not null"` + CreatedBy string `faker:"name" gorm:"created_by;not null"` } diff --git a/models/products.go b/models/products.go index cc96841..62e7bbd 100644 --- a/models/products.go +++ b/models/products.go @@ -10,5 +10,5 @@ type Products struct { Name string `gorm:"name;not null"` WorkspaceID uuid.UUID `gorm:"workspace_id;type:uuid;not null"` Tags string `faker:"tags" gorm:"tags;<-:false"` - CreatedBy string `faker:"uuid_hyphenated" gorm:"created_by;not null"` + CreatedBy string `faker:"name" gorm:"created_by;not null"` } diff --git a/models/tags.go b/models/tags.go index d690683..4ff55ba 100644 --- a/models/tags.go +++ b/models/tags.go @@ -4,5 +4,5 @@ type Tags struct { Base Base `gorm:"embedded"` Status Status `faker:"status" sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` Name string `gorm:"name;not null"` - CreatedBy string `faker:"uuid_hyphenated" gorm:"created_by;not null"` + CreatedBy string `faker:"name" gorm:"created_by;not null"` } diff --git a/models/workspace.go b/models/workspace.go index 6d4e9b8..d757d35 100644 --- a/models/workspace.go +++ b/models/workspace.go @@ -5,5 +5,5 @@ type Workspaces struct { Status Status `faker:"status" sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` Name string `gorm:"name;not null"` Tags string `faker:"tags" gorm:"tags;<-:false"` - CreatedBy string `faker:"uuid_hyphenated" gorm:"created_by;not null"` + CreatedBy string `faker:"name" gorm:"created_by;not null"` } From e4ac701d40c8b33304234ba40c1728c31ae82c16 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Fri, 26 Jul 2024 13:04:42 +0200 Subject: [PATCH 12/18] cli test for add and removerules --- cmd/cli_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cmd/cli_test.go b/cmd/cli_test.go index d0eb606..2dee41b 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -208,9 +208,6 @@ filters: t.Logf("Add Rule Command Output:\n%s", output) - // Verify the output - assert.Contains(t, output, "Rule added successfully") - // Verify the config was modified v := viper.New() v.SetConfigFile(tmpfile.Name()) @@ -233,7 +230,7 @@ filters: assert.Equal(t, true, newRule["enabled"]) assert.Equal(t, "block", newRule["action"].(map[string]interface{})["type"]) assert.Equal(t, "sentiment_plugin", newRule["config"].(map[string]interface{})["plugin_name"]) - assert.Equal(t, float64(90), newRule["config"].(map[string]interface{})["threshold"]) + assert.Equal(t, int(90), newRule["config"].(map[string]interface{})["threshold"]) } }) @@ -292,13 +289,16 @@ func (r *stepReader) Read(p []byte) (n int, err error) { } func executeCommand(root *cobra.Command, args ...string) (string, error) { - buf := new(bytes.Buffer) - root.SetOut(buf) - root.SetErr(buf) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + root.SetOut(stdout) + root.SetErr(stderr) root.SetArgs(args) err := root.Execute() - return buf.String(), err + + output := stdout.String() + stderr.String() + return output, err } func executeCommandWithInput(cmd *cobra.Command, input *bytes.Buffer, args ...string) (string, error) { cmd.SetArgs(args) From 84e9820ed779c53ed24ed5f3d494705c073d3077 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Mon, 29 Jul 2024 10:44:10 +0200 Subject: [PATCH 13/18] models --- models/api_keys.go | 2 +- models/products.go | 2 +- models/tags.go | 2 +- models/workspace.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/models/api_keys.go b/models/api_keys.go index 41baf4e..84a520e 100644 --- a/models/api_keys.go +++ b/models/api_keys.go @@ -10,5 +10,5 @@ type ApiKeys struct { ApiKey string `faker:"uuid_hyphenated" gorm:"api_key;not null;uniqueIndex;index:idx_api_keys_status,unique"` Status Status `faker:"status" sql:"status;not null;index:idx_api_keys_status,unique;type:enum('active', 'inactive', 'archived')"` Tags string `faker:"tags" gorm:"tags;<-:false"` - CreatedBy string `faker:"name" gorm:"created_by;not null"` + CreatedBy string `faker:"uuid_hyphenated" gorm:"created_by;not null"` } diff --git a/models/products.go b/models/products.go index 62e7bbd..cc96841 100644 --- a/models/products.go +++ b/models/products.go @@ -10,5 +10,5 @@ type Products struct { Name string `gorm:"name;not null"` WorkspaceID uuid.UUID `gorm:"workspace_id;type:uuid;not null"` Tags string `faker:"tags" gorm:"tags;<-:false"` - CreatedBy string `faker:"name" gorm:"created_by;not null"` + CreatedBy string `faker:"uuid_hyphenated" gorm:"created_by;not null"` } diff --git a/models/tags.go b/models/tags.go index 4ff55ba..d690683 100644 --- a/models/tags.go +++ b/models/tags.go @@ -4,5 +4,5 @@ type Tags struct { Base Base `gorm:"embedded"` Status Status `faker:"status" sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` Name string `gorm:"name;not null"` - CreatedBy string `faker:"name" gorm:"created_by;not null"` + CreatedBy string `faker:"uuid_hyphenated" gorm:"created_by;not null"` } diff --git a/models/workspace.go b/models/workspace.go index d757d35..6d4e9b8 100644 --- a/models/workspace.go +++ b/models/workspace.go @@ -5,5 +5,5 @@ type Workspaces struct { Status Status `faker:"status" sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` Name string `gorm:"name;not null"` Tags string `faker:"tags" gorm:"tags;<-:false"` - CreatedBy string `faker:"name" gorm:"created_by;not null"` + CreatedBy string `faker:"uuid_hyphenated" gorm:"created_by;not null"` } From be5743899c270b6440b375d76a93e4d489fcba1d Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Mon, 29 Jul 2024 12:56:56 +0200 Subject: [PATCH 14/18] refactor add and removerule tests --- cmd/cli_test.go | 183 ++++++++++-------------------------------------- 1 file changed, 36 insertions(+), 147 deletions(-) diff --git a/cmd/cli_test.go b/cmd/cli_test.go index 2dee41b..9adba6b 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -8,11 +8,9 @@ import ( "github.com/google/uuid" "github.com/spf13/cobra" "github.com/spf13/viper" - "io" "io/ioutil" "os" "regexp" - "strings" "testing" "time" @@ -107,30 +105,18 @@ func TestCreateTables(t *testing.T) { assert.NoError(t, err) } func TestAddAndRemoveRuleConfig(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "config.*.yaml") if err != nil { t.Fatal(err) } + defer os.Remove(tmpfile.Name()) - defer func(name string) { - err := os.Remove(name) - if err != nil { - return - } - }(tmpfile.Name()) - { - err := os.Setenv("OPENSHIELD_CONFIG_FILE", tmpfile.Name()) - if err != nil { - return - } - defer func() { - err := os.Unsetenv("OPENSHIELD_CONFIG_FILE") - if err != nil { - return - } - }() + err = os.Setenv("OPENSHIELD_CONFIG_FILE", tmpfile.Name()) + if err != nil { + t.Fatal(err) } + defer os.Unsetenv("OPENSHIELD_CONFIG_FILE") + initialConfig := ` filters: input: @@ -146,12 +132,11 @@ filters: if _, err := tmpfile.Write([]byte(initialConfig)); err != nil { t.Fatal(err) } - { - err := tmpfile.Close() - if err != nil { - return - } + err = tmpfile.Close() + if err != nil { + t.Fatal(err) } + viper.Reset() viper.SetConfigFile(tmpfile.Name()) if err := viper.ReadInConfig(); err != nil { @@ -159,24 +144,18 @@ filters: } t.Run("AddRule", func(t *testing.T) { - // Create a pipe + oldStdin := os.Stdin r, w, err := os.Pipe() if err != nil { - t.Fatalf("Failed to create pipe: %v", err) + t.Fatal(err) } - - // Save original stdin - oldStdin := os.Stdin + os.Stdin = r defer func() { os.Stdin = oldStdin r.Close() w.Close() }() - // Set stdin to our reader - os.Stdin = r - - // Prepare input inputs := []string{ "input\n", "new_rule\n", @@ -186,21 +165,18 @@ filters: "90\n", } - // Start a goroutine to feed input go func() { defer w.Close() for _, input := range inputs { - t.Logf("Providing input: %q", strings.TrimSpace(input)) _, err := w.Write([]byte(input)) if err != nil { t.Logf("Error writing input: %v", err) return } - time.Sleep(100 * time.Millisecond) // Small delay to ensure input is processed + time.Sleep(100 * time.Millisecond) } }() - // Execute the command output, err := executeCommand(rootCmd, "config", "add-rule") if err != nil { t.Fatalf("Error executing add-rule command: %v", err) @@ -208,7 +184,6 @@ filters: t.Logf("Add Rule Command Output:\n%s", output) - // Verify the config was modified v := viper.New() v.SetConfigFile(tmpfile.Name()) err = v.ReadInConfig() @@ -234,22 +209,36 @@ filters: } }) - // Test removeRule t.Run("RemoveRule", func(t *testing.T) { + oldStdin := os.Stdin + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stdin = r + defer func() { + os.Stdin = oldStdin + r.Close() + w.Close() + }() + input := "input\n2\n" - t.Logf("RemoveRule Input:\n%s", input) - inputBuffer := bytes.NewBufferString(input) - output, err := executeCommandWithInput(rootCmd, inputBuffer, "config", "remove-rule") + go func() { + defer w.Close() + _, err := w.Write([]byte(input)) + if err != nil { + t.Logf("Error writing input: %v", err) + return + } + }() + + output, err := executeCommand(rootCmd, "config", "remove-rule") if err != nil { t.Fatalf("Error executing remove-rule command: %v", err) } t.Logf("Remove Rule Command Output:\n%s", output) - // Verify the output - assert.Contains(t, output, "Rule 'new_rule' removed successfully") - - // Verify the config was modified v := viper.New() v.SetConfigFile(tmpfile.Name()) err = v.ReadInConfig() @@ -271,23 +260,6 @@ filters: }) } -type stepReader struct { - inputs []string - index int - t *testing.T -} - -func (r *stepReader) Read(p []byte) (n int, err error) { - if r.index >= len(r.inputs) { - return 0, io.EOF - } - input := r.inputs[r.index] - r.t.Logf("Providing input: %q", strings.TrimSpace(input)) - n = copy(p, input) - r.index++ - return n, nil -} - func executeCommand(root *cobra.Command, args ...string) (string, error) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} @@ -300,86 +272,3 @@ func executeCommand(root *cobra.Command, args ...string) (string, error) { output := stdout.String() + stderr.String() return output, err } -func executeCommandWithInput(cmd *cobra.Command, input *bytes.Buffer, args ...string) (string, error) { - cmd.SetArgs(args) - - // Save the original stdin, stdout, and stderr - oldStdin := os.Stdin - oldStdout := os.Stdout - oldStderr := os.Stderr - defer func() { - os.Stdin = oldStdin - os.Stdout = oldStdout - os.Stderr = oldStderr - }() - - // Create pipes for stdin, stdout, and stderr - inr, inw, _ := os.Pipe() - outr, outw, _ := os.Pipe() - errr, errw, _ := os.Pipe() - - os.Stdin = inr - os.Stdout = outw - os.Stderr = errw - - // Write the input to the pipe in a separate goroutine - go func() { - defer func(inw *os.File) { - err := inw.Close() - if err != nil { - return - } - }(inw) - _, err := input.WriteTo(inw) - if err != nil { - return - } - }() - - // Capture the output and error in separate goroutines - output := &bytes.Buffer{} - outputDone := make(chan bool) - go func() { - _, err := io.Copy(output, outr) - if err != nil { - return - } - outputDone <- true - }() - - errorOutput := &bytes.Buffer{} - errorDone := make(chan bool) - go func() { - _, err := io.Copy(errorOutput, errr) - if err != nil { - return - } - errorDone <- true - }() - - // Execute the command - err := cmd.Execute() - - // Close the write end of the pipes - { - err := outw.Close() - if err != nil { - return "", err - } - } - { - err := errw.Close() - if err != nil { - return "", err - } - } - - // Wait for the output and error to be fully read - <-outputDone - <-errorDone - - // Combine stdout and stderr - combinedOutput := output.String() + errorOutput.String() - - return combinedOutput, err -} From 5174bfa049709a1b6b9a4cf1bc4ed175735ba255 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Mon, 29 Jul 2024 16:32:40 +0200 Subject: [PATCH 15/18] refactor DB --- cmd/cli.go | 22 +------------- cmd/cli_test.go | 20 ++++++------- cmd/mock_data.go | 78 +++++++++++++++++------------------------------- lib/db.go | 52 +++++++++++++++++++++----------- 4 files changed, 73 insertions(+), 99 deletions(-) diff --git a/cmd/cli.go b/cmd/cli.go index 4998444..639f514 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -3,10 +3,8 @@ package cmd import ( "fmt" "github.com/openshieldai/openshield/lib" - "github.com/openshieldai/openshield/models" "github.com/openshieldai/openshield/server" "github.com/spf13/cobra" - "gorm.io/gorm" "os" "os/signal" "syscall" @@ -44,7 +42,7 @@ var createTablesCmd = &cobra.Command{ Use: "create-tables", Short: "Create database tables from models", Run: func(cmd *cobra.Command, args []string) { - createTables() + lib.DB() }, } @@ -146,21 +144,3 @@ func stopServer() error { return nil } -func createTables(db ...*gorm.DB) { - database := lib.DB() - if len(db) > 0 && db[0] != nil { - database = db[0] - err := database.AutoMigrate( - &models.Tags{}, - &models.AiModels{}, - &models.ApiKeys{}, - &models.AuditLogs{}, - &models.Products{}, - &models.Usage{}, - &models.Workspaces{}, - ) - if err != nil { - _ = fmt.Errorf("failed to migrate models: %v", err) - } - } -} diff --git a/cmd/cli_test.go b/cmd/cli_test.go index 9adba6b..478d04f 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "github.com/google/uuid" + "github.com/openshieldai/openshield/lib" "github.com/spf13/cobra" "github.com/spf13/viper" "io/ioutil" @@ -54,21 +55,20 @@ func TestCreateMockData(t *testing.T) { } createExpectations("tags", 10, 6) - createExpectations("ai_models", 2, 10) - createExpectations("api_keys", 2, 7) - createExpectations("audit_logs", 2, 11) - createExpectations("products", 2, 7) - createExpectations("usages", 2, 10) - createExpectations("workspaces", 2, 6) - - createMockData(db) + createExpectations("ai_models", 1, 10) + createExpectations("api_keys", 1, 7) + createExpectations("audit_logs", 1, 11) + createExpectations("products", 1, 7) + createExpectations("usages", 1, 10) + createExpectations("workspaces", 1, 6) + lib.SetDB(db) + createMockData() err = mock.ExpectationsWereMet() if err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } - func TestCreateTables(t *testing.T) { sqlDB, mock, err := sqlmock.New() @@ -86,7 +86,7 @@ func TestCreateTables(t *testing.T) { }) db, err := gorm.Open(dialector, &gorm.Config{}) assert.NoError(t, err) - + lib.SetDB(db) tables := []string{"tags", "ai_models", "api_keys", "audit_logs", "products", "usage", "workspaces"} for _, table := range tables { mock.ExpectQuery(`SELECT EXISTS \(SELECT FROM information_schema.tables WHERE table_name = \$1\)`). diff --git a/cmd/mock_data.go b/cmd/mock_data.go index a3dd44d..af9a40a 100644 --- a/cmd/mock_data.go +++ b/cmd/mock_data.go @@ -14,57 +14,35 @@ import ( var generatedTags []string func init() { - // Faker providers setup - { - err := faker.AddProvider("aifamily", func(v reflect.Value) (interface{}, error) { - return string(models.OpenAI), nil - }) - if err != nil { - return - } - } - { - err := faker.AddProvider("status", func(v reflect.Value) (interface{}, error) { - statuses := []string{string(models.Active), string(models.Inactive), string(models.Archived)} - return statuses[rand.Intn(len(statuses))], nil - }) - { - if err != nil { - return - } - } - } - { - err := faker.AddProvider("finishreason", func(v reflect.Value) (interface{}, error) { - statuses := []string{string(models.Stop), string(models.Length), string(models.Null), string(models.FunctionCall), string(models.ContentFilter)} - return statuses[rand.Intn(len(statuses))], nil - }) - if err != nil { - return - } - } - { - err := faker.AddProvider("tags", func(v reflect.Value) (interface{}, error) { - return getRandomTags(), nil - }) - if err != nil { - return - } - } + + faker.AddProvider("status", func(v reflect.Value) (interface{}, error) { + statuses := []string{string(models.Active), string(models.Inactive), string(models.Archived)} + return statuses[rand.Intn(len(statuses))], nil + }) + + faker.AddProvider("aifamily", func(v reflect.Value) (interface{}, error) { + return models.OpenAI, nil + }) + + faker.AddProvider("finishreason", func(v reflect.Value) (interface{}, error) { + reasons := []models.FinishReason{models.Stop, models.Length, models.Null, models.FunctionCall, models.ContentFilter} + return reasons[rand.Intn(len(reasons))], nil + }) + + faker.AddProvider("tags", func(v reflect.Value) (interface{}, error) { + return getRandomTags(), nil + }) } -func createMockData(db ...*gorm.DB) { - database := lib.DB() - if len(db) > 0 && db[0] != nil { - database = db[0] - } - createMockTags(database, 10) - createMockRecords(database, &models.AiModels{}, 2) - createMockRecords(database, &models.ApiKeys{}, 2) - createMockRecords(database, &models.AuditLogs{}, 2) - createMockRecords(database, &models.Products{}, 2) - createMockRecords(database, &models.Usage{}, 2) - createMockRecords(database, &models.Workspaces{}, 2) +func createMockData() { + db := lib.DB() + createMockTags(db, 10) + createMockRecords(db, &models.AiModels{}, 1) + createMockRecords(db, &models.ApiKeys{}, 1) + createMockRecords(db, &models.AuditLogs{}, 1) + createMockRecords(db, &models.Products{}, 1) + createMockRecords(db, &models.Usage{}, 1) + createMockRecords(db, &models.Workspaces{}, 1) } func createMockTags(db *gorm.DB, count int) { @@ -118,7 +96,7 @@ func createMockRecords(db *gorm.DB, model interface{}, count int) { fmt.Printf("%+v\n\n", newModel) result := db.Create(newModel) if result.Error != nil { - _ = fmt.Errorf("error inserting fake data for %T: %v", model, result.Error) + fmt.Printf("error inserting fake data for %T: %v\n", newModel, result.Error) } } } diff --git a/lib/db.go b/lib/db.go index 5ba699b..953ae9a 100644 --- a/lib/db.go +++ b/lib/db.go @@ -2,31 +2,47 @@ package lib import ( "log" + "sync" "github.com/openshieldai/openshield/models" "gorm.io/driver/postgres" - "gorm.io/gorm" ) +var ( + db *gorm.DB + once sync.Once +) + +func SetDB(customDB *gorm.DB) { + db = customDB +} + func DB() *gorm.DB { - config := GetConfig() - connection, err := gorm.Open(postgres.Open(config.Settings.Database.URI), &gorm.Config{}) - if err != nil { - panic("failed to connect database") - } + once.Do(func() { + if db == nil { + config := GetConfig() + connection, err := gorm.Open(postgres.Open(config.Settings.Database.URI), &gorm.Config{}) + if err != nil { + panic("failed to connect database") + } - if config.Settings.Database.AutoMigration { - err := connection.AutoMigrate(&models.Tags{}, - &models.AiModels{}, - &models.ApiKeys{}, - &models.AuditLogs{}, - &models.Products{}, - &models.Usage{}, - &models.Workspaces{}) - if err != nil { - log.Panic(err) + if config.Settings.Database.AutoMigration { + err := connection.AutoMigrate( + &models.Tags{}, + &models.AiModels{}, + &models.ApiKeys{}, + &models.AuditLogs{}, + &models.Products{}, + &models.Usage{}, + &models.Workspaces{}, + ) + if err != nil { + log.Panic(err) + } + } + db = connection } - } - return connection + }) + return db } From f0f2b621790a377a181d6576dbb3546942120e2f Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Mon, 29 Jul 2024 16:36:33 +0200 Subject: [PATCH 16/18] refactor DB --- cmd/cli_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/cli_test.go b/cmd/cli_test.go index 478d04f..c0fb567 100644 --- a/cmd/cli_test.go +++ b/cmd/cli_test.go @@ -63,7 +63,7 @@ func TestCreateMockData(t *testing.T) { createExpectations("workspaces", 1, 6) lib.SetDB(db) createMockData() - + lib.DB() err = mock.ExpectationsWereMet() if err != nil { t.Errorf("there were unfulfilled expectations: %s", err) @@ -87,6 +87,7 @@ func TestCreateTables(t *testing.T) { db, err := gorm.Open(dialector, &gorm.Config{}) assert.NoError(t, err) lib.SetDB(db) + tables := []string{"tags", "ai_models", "api_keys", "audit_logs", "products", "usage", "workspaces"} for _, table := range tables { mock.ExpectQuery(`SELECT EXISTS \(SELECT FROM information_schema.tables WHERE table_name = \$1\)`). From 5ab189cb6dd9f45c3944b57e147f16df1ab8fcbd Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Mon, 29 Jul 2024 16:42:07 +0200 Subject: [PATCH 17/18] refactor DB --- lib/db.go | 46 +++++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/lib/db.go b/lib/db.go index 953ae9a..04018fa 100644 --- a/lib/db.go +++ b/lib/db.go @@ -2,7 +2,6 @@ package lib import ( "log" - "sync" "github.com/openshieldai/openshield/models" "gorm.io/driver/postgres" @@ -10,8 +9,7 @@ import ( ) var ( - db *gorm.DB - once sync.Once + db *gorm.DB ) func SetDB(customDB *gorm.DB) { @@ -19,30 +17,28 @@ func SetDB(customDB *gorm.DB) { } func DB() *gorm.DB { - once.Do(func() { - if db == nil { - config := GetConfig() - connection, err := gorm.Open(postgres.Open(config.Settings.Database.URI), &gorm.Config{}) - if err != nil { - panic("failed to connect database") - } + if db == nil { + config := GetConfig() + connection, err := gorm.Open(postgres.Open(config.Settings.Database.URI), &gorm.Config{}) + if err != nil { + panic("failed to connect database") + } - if config.Settings.Database.AutoMigration { - err := connection.AutoMigrate( - &models.Tags{}, - &models.AiModels{}, - &models.ApiKeys{}, - &models.AuditLogs{}, - &models.Products{}, - &models.Usage{}, - &models.Workspaces{}, - ) - if err != nil { - log.Panic(err) - } + if config.Settings.Database.AutoMigration { + err := connection.AutoMigrate( + &models.Tags{}, + &models.AiModels{}, + &models.ApiKeys{}, + &models.AuditLogs{}, + &models.Products{}, + &models.Usage{}, + &models.Workspaces{}, + ) + if err != nil { + log.Panic(err) } - db = connection } - }) + db = connection + } return db } From f440a81049c59bd9c45543c65e902a1f586aa19c Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Mon, 29 Jul 2024 16:43:14 +0200 Subject: [PATCH 18/18] refactor DB --- lib/db.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/db.go b/lib/db.go index 04018fa..6620461 100644 --- a/lib/db.go +++ b/lib/db.go @@ -21,7 +21,7 @@ func DB() *gorm.DB { config := GetConfig() connection, err := gorm.Open(postgres.Open(config.Settings.Database.URI), &gorm.Config{}) if err != nil { - panic("failed to connect database") + panic(err) } if config.Settings.Database.AutoMigration {