Skip to content

Commit

Permalink
Merge branch 'main' into llamaguard
Browse files Browse the repository at this point in the history
Signed-off-by: David Papp <[email protected]>
  • Loading branch information
pigri authored Nov 13, 2024
2 parents 35a95e4 + 682ac73 commit 0d74b17
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 31 deletions.
3 changes: 3 additions & 0 deletions lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ type Services struct {
LlamaGuard *ServiceLlamaGuard `mapstructure:"llamaguard"`
}
type ServiceLlamaGuard struct {
PromptGuard *ServicePromptGuard `mapstructure:"promptguard"`
}
type ServicePromptGuard struct {
Enabled bool `mapstructure:"enabled"`
BaseUrl string `mapstructure:"url"`
}
Expand Down
123 changes: 123 additions & 0 deletions lib/promptguard/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package promptguard

import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"log"
"net/http"

"github.com/go-chi/chi/v5"
"github.com/openshieldai/openshield/lib"
)

type AnalyzeRequest struct {
Text string `json:"text"`
Threshold float64 `json:"threshold"`
}

type AnalyzeResponse struct {
Score float64 `json:"score"`
Details struct {
BenignProbability float64 `json:"benign_probability"`
InjectionProbability float64 `json:"injection_probability"`
JailbreakProbability float64 `json:"jailbreak_probability"`
} `json:"details"`
}

func SetupRoutes(r chi.Router) {
r.Post("/promptguard/analyze", lib.AuthOpenShieldMiddleware(AnalyzeHandler))
}

func AnalyzeHandler(w http.ResponseWriter, r *http.Request) {
var req AnalyzeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
lib.ErrorResponse(w, fmt.Errorf("error reading request body: %v", err))
return
}

if req.Text == "" {
lib.ErrorResponse(w, fmt.Errorf("text field is required"))
return
}

performAuditLogging(r, "promptguard_analyze", "input", []byte(req.Text))

resp, err := callPromptGuardService(r.Context(), req)
if err != nil {
lib.ErrorResponse(w, fmt.Errorf("error calling PromptGuard service: %v", err))
return
}

respBytes, _ := json.Marshal(resp)
performAuditLogging(r, "promptguard_analyze", "output", respBytes)

json.NewEncoder(w).Encode(resp)
}

func callPromptGuardService(ctx context.Context, req AnalyzeRequest) (*AnalyzeResponse, error) {
config := lib.GetConfig()
promptGuardURL := config.Services.PromptGuard.BaseUrl + "/analyze"

reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %v", err)
}

httpReq, err := http.NewRequestWithContext(ctx, "POST", promptGuardURL, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}

httpReq.Header.Set("Content-Type", "application/json")

client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("service returned status %d", resp.StatusCode)
}

var result AnalyzeResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %v", err)
}

return &result, nil
}

func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyId)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}

lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)
}

func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) {
var productIDStr string
err := lib.DB().Table("api_keys").Where("id = ?", apiKeyId).Pluck("product_id", &productIDStr).Error
if err != nil {
return uuid.Nil, err
}

productID, err := uuid.Parse(productIDStr)
if err != nil {
return uuid.Nil, errors.New("failed to parse product_id as UUID")
}

return productID, nil
}
80 changes: 49 additions & 31 deletions lib/provider/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,6 @@ func HandleCommonRequestLogic(w http.ResponseWriter, r *http.Request, providerNa

log.Printf("Received request: %+v", req)

PerformAuditLogging(r, providerName+"_create_message", "input", body)

filtered, err := ProcessInput(w, r, req)
if err != nil || filtered {
log.Printf("Request filtered or error occurred: %v", err)
return ChatCompletionRequest{}, nil, uuid.Nil, false
}

apiKeyID, ok := r.Context().Value("apiKeyId").(uuid.UUID)
if !ok {
HandleError(w, fmt.Errorf("apiKeyId not found in context"), http.StatusInternalServerError)
Expand All @@ -367,6 +359,37 @@ func HandleCommonRequestLogic(w http.ResponseWriter, r *http.Request, providerNa

ctx := context.WithValue(r.Context(), "productID", productID)

// Check cache before running rules
if !req.Stream {
cachedResponse, cacheHit, err := HandleContextCache(ctx, req, productID)
if err != nil {
log.Printf("Error handling context cache: %v", err)
}

if cacheHit {
log.Println("Cache hit, using cached response")
resp, err := CreateChatCompletionResponseFromCache(cachedResponse, req.Model)
if err != nil {
log.Printf("Error creating response from cache: %v", err)
} else {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Error encoding cached response: %v", err)
http.Error(w, "Error encoding response", http.StatusInternalServerError)
}
return ChatCompletionRequest{}, nil, uuid.Nil, true
}
}
}

PerformAuditLogging(r, providerName+"_create_message", "input", body)

filtered, err := ProcessInput(w, r, req)
if err != nil || filtered {
log.Printf("Request filtered or error occurred: %v", err)
return ChatCompletionRequest{}, nil, uuid.Nil, false
}

return req, ctx, productID, true
}

Expand All @@ -387,33 +410,28 @@ func HandleCacheLogic(ctx context.Context, req ChatCompletionRequest, productID

func HandleAPICallAndResponse(w http.ResponseWriter, r *http.Request, ctx context.Context, req ChatCompletionRequest, productID uuid.UUID, provider Provider) {
if req.Stream {
handleStreamingRequest(w, r, provider, req)
} else {
resp, cacheHit, err := HandleCacheLogic(ctx, req, productID)
if err != nil {
log.Printf("Error handling cache logic: %v", err)
if err := handleStreamingRequest(w, r, provider, req); err != nil {
HandleError(w, err, http.StatusInternalServerError)
}
return
}

if !cacheHit {
log.Printf("Cache miss, making API call to provider")
resp, err = provider.CreateChatCompletion(ctx, req)
if err != nil {
HandleError(w, fmt.Errorf("error creating chat completion: %v", err), http.StatusInternalServerError)
return
}
resp, err := provider.CreateChatCompletion(ctx, req)
if err != nil {
HandleError(w, fmt.Errorf("error creating chat completion: %v", err), http.StatusInternalServerError)
return
}

if err := SetContextCacheResponse(ctx, req, resp, productID); err != nil {
log.Printf("Error setting context cache: %v", err)
}
if err := SetContextCacheResponse(ctx, req, resp, productID); err != nil {
log.Printf("Error setting context cache: %v", err)
}

PerformResponseAuditLogging(r, resp)
}
PerformResponseAuditLogging(r, resp)

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Error encoding response: %v", err)
http.Error(w, "Error encoding response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Error encoding response: %v", err)
http.Error(w, "Error encoding response", http.StatusInternalServerError)
return
}
}
2 changes: 2 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/openshieldai/openshield/lib/llamaguard"
"github.com/openshieldai/openshield/lib/promptguard"
"net/http"
"time"

Expand Down Expand Up @@ -145,6 +146,7 @@ func setupOpenAIRoutes(r chi.Router) {
r.Post("/chat/completions", lib.AuthOpenShieldMiddleware(huggingface.ChatCompletionHandler))
})
r.Post("/v1/llamaguard/analyze", lib.AuthOpenShieldMiddleware(llamaguard.AnalyzeHandler))
r.Post("/v1/promptguard/analyze", lib.AuthOpenShieldMiddleware(promptguard.AnalyzeHandler))
}

var redisClient *redis.Client
Expand Down
Loading

0 comments on commit 0d74b17

Please sign in to comment.