Skip to content

Commit

Permalink
162 use rules order from config (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
pigri authored Sep 12, 2024
2 parents 92dda27 + c66214d commit 7fdd981
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
3 changes: 3 additions & 0 deletions config_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ rules:
- name: "language_detection"
type: "language_detection"
enabled: true
order_number: 1
config:
url: "https://api-inference.huggingface.co/models/papluca/xlm-roberta-base-language-detection"
apikey: ""
Expand All @@ -11,6 +12,7 @@ rules:
- name: "pii_example"
type: "pii_filter"
enabled: true
order_number: 2
config:
model_name: "value1"
model_url: "value2"
Expand Down Expand Up @@ -53,6 +55,7 @@ rules:
plugin_name: "prompt_injection_llm"
threshold: 0.85
enabled: true
order_number: 3
config:
plugin_name: "prompt_injection_llm"
threshold: 0.85
Expand Down
11 changes: 6 additions & 5 deletions lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ type Rules struct {

// Rule defines a rule configuration
type Rule struct {
Enabled bool `mapstructure:"enabled,default=false"`
Name string `mapstructure:"name"`
Type string `mapstructure:"type"`
Config Config `mapstructure:"config"`
Action Action `mapstructure:"action"`
Enabled bool `mapstructure:"enabled,default=false"`
Name string `mapstructure:"name"`
Type string `mapstructure:"type"`
Config Config `mapstructure:"config"`
Action Action `mapstructure:"action"`
OrderNumber int `mapstructure:"order_number"`
}

// Config holds the configuration specifics of a filter
Expand Down
21 changes: 13 additions & 8 deletions rules/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"log"
"net/http"
"sort"
"strings"

"github.com/openshieldai/openshield/lib"
Expand Down Expand Up @@ -148,9 +149,17 @@ func Input(_ *http.Request, request interface{}) (bool, string, error) {

log.Println("Starting Input function")

for input := range config.Rules.Input {
inputConfig := config.Rules.Input[input]
log.Printf("Processing input rule: %s", inputConfig.Type)
sort.Slice(config.Rules.Input, func(i, j int) bool {
return config.Rules.Input[i].OrderNumber < config.Rules.Input[j].OrderNumber
})

for _, inputConfig := range config.Rules.Input {
log.Printf("Processing input rule: %s (Order: %d)", inputConfig.Type, inputConfig.OrderNumber)

if !inputConfig.Enabled {
log.Printf("Rule %s is disabled, skipping", inputConfig.Type)
continue
}

blocked, message, err := handleRule(inputConfig, request, inputConfig.Type)
if blocked {
Expand All @@ -162,11 +171,7 @@ func Input(_ *http.Request, request interface{}) (bool, string, error) {
return false, "request is not blocked", nil
}
func handleRule(inputConfig lib.Rule, request interface{}, ruleType string) (bool, string, error) {
if !inputConfig.Enabled {
return false, "", nil
}

log.Printf("%s check enabled", ruleType)
log.Printf("%s check enabled (Order: %d)", ruleType, inputConfig.OrderNumber)

var extractedPrompt string
var userMessageIndex int
Expand Down

0 comments on commit 7fdd981

Please sign in to comment.