diff --git a/lib/config.go b/lib/config.go index e4edb37..fa628ed 100644 --- a/lib/config.go +++ b/lib/config.go @@ -33,6 +33,9 @@ type Providers struct { } type Services struct { + LlamaGuard *ServiceLlamaGuard `mapstructure:"llamaguard"` +} +type ServiceLlamaGuard struct { PromptGuard *ServicePromptGuard `mapstructure:"promptguard"` } type ServicePromptGuard struct { diff --git a/lib/llamaguard/handlers.go b/lib/llamaguard/handlers.go new file mode 100644 index 0000000..37adb37 --- /dev/null +++ b/lib/llamaguard/handlers.go @@ -0,0 +1,145 @@ +package llamaguard + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "github.com/google/uuid" + "log" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/openshieldai/openshield/lib" +) + +type AnalyzeRequest struct { + Text string `json:"text"` + Categories []string `json:"categories,omitempty"` + ExcludedCategories []string `json:"excluded_categories,omitempty"` +} + +type LlamaGuardResponse struct { + Response string `json:"response"` +} + +type AnalyzeResponse struct { + IsSafe bool `json:"is_safe"` + Categories []string `json:"violated_categories,omitempty"` + Analysis string `json:"analysis"` +} + +func SetupRoutes(r chi.Router) { + r.Post("/llamaguard/analyze", lib.AuthOpenShieldMiddleware(AnalyzeHandler)) +} + +func AnalyzeHandler(w http.ResponseWriter, r *http.Request) { + var req AnalyzeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + lib.ErrorResponse(w, fmt.Errorf("error reading request body: %v", err)) + return + } + + if req.Text == "" { + lib.ErrorResponse(w, fmt.Errorf("text field is required")) + return + } + + performAuditLogging(r, "llamaguard_analyze", "input", []byte(req.Text)) + + resp, err := callLlamaGuardService(r.Context(), req) + if err != nil { + lib.ErrorResponse(w, fmt.Errorf("error calling LlamaGuard service: %v", err)) + return + } + + respBytes, _ := json.Marshal(resp) + performAuditLogging(r, "llamaguard_analyze", "output", respBytes) + + json.NewEncoder(w).Encode(resp) +} + +func callLlamaGuardService(ctx context.Context, req AnalyzeRequest) (*AnalyzeResponse, error) { + config := lib.GetConfig() + llamaGuardURL := config.Services.LlamaGuard.BaseUrl + "/analyze" + + reqBody, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", llamaGuardURL, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("service returned status %d", resp.StatusCode) + } + + var llamaGuardResp LlamaGuardResponse + if err := json.NewDecoder(resp.Body).Decode(&llamaGuardResp); err != nil { + return nil, fmt.Errorf("error decoding response: %v", err) + } + + return parseLlamaGuardResponse(llamaGuardResp.Response), nil +} + +func parseLlamaGuardResponse(response string) *AnalyzeResponse { + + result := &AnalyzeResponse{ + Analysis: response, + IsSafe: response == "safe", + } + + if !result.IsSafe { + + for _, category := range []string{"S1", "S2", "S3", "S4", "S5", "S6", "S7", + "S8", "S9", "S10", "S11", "S12", "S13"} { + if bytes.Contains([]byte(response), []byte(category)) { + result.Categories = append(result.Categories, category) + } + } + } + + return result +} + +func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) { + apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID) + + productID, err := getProductIDFromAPIKey(apiKeyId) + if err != nil { + hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String())) + log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err) + return + } + + lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r) +} + +func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) { + var productIDStr string + err := lib.DB().Table("api_keys").Where("id = ?", apiKeyId).Pluck("product_id", &productIDStr).Error + if err != nil { + return uuid.Nil, err + } + + productID, err := uuid.Parse(productIDStr) + if err != nil { + return uuid.Nil, errors.New("failed to parse product_id as UUID") + } + + return productID, nil +} diff --git a/server/server.go b/server/server.go index b7fc746..9a1cf23 100644 --- a/server/server.go +++ b/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "github.com/openshieldai/openshield/lib/llamaguard" "github.com/openshieldai/openshield/lib/promptguard" "net/http" "time" @@ -144,6 +145,7 @@ func setupOpenAIRoutes(r chi.Router) { r.Route("/huggingface/v1", func(r chi.Router) { r.Post("/chat/completions", lib.AuthOpenShieldMiddleware(huggingface.ChatCompletionHandler)) }) + r.Post("/v1/llamaguard/analyze", lib.AuthOpenShieldMiddleware(llamaguard.AnalyzeHandler)) r.Post("/v1/promptguard/analyze", lib.AuthOpenShieldMiddleware(promptguard.AnalyzeHandler)) } diff --git a/services/llmguard/main.py b/services/llmguard/main.py new file mode 100644 index 0000000..63cc12e --- /dev/null +++ b/services/llmguard/main.py @@ -0,0 +1,190 @@ +# llamaguard_service.py +import os +import logging +from typing import Dict, List, Optional +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from huggingface_hub import login, HfApi +import uvicorn +from dotenv import load_dotenv + +load_dotenv() + + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + + +class AnalyzeRequest(BaseModel): + text: str + categories: List[str] = Field(default_factory=list) + excluded_categories: List[str] = Field(default_factory=list) + + +class AnalyzeResponse(BaseModel): + response: str + + +class LlamaGuard: + def __init__(self): + self.token = os.getenv("HUGGINGFACE_API_KEY") + if not self.token: + raise ValueError("HuggingFace API token not set") + + try: + login(token=self.token, write_permission=True) + self.api = HfApi() + except Exception as e: + logger.error(f"Authentication error: {str(e)}") + raise + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {self.device}") + + try: + logger.info("Loading Llama Guard model...") + model_id = "meta-llama/Llama-Guard-3-1B" + + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + device_map="auto", + token=self.token + ) + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, + token=self.token + ) + logger.info("Model loaded successfully") + + except Exception as e: + logger.error(f"Error loading model: {e}") + raise + + def clean_analysis_output(self, text: str) -> str: + + text = text.replace("<|eot_id|>", "").replace("<|endoftext|>", "") + + text = text.replace("\n", " ") + + text = " ".join(text.split()) + text = text.strip() + + if text.startswith("unsafe"): + text = text.replace("unsafe ", "unsafe,") + return text.strip() + + def analyze_content( + self, + text: str, + categories: Optional[List[str]] = None, + excluded_categories: Optional[List[str]] = None + ) -> str: + try: + logger.info(f"Analyzing text: '{text[:100]}{'...' if len(text) > 100 else ''}'") + + conversation = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": text + } + ] + } + ] + + kwargs = {"return_tensors": "pt"} + + if categories: + cats_dict = {cat: cat for cat in categories} + kwargs["categories"] = cats_dict + + if excluded_categories: + kwargs["excluded_category_keys"] = excluded_categories + + input_ids = self.tokenizer.apply_chat_template( + conversation, + **kwargs + ).to(self.device) + + with torch.inference_mode(): + prompt_len = input_ids.shape[-1] + output = self.model.generate( + input_ids, + max_new_tokens=256, + pad_token_id=0, + ) + + + analysis = self.tokenizer.decode( + output[0][prompt_len:], + skip_special_tokens=True, + clean_up_tokenization_spaces=True + ) + clean_analysis = self.clean_analysis_output(analysis) + + logger.info(f"Analysis completed. Result: {clean_analysis}") + return clean_analysis + + except Exception as e: + logger.error(f"Error during analysis: {e}") + raise + + +app = FastAPI( + title="LlamaGuard ", + description="Meta's Llama Guard model" +) + +llama_guard: Optional[LlamaGuard] = None + + +@app.on_event("startup") +async def startup_event(): + global llama_guard + try: + llama_guard = LlamaGuard() + except Exception as e: + logger.error(f"Failed to initialize LlamaGuard: {e}") + raise + + +@app.post("/analyze", response_model=AnalyzeResponse) +async def analyze_content(request: AnalyzeRequest): + try: + if not llama_guard: + raise HTTPException(status_code=500, detail="LlamaGuard not initialized") + + response = llama_guard.analyze_content( + request.text, + request.categories, + request.excluded_categories + ) + + return AnalyzeResponse(response=response) + + except Exception as e: + logger.error(f"Error processing request: {e}") + raise HTTPException( + status_code=500, + detail=f"Error analyzing content: {str(e)}" + ) + + +@app.get("/health") +async def health_check(): + if not llama_guard: + raise HTTPException(status_code=503, detail="Service not ready") + return {"status": "healthy"} + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/services/llmguard/pyproject.toml b/services/llmguard/pyproject.toml new file mode 100644 index 0000000..1763d27 --- /dev/null +++ b/services/llmguard/pyproject.toml @@ -0,0 +1,20 @@ +[tool.poetry] +name = "llamaguard-service" +version = "0.1.0" +description = "Content safety analysis service using Llama Guard" +authors = ["Openshield"] + +[tool.poetry.dependencies] +python = "^3.9" +fastapi = "^0.109.0" +uvicorn = "^0.27.0" +torch = "^2.1.0" +transformers = "^4.43.2" +pydantic = "^2.5.0" +huggingface-hub = "^0.23.2" +python-dotenv = "^1.0.0" +accelerate = "^1.1.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api"