diff --git a/cmd/cli.go b/cmd/cli.go new file mode 100644 index 0000000..2b7b39f --- /dev/null +++ b/cmd/cli.go @@ -0,0 +1,146 @@ +package cmd + +import ( + "fmt" + "github.com/openshieldai/openshield/lib" + "github.com/openshieldai/openshield/server" + "github.com/spf13/cobra" + "os" + "os/signal" + "syscall" + "time" +) + +var rootCmd = &cobra.Command{ + Use: "openshield", + Short: "OpenShield CLI", +} + +func Execute() error { + return rootCmd.Execute() +} + +func init() { + rootCmd.AddCommand(dbCmd) + rootCmd.AddCommand(configCmd) + rootCmd.AddCommand(startServerCmd) + rootCmd.AddCommand(stopServerCmd) + dbCmd.AddCommand(createTablesCmd) + dbCmd.AddCommand(createMockDataCmd) + configCmd.AddCommand(editConfigCmd) + configCmd.AddCommand(addRuleCmd) + configCmd.AddCommand(removeRuleCmd) + configCmd.AddCommand(configWizardCmd) +} + +var dbCmd = &cobra.Command{ + Use: "db", + Short: "Database related commands", +} + +var createTablesCmd = &cobra.Command{ + Use: "create-tables", + Short: "Create database tables from models", + Run: func(cmd *cobra.Command, args []string) { + lib.DB() // Call the function from lib package + }, +} + +var createMockDataCmd = &cobra.Command{ + Use: "create-mock-data", + Short: "Create mock data in the database", + Run: func(cmd *cobra.Command, args []string) { + createMockData() + }, +} + +var configCmd = &cobra.Command{ + Use: "config", + Short: "Configuration related commands", +} + +var editConfigCmd = &cobra.Command{ + Use: "edit", + Short: "Edit the config.yaml file", + Run: func(cmd *cobra.Command, args []string) { + editConfig() + }, +} + +var addRuleCmd = &cobra.Command{ + Use: "add-rule", + Short: "Add a new rule to the configuration", + Run: func(cmd *cobra.Command, args []string) { + addRule() + }, +} + +var removeRuleCmd = &cobra.Command{ + Use: "remove-rule", + Short: "Remove a rule from the configuration", + Run: func(cmd *cobra.Command, args []string) { + removeRule() + }, +} + +var configWizardCmd = &cobra.Command{ + Use: "wizard", + Short: "Interactive wizard to create or update config.yaml", + Run: func(cmd *cobra.Command, args []string) { + runConfigWizard() + }, +} + +var startServerCmd = &cobra.Command{ + Use: "start", + Short: "Start the server", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("Starting the server...") + if err := server.StartServer(); err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + }, +} + +var stopServerCmd = &cobra.Command{ + Use: "stop", + Short: "Stop the server", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("Sending stop signal to the server...") + if err := stopServer(); err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + fmt.Println("Stop signal sent to the server") + }, +} + +func stopServer() error { + // Create a channel to receive OS signals + sigs := make(chan os.Signal, 1) + + // Register for SIGINT + signal.Notify(sigs, syscall.SIGINT) + + // Send SIGINT to the current process + p, err := os.FindProcess(os.Getpid()) + if err != nil { + return fmt.Errorf("failed to find current process: %v", err) + } + + err = p.Signal(syscall.SIGINT) + if err != nil { + return fmt.Errorf("failed to send interrupt signal: %v", err) + } + + // Wait for a short time to allow the signal to be processed + select { + case <-sigs: + fmt.Println("Interrupt signal received") + case <-time.After(2 * time.Second): + fmt.Println("No confirmation received, but signal was sent") + } + + return nil +} diff --git a/cmd/config.go b/cmd/config.go new file mode 100644 index 0000000..229ded5 --- /dev/null +++ b/cmd/config.go @@ -0,0 +1,489 @@ +package cmd + +import ( + "bufio" + "fmt" + "os" + "reflect" + "strconv" + "strings" + + "github.com/openshieldai/openshield/lib" + "github.com/spf13/viper" + "gopkg.in/yaml.v3" +) + +var ( + configOptions []configOption + optionCounter int +) + +type configOption struct { + number int + path string + value interface{} +} + +func editConfig() { + v := viper.New() + v.SetConfigFile("config.yaml") + err := v.ReadInConfig() + if err != nil { + fmt.Printf("Error reading config file: %v\n", err) + return + } + + for { + configOptions = []configOption{} // Clear the configOptions slice + optionCounter = 1 + generateConfigOptions(v.AllSettings(), "") + + fmt.Println("\nCurrent configuration:") + for _, option := range configOptions { + fmt.Printf("%d. %s: %v\n", option.number, option.path, option.value) + } + + fmt.Println("\nEnter the number of the setting you want to change, or 'q' to quit:") + var input string + fmt.Scanln(&input) + + if input == "q" { + break + } + + number, err := strconv.Atoi(input) + if err != nil || number < 1 || number > len(configOptions) { + fmt.Println("Invalid input. Please enter a valid number.") + continue + } + + option := configOptions[number-1] + fmt.Printf("Enter new value for %s: ", option.path) + scanner := bufio.NewScanner(os.Stdin) + if scanner.Scan() { + newValue := scanner.Text() + if err := updateConfig(v, option.path, newValue); err != nil { + fmt.Printf("Error updating config: %v\n", err) + } else { + fmt.Println("Configuration updated successfully.") + if err := v.WriteConfig(); err != nil { + fmt.Printf("Error writing config file: %v\n", err) + } else { + fmt.Println("Configuration file updated.") + } + } + } + + // Reload the configuration + err = v.ReadInConfig() + if err != nil { + fmt.Printf("Error reading updated config file: %v\n", err) + return + } + } +} + +func generateConfigOptions(value interface{}, prefix string) { + v := reflect.ValueOf(value) + + switch { + case value == nil: + configOptions = append(configOptions, configOption{ + number: optionCounter, + path: prefix, + value: nil, + }) + optionCounter++ + case v.Kind() == reflect.Map: + for _, key := range v.MapKeys() { + keyStr := key.String() + newPrefix := prefix + if newPrefix != "" { + newPrefix += "." + } + newPrefix += keyStr + generateConfigOptions(v.MapIndex(key).Interface(), newPrefix) + } + case v.Kind() == reflect.Slice: + for i := 0; i < v.Len(); i++ { + newPrefix := fmt.Sprintf("%s[%d]", prefix, i) + generateConfigOptions(v.Index(i).Interface(), newPrefix) + } + default: + configOptions = append(configOptions, configOption{ + number: optionCounter, + path: prefix, + value: v.Interface(), + }) + optionCounter++ + } +} + +func updateConfig(v *viper.Viper, path string, value string) error { + var parsedValue interface{} + err := yaml.Unmarshal([]byte(value), &parsedValue) + if err != nil { + return fmt.Errorf("invalid YAML: %v", err) + } + + parts := strings.Split(path, ".") + + // Create a function to recursively update nested maps + var updateNestedMap func(m map[string]interface{}, parts []string, value interface{}) error + updateNestedMap = func(m map[string]interface{}, parts []string, value interface{}) error { + if len(parts) == 1 { + m[parts[0]] = value + return nil + } + + key := parts[0] + if strings.HasSuffix(key, "]") { + arrayName := strings.Split(key, "[")[0] + indexStr := strings.TrimSuffix(strings.Split(key, "[")[1], "]") + index, err := strconv.Atoi(indexStr) + if err != nil { + return fmt.Errorf("invalid array index: %v", err) + } + + array, ok := m[arrayName].([]interface{}) + if !ok { + return fmt.Errorf("invalid array at %s", arrayName) + } + + if index < 0 || index >= len(array) { + return fmt.Errorf("array index out of bounds: %d", index) + } + + element, ok := array[index].(map[string]interface{}) + if !ok { + element = make(map[string]interface{}) + } + + err = updateNestedMap(element, parts[1:], value) + if err != nil { + return err + } + + array[index] = element + m[arrayName] = array + } else { + nextMap, ok := m[key].(map[string]interface{}) + if !ok { + nextMap = make(map[string]interface{}) + m[key] = nextMap + } + return updateNestedMap(nextMap, parts[1:], value) + } + + return nil + } + + // Get the current configuration + config := v.AllSettings() + + // Update the nested structure + err = updateNestedMap(config, parts, parsedValue) + if err != nil { + return err + } + + // Update the Viper instance with the modified configuration + for key, val := range config { + v.Set(key, val) + } + + return nil +} +func addRule() { + v := viper.New() + v.SetConfigFile("config.yaml") + err := v.ReadInConfig() + if err != nil { + fmt.Printf("Error reading config file: %v\n", err) + return + } + + var ruleType string + fmt.Print("Enter rule type (input/output): ") + fmt.Scanln(&ruleType) + + if ruleType != "input" && ruleType != "output" { + fmt.Println("Invalid rule type. Please enter 'input' or 'output'.") + return + } + + newRule := createRuleWizard() + + rules := v.Get(fmt.Sprintf("filters.%s", ruleType)) + var ruleSlice []interface{} + + if rules != nil { + var ok bool + ruleSlice, ok = rules.([]interface{}) + if !ok { + fmt.Printf("Error: unexpected format for %s rules.\n", ruleType) + return + } + } + + ruleSlice = append(ruleSlice, newRule) + v.Set(fmt.Sprintf("filters.%s", ruleType), ruleSlice) + + if err := v.WriteConfig(); err != nil { + fmt.Printf("Error writing config file: %v\n", err) + return + } + + fmt.Println("Rule added successfully.") +} + +func createRuleWizard() map[string]interface{} { + rule := make(map[string]interface{}) + rule["enabled"] = true + + fmt.Print("Enter rule name: ") + rule["name"] = getInput() + + fmt.Print("Enter rule type (e.g., pii_filter): ") + rule["type"] = getInput() + + action := make(map[string]interface{}) + fmt.Print("Enter action type: ") + action["type"] = getInput() + rule["action"] = action + + config := make(map[string]interface{}) + fmt.Print("Enter plugin name: ") + config["plugin_name"] = getInput() + + fmt.Print("Enter threshold (0-100): ") + threshold, err := strconv.Atoi(getInput()) + if err != nil || threshold < 0 || threshold > 100 { + fmt.Println("Invalid threshold. Using default value of 50.") + threshold = 50 + } + config["threshold"] = threshold + rule["config"] = config + + return rule +} + +func removeRule() { + v := viper.New() + v.SetConfigFile("config.yaml") + err := v.ReadInConfig() + if err != nil { + fmt.Printf("Error reading config file: %v\n", err) + return + } + + var ruleType string + fmt.Print("Enter rule type (input/output): ") + fmt.Scanln(&ruleType) + + if ruleType != "input" && ruleType != "output" { + fmt.Println("Invalid rule type. Please enter 'input' or 'output'.") + return + } + + rules := v.Get(fmt.Sprintf("filters.%s", ruleType)) + if rules == nil { + fmt.Printf("No %s rules found.\n", ruleType) + return + } + + ruleSlice, ok := rules.([]interface{}) + if !ok { + fmt.Printf("Error: unexpected format for %s rules.\n", ruleType) + return + } + + if len(ruleSlice) == 0 { + fmt.Printf("No %s rules found.\n", ruleType) + return + } + + fmt.Printf("Current %s rules:\n", ruleType) + for i, rule := range ruleSlice { + r, ok := rule.(map[string]interface{}) + if !ok { + fmt.Printf("%d. Unknown rule format\n", i+1) + continue + } + fmt.Printf("%d. %s (%s)\n", i+1, r["name"], r["type"]) + } + + var ruleIndex int + fmt.Print("Enter the number of the rule to remove: ") + _, err = fmt.Scanf("%d", &ruleIndex) + if err != nil || ruleIndex < 1 || ruleIndex > len(ruleSlice) { + fmt.Println("Invalid rule number.") + return + } + + removedRule := ruleSlice[ruleIndex-1] + ruleSlice = append(ruleSlice[:ruleIndex-1], ruleSlice[ruleIndex:]...) + v.Set(fmt.Sprintf("filters.%s", ruleType), ruleSlice) + + if err := v.WriteConfig(); err != nil { + fmt.Printf("Error writing config file: %v\n", err) + return + } + + fmt.Printf("Rule '%v' removed successfully.\n", removedRule.(map[string]interface{})["name"]) +} +func runConfigWizard() { + config := lib.Configuration{} + v := reflect.ValueOf(&config).Elem() + + fmt.Println("Do you want to change default values? (y/n):") + changeDefaults := confirmInput() + + fmt.Println("Please provide values for the following settings:") + + fillStructure(v, "", changeDefaults) + + yamlData, err := yaml.Marshal(config) + if err != nil { + fmt.Printf("Error marshaling config to YAML: %v\n", err) + return + } + + err = os.WriteFile("config.yaml", yamlData, 0644) + if err != nil { + fmt.Printf("Error writing config file: %v\n", err) + return + } + + fmt.Println("Configuration file 'config.yaml' has been created successfully!") +} + +func fillStructure(v reflect.Value, prefix string, changeDefaults bool) { + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + + if !field.CanSet() { + continue + } + + fieldName := fieldType.Name + fullName := prefix + fieldName + + switch field.Kind() { + case reflect.Struct: + fillStructure(field, fullName+".", changeDefaults) + case reflect.Ptr: + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + fillStructure(field.Elem(), fullName+".", changeDefaults) + case reflect.Slice: + handleSlice(field, fullName, changeDefaults) + default: + handleField(field, fieldType, fullName, changeDefaults) + } + } +} + +func handleSlice(field reflect.Value, fullName string, changeDefaults bool) { + fmt.Printf("Enter the number of elements for %s: ", fullName) + countStr := getInput() + count, err := strconv.Atoi(countStr) + if err != nil || count < 0 { + fmt.Println("Invalid input. Using 0 elements.") + return + } + + sliceType := field.Type().Elem() + newSlice := reflect.MakeSlice(field.Type(), count, count) + + for i := 0; i < count; i++ { + fmt.Printf("Element %d of %s:\n", i+1, fullName) + elem := reflect.New(sliceType).Elem() + fillStructure(elem, fmt.Sprintf("%s[%d].", fullName, i), changeDefaults) + newSlice.Index(i).Set(elem) + } + + field.Set(newSlice) +} + +func handleField(field reflect.Value, fieldType reflect.StructField, fullName string, changeDefaults bool) { + tag := fieldType.Tag.Get("mapstructure") + defaultValue := getDefaultValue(tag) + + if !changeDefaults && defaultValue != "" { + setValue(field, defaultValue) + return + } + + if strings.Contains(tag, "omitempty") && !changeDefaults { + return + } + + prompt := fmt.Sprintf("Enter value for %s (%v)", fullName, fieldType.Type) + if defaultValue != "" { + prompt += fmt.Sprintf(" [default: %s]", defaultValue) + } + prompt += ": " + + var value string + for { + fmt.Print(prompt) + value = getInput() + + if value == "" && defaultValue != "" { + value = defaultValue + } + + if setValue(field, value) { + break + } + fmt.Println("Invalid input. Please try again.") + } +} + +func getDefaultValue(tag string) string { + parts := strings.Split(tag, ",") + for _, part := range parts { + if strings.HasPrefix(part, "default=") { + return strings.TrimPrefix(part, "default=") + } + } + return "" +} + +func setValue(field reflect.Value, value string) bool { + switch field.Kind() { + case reflect.String: + field.SetString(value) + case reflect.Int: + if intValue, err := strconv.Atoi(value); err == nil { + field.SetInt(int64(intValue)) + } else { + return false + } + case reflect.Bool: + if boolValue, err := strconv.ParseBool(value); err == nil { + field.SetBool(boolValue) + } else { + return false + } + default: + return false + } + return true +} + +func getInput() string { + reader := bufio.NewReader(os.Stdin) + input, _ := reader.ReadString('\n') + return strings.TrimSpace(input) +} + +func confirmInput() bool { + input := getInput() + return strings.ToLower(input) == "y" || strings.ToLower(input) == "yes" +} diff --git a/cmd/mock_data.go b/cmd/mock_data.go new file mode 100644 index 0000000..b843f7e --- /dev/null +++ b/cmd/mock_data.go @@ -0,0 +1,95 @@ +package cmd + +import ( + "fmt" + "github.com/bxcodec/faker/v4" + "github.com/openshieldai/openshield/lib" + "github.com/openshieldai/openshield/models" + "gorm.io/gorm" + "math/rand" + "reflect" + "strings" +) + +var generatedTags []string + +func init() { + // Faker providers setup + faker.AddProvider("aifamily", func(v reflect.Value) (interface{}, error) { + return string(models.OpenAI), nil + }) + + faker.AddProvider("status", func(v reflect.Value) (interface{}, error) { + statuses := []string{string(models.Active), string(models.Inactive), string(models.Archived)} + return statuses[rand.Intn(len(statuses))], nil + }) + + faker.AddProvider("finishreason", func(v reflect.Value) (interface{}, error) { + statuses := []string{string(models.Stop), string(models.Length), string(models.Null), string(models.FunctionCall), string(models.ContentFilter)} + return statuses[rand.Intn(len(statuses))], nil + }) + + faker.AddProvider("tags", func(v reflect.Value) (interface{}, error) { + return getRandomTags(), nil + }) +} + +func createMockData() { + db := lib.DB() + createMockTags(db, 10) + createMockRecords(db, &models.AiModels{}, 2) + createMockRecords(db, &models.ApiKeys{}, 2) + createMockRecords(db, &models.AuditLogs{}, 2) + createMockRecords(db, &models.Products{}, 2) + createMockRecords(db, &models.Usage{}, 2) + createMockRecords(db, &models.Workspaces{}, 2) +} + +func createMockTags(db *gorm.DB, count int) { + for i := 0; i < count; i++ { + tag := &models.Tags{} + if err := faker.FakeData(tag); err != nil { + fmt.Printf("error generating fake data for Tag: %v\n", err) + continue + } + tag.Name = fmt.Sprintf("Tag%d", i+1) // Ensure unique names + generatedTags = append(generatedTags, tag.Name) + + fmt.Printf("Generated data for Tag:\n") + fmt.Printf("%+v\n\n", tag) + + result := db.Create(tag) + if result.Error != nil { + fmt.Printf("error inserting fake data for Tag: %v\n", result.Error) + } + } +} + +func getRandomTags() string { + numTags := rand.Intn(3) + 1 + tagsCopy := make([]string, len(generatedTags)) + copy(tagsCopy, generatedTags) + + rand.Shuffle(len(tagsCopy), func(i, j int) { + tagsCopy[i], tagsCopy[j] = tagsCopy[j], tagsCopy[i] + }) + + selectedTags := tagsCopy[:numTags] + + return strings.Join(selectedTags, ",") +} + +func createMockRecords(db *gorm.DB, model interface{}, count int) { + for i := 0; i < count; i++ { + if err := faker.FakeData(model); err != nil { + fmt.Printf("error generating fake data for %T: %v\n", model, err) + continue + } + fmt.Printf("Generated data for %T:\n", model) + fmt.Printf("%+v\n\n", model) + result := db.Create(model) + if result.Error != nil { + fmt.Errorf("error inserting fake data for %T: %v", model, result.Error) + } + } +} diff --git a/go.mod b/go.mod index 6405713..1d80d6d 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,20 @@ module github.com/openshieldai/openshield go 1.22 require ( + github.com/bxcodec/faker/v4 v4.0.0-beta.3 github.com/cespare/xxhash/v2 v2.3.0 github.com/fsnotify/fsnotify v1.7.0 github.com/gofiber/fiber/v2 v2.52.5 github.com/gofiber/storage/redis/v3 v3.1.2 github.com/google/uuid v1.6.0 github.com/sashabaranov/go-openai v1.27.0 + github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.9.0 github.com/tiktoken-go/tokenizer v0.1.1 github.com/valyala/fasthttp v1.55.0 + golang.org/x/sync v0.7.0 + gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.9 gorm.io/gorm v1.25.11 ) @@ -23,6 +27,7 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.6.0 // indirect @@ -54,9 +59,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.25.0 // indirect golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c3045a3..98c54ed 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,11 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bxcodec/faker/v4 v4.0.0-beta.3 h1:gqYNBvN72QtzKkYohNDKQlm+pg+uwBDVMN28nWHS18k= +github.com/bxcodec/faker/v4 v4.0.0-beta.3/go.mod h1:m6+Ch1Lj3fqW/unZmvkXIdxWS5+XQWPWxcbbQW2X+Ho= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -28,6 +31,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -53,8 +58,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -66,8 +69,6 @@ github.com/philhofer/fwd v1.1.3-0.20240612014219-fbbf4953d986/go.mod h1:RqIHx9QI github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis/v9 v9.5.4 h1:vOFYDKKVgrI5u++QvnMT7DksSMYg7Aw/Np4vLJLKLwY= -github.com/redis/go-redis/v9 v9.5.4/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/redis/go-redis/v9 v9.6.0 h1:NLck+Rab3AOTHw21CGRpvQpgTrAU4sgdCswqGtlhGRA= github.com/redis/go-redis/v9 v9.6.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -75,6 +76,7 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3N51bwOk= github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -87,6 +89,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= @@ -118,8 +122,6 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= -golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7 h1:wDLEX9a7YQoKdKNQt88rtydkqDxeGaBUTnIYc3iG/mA= -golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= diff --git a/main.go b/main.go index c6e9d02..dd16241 100644 --- a/main.go +++ b/main.go @@ -1,90 +1,15 @@ package main import ( - "strconv" - "time" + "fmt" + "os" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/limiter" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/fiber/v2/middleware/requestid" - "github.com/openshieldai/openshield/lib" - "github.com/openshieldai/openshield/lib/openai" + "github.com/openshieldai/openshield/cmd" ) -func setupRoute(app *fiber.App, path string, routesSettings lib.RouteSettings, keyGenerator ...func(c *fiber.Ctx) string) { - config := limiter.Config{ - Max: routesSettings.RateLimit.Max, - Expiration: time.Duration(routesSettings.RateLimit.Expiration) * time.Second * time.Duration(routesSettings.RateLimit.Window), - Storage: routesSettings.Storage, - } - - if len(keyGenerator) > 0 { - config.KeyGenerator = keyGenerator[0] - } - - app.Use(path, limiter.New(config)) -} - -func setupOpenAIRoutes(app *fiber.App) { - config := lib.GetRouteSettings() - routes := map[string]lib.RouteSettings{ - "/openai/v1/models": config, - "/openai/v1/models/:model": config, - "/openai/v1/chat/completions": config, - } - - for path, routeSettings := range routes { - setupRoute(app, path, routeSettings) - } - - app.Get("/openai/v1/models", lib.AuthOpenShieldMiddleware(), openai.ListModelsHandler) - app.Get("/openai/v1/models/:model", lib.AuthOpenShieldMiddleware(), openai.GetModelHandler) - app.Post("/openai/v1/chat/completions", lib.AuthOpenShieldMiddleware(), openai.ChatCompletionHandler) -} - -//func setupOpenShieldRoutes(app *fiber.App) { -// config := lib.GetConfig() -// routes := map[string]lib.Route{ -// "/tokenizer/:model": settings.Routes.Tokenizer, -// } -// -// for path := range routes { -// setupRoute(app, path, lib.GetRouteSettings()) -// } -// -// app.Post("/tokenizer/:model", lib.AuthOpenShieldMiddleware(), lib.TokenizerHandler) -//} - func main() { - config := lib.GetConfig() - - app := fiber.New(fiber.Config{ - Prefork: false, - CaseSensitive: false, - StrictRouting: true, - StreamRequestBody: true, - ServerHeader: "openshield", - AppName: "OpenShield", - }) - app.Use(requestid.New()) - app.Use(logger.New()) - - app.Use(logger.New(logger.Config{ - Format: "${pid} ${locals:requestid} ${status} - ${method} ${path}\n", - })) - - app.Use(func(c *fiber.Ctx) error { - c.Set("Content-Type", "application/json") - c.Set("Accept", "application/json") - return c.Next() - }) - - setupOpenAIRoutes(app) - //setupOpenShieldRoutes(app) - - err := app.Listen(":" + strconv.Itoa(config.Settings.Network.Port)) - if err != nil { - panic(err.Error()) + if err := cmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) } } diff --git a/models/aimodels.go b/models/aimodels.go index bb52c58..cc5d23c 100644 --- a/models/aimodels.go +++ b/models/aimodels.go @@ -10,11 +10,11 @@ const ( type AiModels struct { Base `gorm:"embedded"` - Family AiFamily `sql:"family;<-:false;not null;type:enum('openai')"` - ModelType string `gorm:"model_type;<-:false;not null"` - Model string `gorm:"model;<-:false;not null"` - Encoding string `gorm:"encoding;<-:false;not null"` - Size string `gorm:"size;<-:false"` - Quality string `gorm:"quality;<-:false"` - Status Status `sql:"status;<-:false;not null;type:enum('active', 'inactive', 'archived');default:'active'"` + Family AiFamily `sql:"family;not null;type:enum('openai')"` + ModelType string `gorm:"model_type;not null"` + Model string `gorm:"model;not null"` + Encoding string `gorm:"encoding;not null"` + Size string `gorm:"size;"` + Quality string `gorm:"quality;"` + Status Status `sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` } diff --git a/models/api_keys.go b/models/api_keys.go index bbd0441..efca50d 100644 --- a/models/api_keys.go +++ b/models/api_keys.go @@ -6,9 +6,9 @@ import ( type ApiKeys struct { Base `gorm:"embedded"` - ProductID uuid.UUID `gorm:"product_id;<-:false;not null"` - ApiKey string `gorm:"api_key;<-:false;not null;uniqueIndex;index:idx_api_keys_status,unique"` - Status Status `sql:"status;<-:false;not null;index:idx_api_keys_status,unique;type:enum('active', 'inactive', 'archived')"` + ProductID uuid.UUID `gorm:"product_id;not null"` + ApiKey string `gorm:"api_key;not null;uniqueIndex;index:idx_api_keys_status,unique"` + Status Status `sql:"status;not null;index:idx_api_keys_status,unique;type:enum('active', 'inactive', 'archived')"` Tags string `gorm:"tags;<-:false"` - CreatedBy string `gorm:"created_by;<-:false;not null"` + CreatedBy string `gorm:"created_by;not null"` } diff --git a/models/base.go b/models/base.go index fefc002..9b04d76 100644 --- a/models/base.go +++ b/models/base.go @@ -8,9 +8,9 @@ import ( ) type Base struct { - Id uuid.UUID `gorm:"id;<-:false;type:uuid;default:gen_random_uuid();primaryKey;not null"` - CreatedAt time.Time `gorm:"created_at;<-:false;default:now();not null"` - UpdatedAt time.Time `gorm:"updated_at;<-:false;default:now();not null"` + Id uuid.UUID `gorm:"id;type:uuid;default:gen_random_uuid();primaryKey;not null"` + CreatedAt time.Time `gorm:"created_at;default:now();not null"` + UpdatedAt time.Time `gorm:"updated_at;default:now();not null"` DeletedAt *gorm.DeletedAt `gorm:"deleted_at;index;<-:false"` } diff --git a/models/products.go b/models/products.go index 590812c..409bac0 100644 --- a/models/products.go +++ b/models/products.go @@ -6,9 +6,9 @@ import ( type Products struct { Base Base `gorm:"embedded"` - Status Status `sql:"status;<-:false;not null;type:enum('active', 'inactive', 'archived');default:'active';default:'active'"` - Name string `gorm:"name;<-:false;not null"` - WorkspaceID uuid.UUID `gorm:"workspace_id;type:uuid;<-:false;not null"` + Status Status `sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active';default:'active'"` + Name string `gorm:"name;not null"` + WorkspaceID uuid.UUID `gorm:"workspace_id;type:uuid;not null"` Tags string `gorm:"tags;<-:false"` - CreatedBy string `gorm:"created_by;<-:false;not null"` + CreatedBy string `gorm:"created_by;not null"` } diff --git a/models/tags.go b/models/tags.go index 3a3f1b7..09a997a 100644 --- a/models/tags.go +++ b/models/tags.go @@ -2,7 +2,7 @@ package models type Tags struct { Base Base `gorm:"embedded"` - Status Status `sql:"status;<-:false;not null;type:enum('active', 'inactive', 'archived');default:'active'"` - Name string `gorm:"name;<-:false;not null"` - CreatedBy string `gorm:"created_by;<-:false;not null"` + Status Status `sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` + Name string `gorm:"name;not null"` + CreatedBy string `gorm:"created_by;not null"` } diff --git a/models/workspace.go b/models/workspace.go index e8fe801..877e1ac 100644 --- a/models/workspace.go +++ b/models/workspace.go @@ -2,8 +2,8 @@ package models type Workspaces struct { Base Base `gorm:"embedded"` - Status Status `sql:"status;<-:false;not null;type:enum('active', 'inactive', 'archived');default:'active'"` - Name string `gorm:"name;<-:false;not null"` + Status Status `sql:"status;not null;type:enum('active', 'inactive', 'archived');default:'active'"` + Name string `gorm:"name;not null"` Tags string `gorm:"tags;<-:false"` - CreatedBy string `gorm:"created_by;<-:false;not null"` + CreatedBy string `gorm:"created_by;not null"` } diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..5679ca1 --- /dev/null +++ b/server/server.go @@ -0,0 +1,134 @@ +package server + +import ( + "context" + "fmt" + "github.com/gofiber/fiber/v2/middleware/limiter" + "github.com/openshieldai/openshield/lib/openai" + "golang.org/x/sync/errgroup" + "os" + "os/signal" + "syscall" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/fiber/v2/middleware/requestid" + "github.com/openshieldai/openshield/lib" +) + +var ( + app *fiber.App + config lib.Configuration +) + +func StartServer() error { + config = lib.GetConfig() + + app = fiber.New(fiber.Config{ + Prefork: false, + CaseSensitive: false, + StrictRouting: true, + StreamRequestBody: true, + ServerHeader: "openshield", + AppName: "OpenShield", + }) + app.Use(requestid.New()) + app.Use(logger.New()) + + app.Use(logger.New(logger.Config{ + Format: "${pid} ${locals:requestid} ${status} - ${method} ${path}\n", + })) + + app.Use(func(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Accept", "application/json") + return c.Next() + }) + + setupOpenAIRoutes(app) + //setupOpenShieldRoutes(app) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g, ctx := errgroup.WithContext(ctx) + + // Start the server + g.Go(func() error { + addr := fmt.Sprintf(":%d", config.Settings.Network.Port) + fmt.Printf("Server is starting on %s...\n", addr) + return app.Listen(addr) + }) + + // Handle graceful shutdown + g.Go(func() error { + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt, syscall.SIGTERM) + + select { + case <-quit: + fmt.Println("Shutting down server...") + return app.Shutdown() + case <-ctx.Done(): + return ctx.Err() + } + }) + + if err := g.Wait(); err != nil { + fmt.Printf("Server error: %v\n", err) + return err + } + + return nil +} + +func StopServer() error { + if app != nil { + fmt.Println("Stopping the server...") + return app.Shutdown() + } + return fmt.Errorf("server is not running") +} +func setupRoute(app *fiber.App, path string, routesSettings lib.RouteSettings, keyGenerator ...func(c *fiber.Ctx) string) { + config := limiter.Config{ + Max: routesSettings.RateLimit.Max, + Expiration: time.Duration(routesSettings.RateLimit.Expiration) * time.Second * time.Duration(routesSettings.RateLimit.Window), + Storage: routesSettings.Storage, + } + + if len(keyGenerator) > 0 { + config.KeyGenerator = keyGenerator[0] + } + + app.Use(path, limiter.New(config)) +} +func setupOpenAIRoutes(app *fiber.App) { + config := lib.GetRouteSettings() + routes := map[string]lib.RouteSettings{ + "/openai/v1/models": config, + "/openai/v1/models/:model": config, + "/openai/v1/chat/completions": config, + } + + for path, routeSettings := range routes { + setupRoute(app, path, routeSettings) + } + + app.Get("/openai/v1/models", lib.AuthOpenShieldMiddleware(), openai.ListModelsHandler) + app.Get("/openai/v1/models/:model", lib.AuthOpenShieldMiddleware(), openai.GetModelHandler) + app.Post("/openai/v1/chat/completions", lib.AuthOpenShieldMiddleware(), openai.ChatCompletionHandler) +} + +//func setupOpenShieldRoutes(app *fiber.App) { +// config := lib.GetConfig() +// routes := map[string]lib.Route{ +// "/tokenizer/:model": settings.Routes.Tokenizer, +// } +// +// for path := range routes { +// setupRoute(app, path, lib.GetRouteSettings()) +// } +// +// app.Post("/tokenizer/:model", lib.AuthOpenShieldMiddleware(), lib.TokenizerHandler) +//}