Skip to content

Commit

Permalink
move prompt and llamagaurd to ruleservice
Browse files Browse the repository at this point in the history
  • Loading branch information
krichard1212 committed Nov 27, 2024
1 parent cfa193d commit 361625d
Show file tree
Hide file tree
Showing 5 changed files with 466 additions and 48 deletions.
101 changes: 54 additions & 47 deletions lib/rules/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/openshieldai/openshield/lib/types"
"io"
"log"
"net/http"
"sort"
"strings"
"sync"

"github.com/openshieldai/openshield/lib/types"

"github.com/openshieldai/go-openai"
"github.com/openshieldai/openshield/lib"
Expand All @@ -23,6 +21,18 @@ type InputTypes struct {
PIIFilter string
InvisibleChars string
Moderation string
LlamaGuard string
PromptGuard string
}

var inputTypes = InputTypes{
LanguageDetection: "language_detection",
PromptInjection: "prompt_injection",
PIIFilter: "pii_filter",
InvisibleChars: "invisible_chars",
Moderation: "moderation",
LlamaGuard: "llama_guard",
PromptGuard: "prompt_guard",
}

type Rule struct {
Expand All @@ -31,9 +41,10 @@ type Rule struct {
}

type RuleInspection struct {
CheckResult bool `json:"check_result"`
Score float64 `json:"score"`
AnonymizedContent string `json:"anonymized_content"`
CheckResult bool `json:"check_result"`
Score float64 `json:"score"`
AnonymizedContent string `json:"anonymized_content"`
Details map[string]interface{} `json:"details"`
}

type RuleResult struct {
Expand All @@ -46,14 +57,6 @@ type LanguageScore struct {
Score float64 `json:"score"`
}

var inputTypes = InputTypes{
LanguageDetection: "language_detection",
PromptInjection: "prompt_injection",
PIIFilter: "pii_filter",
InvisibleChars: "invisible_chars",
Moderation: "moderation",
}

func sendRequest(data Rule) (RuleResult, error) {
jsonify, err := json.Marshal(data)
if err != nil {
Expand Down Expand Up @@ -105,7 +108,33 @@ func genericHandler(inputConfig lib.Rule, rule RuleResult) (bool, string, error)
log.Println("Invalid Characters Rule Not Matched")
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}
func handleLlamaGuardAction(inputConfig lib.Rule, rule RuleResult) (bool, string, error) {
log.Printf("%s detection result: Match=%v, Score=%f", inputConfig.Type, rule.Match, rule.Inspection.Score)
if rule.Match {
if inputConfig.Action.Type == "block" {
log.Println("Blocking request due to LlamaGuard detection.")
return true, fmt.Sprintf(`{"status": "blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}
log.Println("Monitoring request due to LlamaGuard detection.")
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}
log.Println("LlamaGuard Rule Not Matched")
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}

func handlePromptGuardAction(inputConfig lib.Rule, rule RuleResult) (bool, string, error) {
log.Printf("%s detection result: Match=%v, Score=%f", inputConfig.Type, rule.Match, rule.Inspection.Score)
if rule.Match {
if inputConfig.Action.Type == "block" {
log.Println("Blocking request due to PromptGuard detection.")
return true, fmt.Sprintf(`{"status": "blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}
log.Println("Monitoring request due to PromptGuard detection.")
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}
log.Println("PromptGuard Rule Not Matched")
return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil
}
func handlePIIFilterAction(inputConfig lib.Rule, rule RuleResult, messages interface{}, userMessageIndex int) (bool, string, error) {
if rule.Inspection.CheckResult {
log.Println("PII detected, anonymizing content")
Expand Down Expand Up @@ -133,9 +162,9 @@ func handlePIIFilterAction(inputConfig lib.Rule, rule RuleResult, messages inter

func Input(_ *http.Request, request interface{}) (bool, string, error) {
config := lib.GetConfig()

log.Println("Starting Input function")

// Sort rules by order number
sort.Slice(config.Rules.Input, func(i, j int) bool {
return config.Rules.Input[i].OrderNumber < config.Rules.Input[j].OrderNumber
})
Expand All @@ -158,47 +187,21 @@ func Input(_ *http.Request, request interface{}) (bool, string, error) {
return true, "Invalid request type", fmt.Errorf("unsupported request type")
}

var (
wg sync.WaitGroup
mu sync.Mutex
blocked bool
message string
firstErr error
)

// Process rules sequentially instead of in parallel
for _, inputConfig := range config.Rules.Input {
log.Printf("Processing input rule: %s (Order: %d)", inputConfig.Type, inputConfig.OrderNumber)

if !inputConfig.Enabled {
log.Printf("Rule %s is disabled, skipping", inputConfig.Type)
continue
}

log.Printf("Processing input rule: %s (Order: %d)", inputConfig.Type, inputConfig.OrderNumber)
blocked, message, err := handleRule(inputConfig, messages, model, maxTokens, inputConfig.Type)

if err != nil {
return false, "", err
}
if blocked {
return blocked, message, err
return true, message, nil
}
wg.Add(1)
go func(ic lib.Rule) {
defer wg.Done()
blk, msg, err := handleRule(ic, messages, model, maxTokens, ic.Type)
if blk {
mu.Lock()
if !blocked { // Capture the first block
blocked = true
message = msg
firstErr = err
}
mu.Unlock()
}
}(inputConfig)
}

wg.Wait()

if blocked {
return blocked, message, firstErr
}

log.Println("Final result: No rules matched, request is not blocked")
Expand Down Expand Up @@ -274,6 +277,10 @@ func handleRuleAction(inputConfig lib.Rule, rule RuleResult, ruleType string, me
return genericHandler(inputConfig, rule)
case inputTypes.Moderation:
return genericHandler(inputConfig, rule)
case inputTypes.LlamaGuard:
return handleLlamaGuardAction(inputConfig, rule)
case inputTypes.PromptGuard:
return handlePromptGuardAction(inputConfig, rule)
default:
log.Printf("%s Rule Not Matched", ruleType)
return false, "", nil
Expand Down
63 changes: 62 additions & 1 deletion services/rule/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions services/rule/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ numpy = "^1.21.0"
spacy = "^3.7.5"
thinc = "^8.0.10"
openai = "^1.51.2"
accelerate = ">=0.26.0"


[build-system]
requires = ["poetry-core"]
Expand Down
Loading

0 comments on commit 361625d

Please sign in to comment.