Skip to content

Commit

Permalink
Cli test (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
pigri authored Jul 29, 2024
2 parents c36a667 + f440a81 commit 3902ab3
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cmd/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ var createTablesCmd = &cobra.Command{
Use: "create-tables",
Short: "Create database tables from models",
Run: func(cmd *cobra.Command, args []string) {
lib.DB() // Call the function from lib package
lib.DB()
},
}

Expand Down
275 changes: 275 additions & 0 deletions cmd/cli_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
package cmd

import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
"github.com/google/uuid"
"github.com/openshieldai/openshield/lib"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"io/ioutil"
"os"
"regexp"
"testing"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

func TestCreateMockData(t *testing.T) {

sqlDB, mock, err := sqlmock.New()
assert.NoError(t, err)
defer func(sqlDB *sql.DB) {
err := sqlDB.Close()
if err != nil {

}
}(sqlDB)

dialector := postgres.New(postgres.Config{
Conn: sqlDB,
DriverName: "postgres",
})
db, err := gorm.Open(dialector, &gorm.Config{})
assert.NoError(t, err)

createExpectations := func(tableName string, count int, argCount int) {
for i := 0; i < count; i++ {
mock.ExpectBegin()
args := make([]driver.Value, argCount)
for j := range args {
args[j] = sqlmock.AnyArg()
}
mock.ExpectQuery(regexp.QuoteMeta(`INSERT INTO "` + tableName + `"`)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).
AddRow(uuid.New(), time.Now(), time.Now()))
mock.ExpectCommit()
}
}

createExpectations("tags", 10, 6)
createExpectations("ai_models", 1, 10)
createExpectations("api_keys", 1, 7)
createExpectations("audit_logs", 1, 11)
createExpectations("products", 1, 7)
createExpectations("usages", 1, 10)
createExpectations("workspaces", 1, 6)
lib.SetDB(db)
createMockData()
lib.DB()
err = mock.ExpectationsWereMet()
if err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
func TestCreateTables(t *testing.T) {

sqlDB, mock, err := sqlmock.New()
assert.NoError(t, err)
defer func(sqlDB *sql.DB) {
err := sqlDB.Close()
if err != nil {
_ = fmt.Errorf("failed to create mock db %v", err)
}
}(sqlDB)

dialector := postgres.New(postgres.Config{
Conn: sqlDB,
DriverName: "postgres",
})
db, err := gorm.Open(dialector, &gorm.Config{})
assert.NoError(t, err)
lib.SetDB(db)

tables := []string{"tags", "ai_models", "api_keys", "audit_logs", "products", "usage", "workspaces"}
for _, table := range tables {
mock.ExpectQuery(`SELECT EXISTS \(SELECT FROM information_schema.tables WHERE table_name = \$1\)`).
WithArgs(table).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
}

for _, table := range tables {
var exists bool
err := db.Raw("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = ?)", table).Scan(&exists).Error
assert.NoError(t, err)
assert.True(t, exists, "Table %s should exist", table)
}

err = mock.ExpectationsWereMet()
assert.NoError(t, err)
}
func TestAddAndRemoveRuleConfig(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "config.*.yaml")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())

err = os.Setenv("OPENSHIELD_CONFIG_FILE", tmpfile.Name())
if err != nil {
t.Fatal(err)
}
defer os.Unsetenv("OPENSHIELD_CONFIG_FILE")

initialConfig := `
filters:
input:
- name: existing_rule
type: pii_filter
enabled: true
action:
type: redact
config:
plugin_name: pii_plugin
threshold: 80
`
if _, err := tmpfile.Write([]byte(initialConfig)); err != nil {
t.Fatal(err)
}
err = tmpfile.Close()
if err != nil {
t.Fatal(err)
}

viper.Reset()
viper.SetConfigFile(tmpfile.Name())
if err := viper.ReadInConfig(); err != nil {
t.Fatalf("Error reading config file: %v", err)
}

t.Run("AddRule", func(t *testing.T) {
oldStdin := os.Stdin
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
os.Stdin = r
defer func() {
os.Stdin = oldStdin
r.Close()
w.Close()
}()

inputs := []string{
"input\n",
"new_rule\n",
"sentiment_filter\n",
"block\n",
"sentiment_plugin\n",
"90\n",
}

go func() {
defer w.Close()
for _, input := range inputs {
_, err := w.Write([]byte(input))
if err != nil {
t.Logf("Error writing input: %v", err)
return
}
time.Sleep(100 * time.Millisecond)
}
}()

output, err := executeCommand(rootCmd, "config", "add-rule")
if err != nil {
t.Fatalf("Error executing add-rule command: %v", err)
}

t.Logf("Add Rule Command Output:\n%s", output)

v := viper.New()
v.SetConfigFile(tmpfile.Name())
err = v.ReadInConfig()
if err != nil {
t.Fatalf("Error reading updated config: %v", err)
}

rules := v.Get("filters.input")
rulesSlice, ok := rules.([]interface{})
if !ok {
t.Fatalf("Expected rules to be a slice, got %T", rules)
}

assert.Len(t, rulesSlice, 2, "Expected 2 rules after addition")
if len(rulesSlice) > 1 {
newRule := rulesSlice[1].(map[string]interface{})
assert.Equal(t, "new_rule", newRule["name"])
assert.Equal(t, "sentiment_filter", newRule["type"])
assert.Equal(t, true, newRule["enabled"])
assert.Equal(t, "block", newRule["action"].(map[string]interface{})["type"])
assert.Equal(t, "sentiment_plugin", newRule["config"].(map[string]interface{})["plugin_name"])
assert.Equal(t, int(90), newRule["config"].(map[string]interface{})["threshold"])
}
})

