diff --git a/lib/config.go b/lib/config.go index 677df40..fa628ed 100644 --- a/lib/config.go +++ b/lib/config.go @@ -36,6 +36,9 @@ type Services struct { LlamaGuard *ServiceLlamaGuard `mapstructure:"llamaguard"` } type ServiceLlamaGuard struct { + PromptGuard *ServicePromptGuard `mapstructure:"promptguard"` +} +type ServicePromptGuard struct { Enabled bool `mapstructure:"enabled"` BaseUrl string `mapstructure:"url"` } diff --git a/lib/promptguard/handlers.go b/lib/promptguard/handlers.go new file mode 100644 index 0000000..900f04f --- /dev/null +++ b/lib/promptguard/handlers.go @@ -0,0 +1,123 @@ +package promptguard + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "github.com/google/uuid" + "log" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/openshieldai/openshield/lib" +) + +type AnalyzeRequest struct { + Text string `json:"text"` + Threshold float64 `json:"threshold"` +} + +type AnalyzeResponse struct { + Score float64 `json:"score"` + Details struct { + BenignProbability float64 `json:"benign_probability"` + InjectionProbability float64 `json:"injection_probability"` + JailbreakProbability float64 `json:"jailbreak_probability"` + } `json:"details"` +} + +func SetupRoutes(r chi.Router) { + r.Post("/promptguard/analyze", lib.AuthOpenShieldMiddleware(AnalyzeHandler)) +} + +func AnalyzeHandler(w http.ResponseWriter, r *http.Request) { + var req AnalyzeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + lib.ErrorResponse(w, fmt.Errorf("error reading request body: %v", err)) + return + } + + if req.Text == "" { + lib.ErrorResponse(w, fmt.Errorf("text field is required")) + return + } + + performAuditLogging(r, "promptguard_analyze", "input", []byte(req.Text)) + + resp, err := callPromptGuardService(r.Context(), req) + if err != nil { + lib.ErrorResponse(w, fmt.Errorf("error calling PromptGuard service: %v", err)) + return + } + + respBytes, _ := json.Marshal(resp) + performAuditLogging(r, "promptguard_analyze", "output", respBytes) + + json.NewEncoder(w).Encode(resp) +} + +func callPromptGuardService(ctx context.Context, req AnalyzeRequest) (*AnalyzeResponse, error) { + config := lib.GetConfig() + promptGuardURL := config.Services.PromptGuard.BaseUrl + "/analyze" + + reqBody, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", promptGuardURL, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("service returned status %d", resp.StatusCode) + } + + var result AnalyzeResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %v", err) + } + + return &result, nil +} + +func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) { + apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID) + + productID, err := getProductIDFromAPIKey(apiKeyId) + if err != nil { + hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String())) + log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err) + return + } + + lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r) +} + +func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) { + var productIDStr string + err := lib.DB().Table("api_keys").Where("id = ?", apiKeyId).Pluck("product_id", &productIDStr).Error + if err != nil { + return uuid.Nil, err + } + + productID, err := uuid.Parse(productIDStr) + if err != nil { + return uuid.Nil, errors.New("failed to parse product_id as UUID") + } + + return productID, nil +} diff --git a/lib/provider/common.go b/lib/provider/common.go index b3e3e60..db90134 100644 --- a/lib/provider/common.go +++ b/lib/provider/common.go @@ -345,14 +345,6 @@ func HandleCommonRequestLogic(w http.ResponseWriter, r *http.Request, providerNa log.Printf("Received request: %+v", req) - PerformAuditLogging(r, providerName+"_create_message", "input", body) - - filtered, err := ProcessInput(w, r, req) - if err != nil || filtered { - log.Printf("Request filtered or error occurred: %v", err) - return ChatCompletionRequest{}, nil, uuid.Nil, false - } - apiKeyID, ok := r.Context().Value("apiKeyId").(uuid.UUID) if !ok { HandleError(w, fmt.Errorf("apiKeyId not found in context"), http.StatusInternalServerError) @@ -367,6 +359,37 @@ func HandleCommonRequestLogic(w http.ResponseWriter, r *http.Request, providerNa ctx := context.WithValue(r.Context(), "productID", productID) + // Check cache before running rules + if !req.Stream { + cachedResponse, cacheHit, err := HandleContextCache(ctx, req, productID) + if err != nil { + log.Printf("Error handling context cache: %v", err) + } + + if cacheHit { + log.Println("Cache hit, using cached response") + resp, err := CreateChatCompletionResponseFromCache(cachedResponse, req.Model) + if err != nil { + log.Printf("Error creating response from cache: %v", err) + } else { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Printf("Error encoding cached response: %v", err) + http.Error(w, "Error encoding response", http.StatusInternalServerError) + } + return ChatCompletionRequest{}, nil, uuid.Nil, true + } + } + } + + PerformAuditLogging(r, providerName+"_create_message", "input", body) + + filtered, err := ProcessInput(w, r, req) + if err != nil || filtered { + log.Printf("Request filtered or error occurred: %v", err) + return ChatCompletionRequest{}, nil, uuid.Nil, false + } + return req, ctx, productID, true } @@ -387,33 +410,28 @@ func HandleCacheLogic(ctx context.Context, req ChatCompletionRequest, productID func HandleAPICallAndResponse(w http.ResponseWriter, r *http.Request, ctx context.Context, req ChatCompletionRequest, productID uuid.UUID, provider Provider) { if req.Stream { - handleStreamingRequest(w, r, provider, req) - } else { - resp, cacheHit, err := HandleCacheLogic(ctx, req, productID) - if err != nil { - log.Printf("Error handling cache logic: %v", err) + if err := handleStreamingRequest(w, r, provider, req); err != nil { + HandleError(w, err, http.StatusInternalServerError) } + return + } - if !cacheHit { - log.Printf("Cache miss, making API call to provider") - resp, err = provider.CreateChatCompletion(ctx, req) - if err != nil { - HandleError(w, fmt.Errorf("error creating chat completion: %v", err), http.StatusInternalServerError) - return - } + resp, err := provider.CreateChatCompletion(ctx, req) + if err != nil { + HandleError(w, fmt.Errorf("error creating chat completion: %v", err), http.StatusInternalServerError) + return + } - if err := SetContextCacheResponse(ctx, req, resp, productID); err != nil { - log.Printf("Error setting context cache: %v", err) - } + if err := SetContextCacheResponse(ctx, req, resp, productID); err != nil { + log.Printf("Error setting context cache: %v", err) + } - PerformResponseAuditLogging(r, resp) - } + PerformResponseAuditLogging(r, resp) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(resp); err != nil { - log.Printf("Error encoding response: %v", err) - http.Error(w, "Error encoding response", http.StatusInternalServerError) - return - } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Printf("Error encoding response: %v", err) + http.Error(w, "Error encoding response", http.StatusInternalServerError) + return } } diff --git a/server/server.go b/server/server.go index 43e59eb..9a1cf23 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/openshieldai/openshield/lib/llamaguard" + "github.com/openshieldai/openshield/lib/promptguard" "net/http" "time" @@ -145,6 +146,7 @@ func setupOpenAIRoutes(r chi.Router) { r.Post("/chat/completions", lib.AuthOpenShieldMiddleware(huggingface.ChatCompletionHandler)) }) r.Post("/v1/llamaguard/analyze", lib.AuthOpenShieldMiddleware(llamaguard.AnalyzeHandler)) + r.Post("/v1/promptguard/analyze", lib.AuthOpenShieldMiddleware(promptguard.AnalyzeHandler)) } var redisClient *redis.Client diff --git a/services/promptguard/main.py b/services/promptguard/main.py new file mode 100644 index 0000000..346236e --- /dev/null +++ b/services/promptguard/main.py @@ -0,0 +1,170 @@ +import os + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import torch +from transformers import AutoTokenizer, AutoModelForSequenceClassification +import logging +from typing import Dict, Optional +from huggingface_hub import login, HfApi + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +app = FastAPI(title="PromptGuard Service") + + +class AnalyzeRequest(BaseModel): + text: str + threshold: float = 0.5 + temperature: float = 3.0 + + +class AnalyzeResponse(BaseModel): + score: float + details: Dict[str, float] + classification: str + + +class PromptGuard: + def __init__(self): + self.token = os.getenv("HUGGINGFACE_API_KEY") #NEED TO REQUEST ACCESS FOR THE MODEL! + if not self.token: + raise ValueError("Token not set") + + try: + login(token=self.token, write_permission=True) + api = HfApi() + except Exception as e: + logger.error(f"Authentication error: {str(e)}") + raise + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {self.device}") + + try: + logger.info("Loading model") + self.model = AutoModelForSequenceClassification.from_pretrained( + "meta-llama/Prompt-Guard-86M", + use_auth_token=self.token, + trust_remote_code=True + ) + self.tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Prompt-Guard-86M", + use_auth_token=self.token, + trust_remote_code=True + ) + self.model.to(self.device) + self.model.eval() + logger.info("Model loaded successfully") + except Exception as e: + logger.error(f"Error loading model: {e}") + raise + + def get_class_probabilities(self, text: str, temperature: float = 3.0) -> torch.Tensor: + inputs = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(self.device) + + with torch.no_grad(): + logits = self.model(**inputs).logits + + scaled_logits = logits / temperature + + probabilities = torch.nn.functional.softmax(scaled_logits, dim=-1) + + return probabilities[0] + + def get_indirect_injection_score(self, text: str, temperature: float = 3.0) -> float: + + probabilities = self.get_class_probabilities(text, temperature) + return (probabilities[1] + probabilities[2]).item() + + def analyze_text(self, text: str, temperature: float = 3.0) -> Dict[str, any]: + try: + probabilities = self.get_class_probabilities(text, temperature) + + + scores = { + "benign_probability": probabilities[0].item(), + "injection_probability": probabilities[1].item(), + "jailbreak_probability": probabilities[2].item() + } + + + if scores["jailbreak_probability"] > scores["injection_probability"]: + risk_score = scores["jailbreak_probability"] + classification = "jailbreak" + else: + risk_score = scores["injection_probability"] + classification = "injection" + + + logger.info(f"\nAnalyzing text: {text[:100]}...") + logger.info(f"Probabilities: {scores}") + logger.info(f"Classification: {classification}") + + return { + "score": risk_score, + "details": scores, + "classification": classification + } + + except Exception as e: + logger.error(f"Error during analysis: {e}") + raise + + +prompt_guard: Optional[PromptGuard] = None + + +@app.on_event("startup") +async def startup_event(): + global prompt_guard + try: + prompt_guard = PromptGuard() + except Exception as e: + logger.error(f"Failed to initialize PromptGuard: {e}") + raise + + +@app.post("/analyze", response_model=AnalyzeResponse) +async def analyze_prompt(request: AnalyzeRequest): + try: + if not prompt_guard: + raise HTTPException(status_code=500, detail="PromptGuard not initialized") + + results = prompt_guard.analyze_text(request.text, request.temperature) + + return AnalyzeResponse( + score=results["score"], + details=results["details"], + classification=results["classification"] + ) + + except Exception as e: + logger.error(f"Error processing request: {e}") + raise HTTPException( + status_code=500, + detail=f"Error analyzing prompt: {str(e)}" + ) + + +@app.get("/health") +async def health_check(): + if not prompt_guard: + raise HTTPException(status_code=503, detail="Service not ready") + return {"status": "healthy"} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/services/promptguard/pyproject.toml b/services/promptguard/pyproject.toml new file mode 100644 index 0000000..dafc1bd --- /dev/null +++ b/services/promptguard/pyproject.toml @@ -0,0 +1,31 @@ +[tool.poetry] +name = "promptguard-service" +version = "0.1.0" +description = "PromptGuard analysis service for OpenShield" +authors = ["Your Name "] + +[tool.poetry.dependencies] +python = "^3.9" +fastapi = "^0.104.1" +uvicorn = "^0.24.0" +transformers = "^4.35.2" +torch = "^2.1.1" +python-multipart = "^0.0.6" +pydantic = "^2.5.1" + +[tool.poetry.dev-dependencies] +pytest = "^7.4.3" +black = "^24.0.0" +isort = "^5.12.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.black] +line-length = 88 +target-version = ['py39'] + +[tool.isort] +profile = "black" +multi_line_output = 3 \ No newline at end of file