t.Run("RemoveRule", func(t *testing.T) {
oldStdin := os.Stdin
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
os.Stdin = r
defer func() {
os.Stdin = oldStdin
r.Close()
w.Close()
}()

input := "input\n2\n"
go func() {
defer w.Close()
_, err := w.Write([]byte(input))
if err != nil {
t.Logf("Error writing input: %v", err)
return
}
}()

output, err := executeCommand(rootCmd, "config", "remove-rule")
if err != nil {
t.Fatalf("Error executing remove-rule command: %v", err)
}

t.Logf("Remove Rule Command Output:\n%s", output)

v := viper.New()
v.SetConfigFile(tmpfile.Name())
err = v.ReadInConfig()
if err != nil {
t.Fatalf("Error reading updated config: %v", err)
}

rules := v.Get("filters.input")
rulesSlice, ok := rules.([]interface{})
if !ok {
t.Fatalf("Expected rules to be a slice, got %T", rules)
}

assert.Len(t, rulesSlice, 1, "Expected 1 rule after removal")
if len(rulesSlice) > 0 {
remainingRule := rulesSlice[0].(map[string]interface{})
assert.Equal(t, "existing_rule", remainingRule["name"], "Expected 'existing_rule' to remain")
}
})
}

func executeCommand(root *cobra.Command, args ...string) (string, error) {
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
root.SetOut(stdout)
root.SetErr(stderr)
root.SetArgs(args)

err := root.Execute()

output := stdout.String() + stderr.String()
return output, err
}
27 changes: 22 additions & 5 deletions cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ func editConfig() {

fmt.Println("\nEnter the number of the setting you want to change, or 'q' to quit:")
var input string
fmt.Scanln(&input)
_, err := fmt.Scanln(&input)
if err != nil {
return
}

if input == "q" {
break
Expand Down Expand Up @@ -196,7 +199,11 @@ func updateConfig(v *viper.Viper, path string, value string) error {
}
func addRule() {
v := viper.New()
v.SetConfigFile("config.yaml")
configFile := os.Getenv("OPENSHIELD_CONFIG_FILE")
if configFile == "" {
configFile = "config.yaml" // Default to config.yaml if env var is not set
}
v.SetConfigFile(configFile)
err := v.ReadInConfig()
if err != nil {
fmt.Printf("Error reading config file: %v\n", err)
Expand All @@ -205,7 +212,10 @@ func addRule() {

var ruleType string
fmt.Print("Enter rule type (input/output): ")
fmt.Scanln(&ruleType)
_, err = fmt.Scanln(&ruleType)
if err != nil {
return
}

if ruleType != "input" && ruleType != "output" {
fmt.Println("Invalid rule type. Please enter 'input' or 'output'.")
Expand Down Expand Up @@ -270,7 +280,11 @@ func createRuleWizard() map[string]interface{} {

func removeRule() {
v := viper.New()
v.SetConfigFile("config.yaml")
configFile := os.Getenv("OPENSHIELD_CONFIG_FILE")
if configFile == "" {
configFile = "config.yaml" // Default to config.yaml if env var is not set
}
v.SetConfigFile(configFile)
err := v.ReadInConfig()
if err != nil {
fmt.Printf("Error reading config file: %v\n", err)
Expand All @@ -279,7 +293,10 @@ func removeRule() {

var ruleType string
fmt.Print("Enter rule type (input/output): ")
fmt.Scanln(&ruleType)
_, err = fmt.Scanln(&ruleType)
if err != nil {
return
}

if ruleType != "input" && ruleType != "output" {
fmt.Println("Invalid rule type. Please enter 'input' or 'output'.")
Expand Down
2 changes: 1 addition & 1 deletion cmd/mock_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func init() {
func createMockData() {
db := lib.DB()
createMockTags(db, 10)
createMockRecords(db, &models.AiModels{}, 0)
createMockRecords(db, &models.AiModels{}, 1)
createMockRecords(db, &models.ApiKeys{}, 1)
createMockRecords(db, &models.AuditLogs{}, 1)
createMockRecords(db, &models.Products{}, 1)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/openshieldai/openshield
go 1.22

require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/fsnotify/fsnotify v1.7.0
github.com/go-faker/faker/v4 v4.4.2
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
Expand Down Expand Up @@ -45,6 +47,7 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
Expand Down
Loading

0 comments on commit 3902ab3

Please sign in to comment.