From 361625d1736b46d2656cda8a1cce0cb7a701b7be Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Wed, 27 Nov 2024 16:41:33 +0100 Subject: [PATCH 1/4] move prompt and llamagaurd to ruleservice --- lib/rules/input.go | 101 ++++++------ services/rule/poetry.lock | 63 ++++++- services/rule/pyproject.toml | 2 + services/rule/src/plugins/llama_guard.py | 191 ++++++++++++++++++++++ services/rule/src/plugins/prompt_guard.py | 157 ++++++++++++++++++ 5 files changed, 466 insertions(+), 48 deletions(-) create mode 100644 services/rule/src/plugins/llama_guard.py create mode 100644 services/rule/src/plugins/prompt_guard.py diff --git a/lib/rules/input.go b/lib/rules/input.go index cd699f6..f4b2734 100644 --- a/lib/rules/input.go +++ b/lib/rules/input.go @@ -4,14 +4,12 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/openshieldai/openshield/lib/types" "io" "log" "net/http" "sort" "strings" - "sync" - - "github.com/openshieldai/openshield/lib/types" "github.com/openshieldai/go-openai" "github.com/openshieldai/openshield/lib" @@ -23,6 +21,18 @@ type InputTypes struct { PIIFilter string InvisibleChars string Moderation string + LlamaGuard string + PromptGuard string +} + +var inputTypes = InputTypes{ + LanguageDetection: "language_detection", + PromptInjection: "prompt_injection", + PIIFilter: "pii_filter", + InvisibleChars: "invisible_chars", + Moderation: "moderation", + LlamaGuard: "llama_guard", + PromptGuard: "prompt_guard", } type Rule struct { @@ -31,9 +41,10 @@ type Rule struct { } type RuleInspection struct { - CheckResult bool `json:"check_result"` - Score float64 `json:"score"` - AnonymizedContent string `json:"anonymized_content"` + CheckResult bool `json:"check_result"` + Score float64 `json:"score"` + AnonymizedContent string `json:"anonymized_content"` + Details map[string]interface{} `json:"details"` } type RuleResult struct { @@ -46,14 +57,6 @@ type LanguageScore struct { Score float64 `json:"score"` } -var inputTypes = InputTypes{ - LanguageDetection: "language_detection", - PromptInjection: "prompt_injection", - PIIFilter: "pii_filter", - InvisibleChars: "invisible_chars", - Moderation: "moderation", -} - func sendRequest(data Rule) (RuleResult, error) { jsonify, err := json.Marshal(data) if err != nil { @@ -105,7 +108,33 @@ func genericHandler(inputConfig lib.Rule, rule RuleResult) (bool, string, error) log.Println("Invalid Characters Rule Not Matched") return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil } +func handleLlamaGuardAction(inputConfig lib.Rule, rule RuleResult) (bool, string, error) { + log.Printf("%s detection result: Match=%v, Score=%f", inputConfig.Type, rule.Match, rule.Inspection.Score) + if rule.Match { + if inputConfig.Action.Type == "block" { + log.Println("Blocking request due to LlamaGuard detection.") + return true, fmt.Sprintf(`{"status": "blocked", "rule_type": "%s"}`, inputConfig.Type), nil + } + log.Println("Monitoring request due to LlamaGuard detection.") + return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil + } + log.Println("LlamaGuard Rule Not Matched") + return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil +} +func handlePromptGuardAction(inputConfig lib.Rule, rule RuleResult) (bool, string, error) { + log.Printf("%s detection result: Match=%v, Score=%f", inputConfig.Type, rule.Match, rule.Inspection.Score) + if rule.Match { + if inputConfig.Action.Type == "block" { + log.Println("Blocking request due to PromptGuard detection.") + return true, fmt.Sprintf(`{"status": "blocked", "rule_type": "%s"}`, inputConfig.Type), nil + } + log.Println("Monitoring request due to PromptGuard detection.") + return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil + } + log.Println("PromptGuard Rule Not Matched") + return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil +} func handlePIIFilterAction(inputConfig lib.Rule, rule RuleResult, messages interface{}, userMessageIndex int) (bool, string, error) { if rule.Inspection.CheckResult { log.Println("PII detected, anonymizing content") @@ -133,9 +162,9 @@ func handlePIIFilterAction(inputConfig lib.Rule, rule RuleResult, messages inter func Input(_ *http.Request, request interface{}) (bool, string, error) { config := lib.GetConfig() - log.Println("Starting Input function") + // Sort rules by order number sort.Slice(config.Rules.Input, func(i, j int) bool { return config.Rules.Input[i].OrderNumber < config.Rules.Input[j].OrderNumber }) @@ -158,47 +187,21 @@ func Input(_ *http.Request, request interface{}) (bool, string, error) { return true, "Invalid request type", fmt.Errorf("unsupported request type") } - var ( - wg sync.WaitGroup - mu sync.Mutex - blocked bool - message string - firstErr error - ) - + // Process rules sequentially instead of in parallel for _, inputConfig := range config.Rules.Input { - log.Printf("Processing input rule: %s (Order: %d)", inputConfig.Type, inputConfig.OrderNumber) - if !inputConfig.Enabled { log.Printf("Rule %s is disabled, skipping", inputConfig.Type) continue } + log.Printf("Processing input rule: %s (Order: %d)", inputConfig.Type, inputConfig.OrderNumber) blocked, message, err := handleRule(inputConfig, messages, model, maxTokens, inputConfig.Type) - + if err != nil { + return false, "", err + } if blocked { - return blocked, message, err + return true, message, nil } - wg.Add(1) - go func(ic lib.Rule) { - defer wg.Done() - blk, msg, err := handleRule(ic, messages, model, maxTokens, ic.Type) - if blk { - mu.Lock() - if !blocked { // Capture the first block - blocked = true - message = msg - firstErr = err - } - mu.Unlock() - } - }(inputConfig) - } - - wg.Wait() - - if blocked { - return blocked, message, firstErr } log.Println("Final result: No rules matched, request is not blocked") @@ -274,6 +277,10 @@ func handleRuleAction(inputConfig lib.Rule, rule RuleResult, ruleType string, me return genericHandler(inputConfig, rule) case inputTypes.Moderation: return genericHandler(inputConfig, rule) + case inputTypes.LlamaGuard: + return handleLlamaGuardAction(inputConfig, rule) + case inputTypes.PromptGuard: + return handlePromptGuardAction(inputConfig, rule) default: log.Printf("%s Rule Not Matched", ruleType) return false, "", nil diff --git a/services/rule/poetry.lock b/services/rule/poetry.lock index 529b142..2e7a899 100644 --- a/services/rule/poetry.lock +++ b/services/rule/poetry.lock @@ -1,5 +1,36 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "accelerate" +version = "1.1.1" +description = "Accelerate" +optional = false +python-versions = ">=3.9.0" +files = [ + {file = "accelerate-1.1.1-py3-none-any.whl", hash = "sha256:61edd81762131b8d4bede008643fa1e1f3bf59bec710ebda9771443e24feae02"}, + {file = "accelerate-1.1.1.tar.gz", hash = "sha256:0d39dfac557052bc735eb2703a0e87742879e1e40b88af8a2f9a93233d4cd7db"}, +] + +[package.dependencies] +huggingface-hub = ">=0.21.0" +numpy = ">=1.17,<3.0.0" +packaging = ">=20.0" +psutil = "*" +pyyaml = "*" +safetensors = ">=0.4.3" +torch = ">=1.10.0" + +[package.extras] +deepspeed = ["deepspeed"] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "diffusers", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.6.4,<0.7.0)", "scikit-learn", "scipy", "timm", "torchdata (>=0.8.0)", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.6.4,<0.7.0)"] +rich = ["rich"] +sagemaker = ["sagemaker"] +test-dev = ["bitsandbytes", "datasets", "diffusers", "evaluate", "scikit-learn", "scipy", "timm", "torchdata (>=0.8.0)", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"] +test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] +testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchdata (>=0.8.0)", "torchpippy (>=0.2.0)", "tqdm", "transformers"] + [[package]] name = "annotated-types" version = "0.7.0" @@ -1270,6 +1301,36 @@ pycryptodome = ">=3.10.1" [package.extras] server = ["flask (>=1.1)"] +[[package]] +name = "psutil" +version = "6.1.0" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:000d1d1ebd634b4efb383f4034437384e44a6d455260aaee2eca1e9c1b55f047"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:5cd2bcdc75b452ba2e10f0e8ecc0b57b827dd5d7aaffbc6821b2a9a242823a76"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:045f00a43c737f960d273a83973b2511430d61f283a44c96bf13a6e829ba8fdc"}, + {file = "psutil-6.1.0-cp27-none-win32.whl", hash = "sha256:9118f27452b70bb1d9ab3198c1f626c2499384935aaf55388211ad982611407e"}, + {file = "psutil-6.1.0-cp27-none-win_amd64.whl", hash = "sha256:a8506f6119cff7015678e2bce904a4da21025cc70ad283a53b099e7620061d85"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6e2dcd475ce8b80522e51d923d10c7871e45f20918e027ab682f94f1c6351688"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0895b8414afafc526712c498bd9de2b063deaac4021a3b3c34566283464aff8e"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dcbfce5d89f1d1f2546a2090f4fcf87c7f669d1d90aacb7d7582addece9fb38"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498c6979f9c6637ebc3a73b3f87f9eb1ec24e1ce53a7c5173b8508981614a90b"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d905186d647b16755a800e7263d43df08b790d709d575105d419f8b6ef65423a"}, + {file = "psutil-6.1.0-cp36-cp36m-win32.whl", hash = "sha256:6d3fbbc8d23fcdcb500d2c9f94e07b1342df8ed71b948a2649b5cb060a7c94ca"}, + {file = "psutil-6.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1209036fbd0421afde505a4879dee3b2fd7b1e14fee81c0069807adcbbcca747"}, + {file = "psutil-6.1.0-cp37-abi3-win32.whl", hash = "sha256:1ad45a1f5d0b608253b11508f80940985d1d0c8f6111b5cb637533a0e6ddc13e"}, + {file = "psutil-6.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:a8fb3752b491d246034fa4d279ff076501588ce8cbcdbb62c32fd7a377d996be"}, + {file = "psutil-6.1.0.tar.gz", hash = "sha256:353815f59a7f64cdaca1c0307ee13558a0512f6db064e92fe833784f08539c7a"}, +] + +[package.extras] +dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"] +test = ["pytest", "pytest-xdist", "setuptools"] + [[package]] name = "pycryptodome" version = "3.21.0" @@ -2659,4 +2720,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "bc549196d825bbf5f10c4cbc349655c25d38598ebaaf8c8aecccc271942e6469" +content-hash = "23dfe18fd5b360f660cef840e53df72a97df81a8a88b8bb8e0a89d303856e550" diff --git a/services/rule/pyproject.toml b/services/rule/pyproject.toml index b17d859..3c27a03 100644 --- a/services/rule/pyproject.toml +++ b/services/rule/pyproject.toml @@ -21,6 +21,8 @@ numpy = "^1.21.0" spacy = "^3.7.5" thinc = "^8.0.10" openai = "^1.51.2" +accelerate = ">=0.26.0" + [build-system] requires = ["poetry-core"] diff --git a/services/rule/src/plugins/llama_guard.py b/services/rule/src/plugins/llama_guard.py new file mode 100644 index 0000000..aef018e --- /dev/null +++ b/services/rule/src/plugins/llama_guard.py @@ -0,0 +1,191 @@ +""" +LlamaGuard plugin for content safety analysis using Meta's Llama Guard 3-1B model. +""" + +import os +import logging +from typing import Dict, Any, List, Optional +import torch +import accelerate +from transformers import AutoModelForCausalLM, AutoTokenizer +from huggingface_hub import login, HfApi + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + + +def get_huggingface_token(): + token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") + if not token: + logger.error("HUGGINGFACE_TOKEN or HUGGINGFACE_API_KEY environment variable not set") + return None + return token + + +class LlamaGuardAnalyzer: + def __init__(self): + self.token = get_huggingface_token() + if not self.token: + raise ValueError("HuggingFace token not set") + + try: + login(token=self.token, write_permission=True) + self.api = HfApi() + logger.info("Successfully logged into HuggingFace") + except Exception as e: + logger.error(f"HuggingFace authentication error: {str(e)}") + raise + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {self.device}") + + try: + logger.info("Loading Llama Guard model...") + model_id = "meta-llama/Llama-Guard-3-1B" + + if torch.cuda.is_available(): + logger.info("Using GPU for model loading") + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + device_map="auto", + token=self.token + ) + else: + logger.info("Using CPU for model loading") + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + token=self.token + ) + self.model.to(self.device) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, + token=self.token + ) + logger.info("LlamaGuard model loaded successfully") + + except Exception as e: + logger.error(f"Error loading LlamaGuard model: {str(e)}") + raise + + def clean_analysis_output(self, text: str) -> str: + text = text.replace("<|eot_id|>", "").replace("<|endoftext|>", "") + text = text.replace("\n", " ") + text = " ".join(text.split()) + text = text.strip() + if text.startswith("unsafe"): + text = text.replace("unsafe ", "unsafe,") + return text.strip() + + def analyze_content( + self, + text: str, + categories: Optional[List[str]] = None, + excluded_categories: Optional[List[str]] = None + ) -> str: + try: + logger.info(f"Analyzing text: '{text[:100]}{'...' if len(text) > 100 else ''}'") + + conversation = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": text + } + ] + } + ] + + kwargs = {"return_tensors": "pt"} + if categories: + cats_dict = {cat: cat for cat in categories} + kwargs["categories"] = cats_dict + if excluded_categories: + kwargs["excluded_category_keys"] = excluded_categories + + input_ids = self.tokenizer.apply_chat_template( + conversation, + **kwargs + ).to(self.device) + + with torch.inference_mode(): + prompt_len = input_ids.shape[-1] + output = self.model.generate( + input_ids, + max_new_tokens=256, + pad_token_id=0, + ) + + analysis = self.tokenizer.decode( + output[0][prompt_len:], + skip_special_tokens=True, + clean_up_tokenization_spaces=True + ) + clean_analysis = self.clean_analysis_output(analysis) + logger.info(f"Analysis completed. Result: {clean_analysis}") + return clean_analysis + + except Exception as e: + logger.error(f"Error during analysis: {e}") + raise + + +analyzer = None +try: + logger.info("Initializing LlamaGuard analyzer...") + analyzer = LlamaGuardAnalyzer() + logger.info("LlamaGuard analyzer initialized successfully") +except Exception as e: + logger.error(f"Failed to initialize LlamaGuard analyzer: {str(e)}") + + +def handler(text: str, threshold: float, config: Dict[str, Any]) -> Dict[str, Any]: + try: + if analyzer is None: + raise RuntimeError("LlamaGuard analyzer not initialized") + + categories = config.get('categories', []) + excluded_categories = config.get('excluded_categories', []) + + analysis = analyzer.analyze_content( + text, + categories=categories, + excluded_categories=excluded_categories + ) + + is_unsafe = not analysis.lower().startswith('safe') + score = 1.0 if is_unsafe else 0.0 + + violated_categories = [] + if is_unsafe: + for category in ["S1", "S2", "S3", "S4", "S5", "S6", "S7", + "S8", "S9", "S10", "S11", "S12", "S13"]: + if category in analysis: + violated_categories.append(category) + + return { + "check_result": score > threshold, + "score": score, + "details": { + "is_safe": not is_unsafe, + "violated_categories": violated_categories, + "raw_analysis": analysis + } + } + + except Exception as e: + logger.error(f"Error in LlamaGuard analysis: {str(e)}") + return { + "check_result": False, + "score": 0.0, + "details": {"error": str(e)} + } diff --git a/services/rule/src/plugins/prompt_guard.py b/services/rule/src/plugins/prompt_guard.py new file mode 100644 index 0000000..c36a70d --- /dev/null +++ b/services/rule/src/plugins/prompt_guard.py @@ -0,0 +1,157 @@ +""" +PromptGuard plugin for detecting prompt injection attacks using the PromptGuard model. +""" + +import os +import logging +from typing import Dict, Any +import torch +from transformers import AutoTokenizer, AutoModelForSequenceClassification +from huggingface_hub import login, HfApi + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def get_huggingface_token(): + """Get token from environment with proper error handling.""" + token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") + if not token: + logger.error("HUGGINGFACE_TOKEN or HUGGINGFACE_API_KEY environment variable not set") + return None + return token + +class PromptGuardAnalyzer: + def __init__(self): + self.token = get_huggingface_token() + if not self.token: + raise ValueError("HuggingFace token not set") + + try: + login(token=self.token, write_permission=True) + self.api = HfApi() + logger.info("Successfully logged into HuggingFace") + except Exception as e: + logger.error(f"HuggingFace authentication error: {str(e)}") + raise + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {self.device}") + + try: + logger.info("Loading PromptGuard model...") + self.model = AutoModelForSequenceClassification.from_pretrained( + "meta-llama/Prompt-Guard-86M", + use_auth_token=self.token, + trust_remote_code=True + ) + self.tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Prompt-Guard-86M", + use_auth_token=self.token, + trust_remote_code=True + ) + self.model.to(self.device) + self.model.eval() + logger.info("PromptGuard model loaded successfully") + except Exception as e: + logger.error(f"Error loading PromptGuard model: {str(e)}") + raise + + def get_class_probabilities(self, text: str, temperature: float = 3.0) -> torch.Tensor: + inputs = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(self.device) + + with torch.no_grad(): + logits = self.model(**inputs).logits + scaled_logits = logits / temperature + probabilities = torch.nn.functional.softmax(scaled_logits, dim=-1) + + return probabilities[0] + + def analyze_text(self, text: str, temperature: float = 3.0) -> Dict[str, Any]: + try: + probabilities = self.get_class_probabilities(text, temperature) + + scores = { + "benign_probability": probabilities[0].item(), + "injection_probability": probabilities[1].item(), + "jailbreak_probability": probabilities[2].item() + } + + if scores["jailbreak_probability"] > scores["injection_probability"]: + risk_score = scores["jailbreak_probability"] + classification = "jailbreak" + else: + risk_score = scores["injection_probability"] + classification = "injection" + + logger.info(f"\nAnalyzing text: {text[:100]}...") + logger.info(f"Probabilities: {scores}") + logger.info(f"Classification: {classification}") + + return { + "score": risk_score, + "details": scores, + "classification": classification + } + + except Exception as e: + logger.error(f"Error during analysis: {e}") + raise + +# Initialize global analyzer +analyzer = None +try: + logger.info("Initializing PromptGuard analyzer...") + analyzer = PromptGuardAnalyzer() + logger.info("PromptGuard analyzer initialized successfully") +except Exception as e: + logger.error(f"Failed to initialize PromptGuard analyzer: {str(e)}") + +def handler(text: str, threshold: float, config: Dict[str, Any]) -> Dict[str, Any]: + """ + Analyzes text for prompt injection using PromptGuard. + + Args: + text: The text to analyze + threshold: Score threshold for detecting injection + config: Configuration parameters including temperature + + Returns: + dict containing: + - check_result: bool if score exceeds threshold + - score: risk score from 0-1 + - details: probabilities and classification details + """ + try: + if analyzer is None: + raise RuntimeError("PromptGuard analyzer not initialized") + + temperature = config.get('temperature', 3.0) + results = analyzer.analyze_text(text, temperature) + + return { + "check_result": results["score"] > threshold, + "score": results["score"], + "details": { + "benign_probability": results["details"]["benign_probability"], + "injection_probability": results["details"]["injection_probability"], + "jailbreak_probability": results["details"]["jailbreak_probability"], + "classification": results["classification"] + } + } + + except Exception as e: + logger.error(f"Error in PromptGuard analysis: {str(e)}") + return { + "check_result": False, + "score": 0.0, + "details": {"error": str(e)} + } \ No newline at end of file From 5a71bf5cce377f6604eebe0332ee127cd2ba2382 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Wed, 27 Nov 2024 16:42:36 +0100 Subject: [PATCH 2/4] cleanup --- services/llmguard/main.py | 190 ---------------------------- services/llmguard/pyproject.toml | 20 --- services/promptguard/main.py | 170 ------------------------- services/promptguard/pyproject.toml | 31 ----- 4 files changed, 411 deletions(-) delete mode 100644 services/llmguard/main.py delete mode 100644 services/llmguard/pyproject.toml delete mode 100644 services/promptguard/main.py delete mode 100644 services/promptguard/pyproject.toml diff --git a/services/llmguard/main.py b/services/llmguard/main.py deleted file mode 100644 index 63cc12e..0000000 --- a/services/llmguard/main.py +++ /dev/null @@ -1,190 +0,0 @@ -# llamaguard_service.py -import os -import logging -from typing import Dict, List, Optional -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel, Field -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -from huggingface_hub import login, HfApi -import uvicorn -from dotenv import load_dotenv - -load_dotenv() - - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' -) -logger = logging.getLogger(__name__) - - -class AnalyzeRequest(BaseModel): - text: str - categories: List[str] = Field(default_factory=list) - excluded_categories: List[str] = Field(default_factory=list) - - -class AnalyzeResponse(BaseModel): - response: str - - -class LlamaGuard: - def __init__(self): - self.token = os.getenv("HUGGINGFACE_API_KEY") - if not self.token: - raise ValueError("HuggingFace API token not set") - - try: - login(token=self.token, write_permission=True) - self.api = HfApi() - except Exception as e: - logger.error(f"Authentication error: {str(e)}") - raise - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"Using device: {self.device}") - - try: - logger.info("Loading Llama Guard model...") - model_id = "meta-llama/Llama-Guard-3-1B" - - self.model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - device_map="auto", - token=self.token - ) - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, - token=self.token - ) - logger.info("Model loaded successfully") - - except Exception as e: - logger.error(f"Error loading model: {e}") - raise - - def clean_analysis_output(self, text: str) -> str: - - text = text.replace("<|eot_id|>", "").replace("<|endoftext|>", "") - - text = text.replace("\n", " ") - - text = " ".join(text.split()) - text = text.strip() - - if text.startswith("unsafe"): - text = text.replace("unsafe ", "unsafe,") - return text.strip() - - def analyze_content( - self, - text: str, - categories: Optional[List[str]] = None, - excluded_categories: Optional[List[str]] = None - ) -> str: - try: - logger.info(f"Analyzing text: '{text[:100]}{'...' if len(text) > 100 else ''}'") - - conversation = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": text - } - ] - } - ] - - kwargs = {"return_tensors": "pt"} - - if categories: - cats_dict = {cat: cat for cat in categories} - kwargs["categories"] = cats_dict - - if excluded_categories: - kwargs["excluded_category_keys"] = excluded_categories - - input_ids = self.tokenizer.apply_chat_template( - conversation, - **kwargs - ).to(self.device) - - with torch.inference_mode(): - prompt_len = input_ids.shape[-1] - output = self.model.generate( - input_ids, - max_new_tokens=256, - pad_token_id=0, - ) - - - analysis = self.tokenizer.decode( - output[0][prompt_len:], - skip_special_tokens=True, - clean_up_tokenization_spaces=True - ) - clean_analysis = self.clean_analysis_output(analysis) - - logger.info(f"Analysis completed. Result: {clean_analysis}") - return clean_analysis - - except Exception as e: - logger.error(f"Error during analysis: {e}") - raise - - -app = FastAPI( - title="LlamaGuard ", - description="Meta's Llama Guard model" -) - -llama_guard: Optional[LlamaGuard] = None - - -@app.on_event("startup") -async def startup_event(): - global llama_guard - try: - llama_guard = LlamaGuard() - except Exception as e: - logger.error(f"Failed to initialize LlamaGuard: {e}") - raise - - -@app.post("/analyze", response_model=AnalyzeResponse) -async def analyze_content(request: AnalyzeRequest): - try: - if not llama_guard: - raise HTTPException(status_code=500, detail="LlamaGuard not initialized") - - response = llama_guard.analyze_content( - request.text, - request.categories, - request.excluded_categories - ) - - return AnalyzeResponse(response=response) - - except Exception as e: - logger.error(f"Error processing request: {e}") - raise HTTPException( - status_code=500, - detail=f"Error analyzing content: {str(e)}" - ) - - -@app.get("/health") -async def health_check(): - if not llama_guard: - raise HTTPException(status_code=503, detail="Service not ready") - return {"status": "healthy"} - - -if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/services/llmguard/pyproject.toml b/services/llmguard/pyproject.toml deleted file mode 100644 index 1763d27..0000000 --- a/services/llmguard/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[tool.poetry] -name = "llamaguard-service" -version = "0.1.0" -description = "Content safety analysis service using Llama Guard" -authors = ["Openshield"] - -[tool.poetry.dependencies] -python = "^3.9" -fastapi = "^0.109.0" -uvicorn = "^0.27.0" -torch = "^2.1.0" -transformers = "^4.43.2" -pydantic = "^2.5.0" -huggingface-hub = "^0.23.2" -python-dotenv = "^1.0.0" -accelerate = "^1.1.0" - -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" diff --git a/services/promptguard/main.py b/services/promptguard/main.py deleted file mode 100644 index 346236e..0000000 --- a/services/promptguard/main.py +++ /dev/null @@ -1,170 +0,0 @@ -import os - -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -import torch -from transformers import AutoTokenizer, AutoModelForSequenceClassification -import logging -from typing import Dict, Optional -from huggingface_hub import login, HfApi - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -app = FastAPI(title="PromptGuard Service") - - -class AnalyzeRequest(BaseModel): - text: str - threshold: float = 0.5 - temperature: float = 3.0 - - -class AnalyzeResponse(BaseModel): - score: float - details: Dict[str, float] - classification: str - - -class PromptGuard: - def __init__(self): - self.token = os.getenv("HUGGINGFACE_API_KEY") #NEED TO REQUEST ACCESS FOR THE MODEL! - if not self.token: - raise ValueError("Token not set") - - try: - login(token=self.token, write_permission=True) - api = HfApi() - except Exception as e: - logger.error(f"Authentication error: {str(e)}") - raise - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"Using device: {self.device}") - - try: - logger.info("Loading model") - self.model = AutoModelForSequenceClassification.from_pretrained( - "meta-llama/Prompt-Guard-86M", - use_auth_token=self.token, - trust_remote_code=True - ) - self.tokenizer = AutoTokenizer.from_pretrained( - "meta-llama/Prompt-Guard-86M", - use_auth_token=self.token, - trust_remote_code=True - ) - self.model.to(self.device) - self.model.eval() - logger.info("Model loaded successfully") - except Exception as e: - logger.error(f"Error loading model: {e}") - raise - - def get_class_probabilities(self, text: str, temperature: float = 3.0) -> torch.Tensor: - inputs = self.tokenizer( - text, - return_tensors="pt", - padding=True, - truncation=True, - max_length=512 - ).to(self.device) - - with torch.no_grad(): - logits = self.model(**inputs).logits - - scaled_logits = logits / temperature - - probabilities = torch.nn.functional.softmax(scaled_logits, dim=-1) - - return probabilities[0] - - def get_indirect_injection_score(self, text: str, temperature: float = 3.0) -> float: - - probabilities = self.get_class_probabilities(text, temperature) - return (probabilities[1] + probabilities[2]).item() - - def analyze_text(self, text: str, temperature: float = 3.0) -> Dict[str, any]: - try: - probabilities = self.get_class_probabilities(text, temperature) - - - scores = { - "benign_probability": probabilities[0].item(), - "injection_probability": probabilities[1].item(), - "jailbreak_probability": probabilities[2].item() - } - - - if scores["jailbreak_probability"] > scores["injection_probability"]: - risk_score = scores["jailbreak_probability"] - classification = "jailbreak" - else: - risk_score = scores["injection_probability"] - classification = "injection" - - - logger.info(f"\nAnalyzing text: {text[:100]}...") - logger.info(f"Probabilities: {scores}") - logger.info(f"Classification: {classification}") - - return { - "score": risk_score, - "details": scores, - "classification": classification - } - - except Exception as e: - logger.error(f"Error during analysis: {e}") - raise - - -prompt_guard: Optional[PromptGuard] = None - - -@app.on_event("startup") -async def startup_event(): - global prompt_guard - try: - prompt_guard = PromptGuard() - except Exception as e: - logger.error(f"Failed to initialize PromptGuard: {e}") - raise - - -@app.post("/analyze", response_model=AnalyzeResponse) -async def analyze_prompt(request: AnalyzeRequest): - try: - if not prompt_guard: - raise HTTPException(status_code=500, detail="PromptGuard not initialized") - - results = prompt_guard.analyze_text(request.text, request.temperature) - - return AnalyzeResponse( - score=results["score"], - details=results["details"], - classification=results["classification"] - ) - - except Exception as e: - logger.error(f"Error processing request: {e}") - raise HTTPException( - status_code=500, - detail=f"Error analyzing prompt: {str(e)}" - ) - - -@app.get("/health") -async def health_check(): - if not prompt_guard: - raise HTTPException(status_code=503, detail="Service not ready") - return {"status": "healthy"} - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/services/promptguard/pyproject.toml b/services/promptguard/pyproject.toml deleted file mode 100644 index dafc1bd..0000000 --- a/services/promptguard/pyproject.toml +++ /dev/null @@ -1,31 +0,0 @@ -[tool.poetry] -name = "promptguard-service" -version = "0.1.0" -description = "PromptGuard analysis service for OpenShield" -authors = ["Your Name "] - -[tool.poetry.dependencies] -python = "^3.9" -fastapi = "^0.104.1" -uvicorn = "^0.24.0" -transformers = "^4.35.2" -torch = "^2.1.1" -python-multipart = "^0.0.6" -pydantic = "^2.5.1" - -[tool.poetry.dev-dependencies] -pytest = "^7.4.3" -black = "^24.0.0" -isort = "^5.12.0" - -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" - -[tool.black] -line-length = 88 -target-version = ['py39'] - -[tool.isort] -profile = "black" -multi_line_output = 3 \ No newline at end of file From 63fad1cefec4b0c4bac608a23b15675555d50f16 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Thu, 28 Nov 2024 10:38:22 +0100 Subject: [PATCH 3/4] cleanup --- lib/llamaguard/handlers.go | 135 ------------------------------------ lib/promptguard/handlers.go | 103 --------------------------- server/server.go | 4 -- 3 files changed, 242 deletions(-) delete mode 100644 lib/llamaguard/handlers.go delete mode 100644 lib/promptguard/handlers.go diff --git a/lib/llamaguard/handlers.go b/lib/llamaguard/handlers.go deleted file mode 100644 index 7786bd9..0000000 --- a/lib/llamaguard/handlers.go +++ /dev/null @@ -1,135 +0,0 @@ -package llamaguard - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "github.com/google/uuid" - "github.com/openshieldai/openshield/lib/provider" - "net/http" - - "github.com/go-chi/chi/v5" - "github.com/openshieldai/openshield/lib" -) - -type AnalyzeRequest struct { - Text string `json:"text"` - Categories []string `json:"categories,omitempty"` - ExcludedCategories []string `json:"excluded_categories,omitempty"` -} - -type LlamaGuardResponse struct { - Response string `json:"response"` -} - -type AnalyzeResponse struct { - IsSafe bool `json:"is_safe"` - Categories []string `json:"violated_categories,omitempty"` - Analysis string `json:"analysis"` -} - -func SetupRoutes(r chi.Router) { - r.Post("/llamaguard/analyze", lib.AuthOpenShieldMiddleware(AnalyzeHandler)) -} - -func AnalyzeHandler(w http.ResponseWriter, r *http.Request) { - var req AnalyzeRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - lib.ErrorResponse(w, fmt.Errorf("error reading request body: %v", err)) - return - } - - if req.Text == "" { - lib.ErrorResponse(w, fmt.Errorf("text field is required")) - return - } - - provider.LogProviderInput(r, "llamaguard", req.Text) - - resp, err := callLlamaGuardService(r.Context(), req) - if err != nil { - lib.ErrorResponse(w, fmt.Errorf("error calling LlamaGuard service: %v", err)) - return - } - - respBytes, _ := json.Marshal(resp) - provider.LogProviderOutput(r, "llamaguard", respBytes) - - json.NewEncoder(w).Encode(resp) -} - -func callLlamaGuardService(ctx context.Context, req AnalyzeRequest) (*AnalyzeResponse, error) { - config := lib.GetConfig() - llamaGuardURL := config.Services.LlamaGuard.BaseUrl + "/analyze" - - reqBody, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, "POST", llamaGuardURL, bytes.NewBuffer(reqBody)) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - httpReq.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("service returned status %d", resp.StatusCode) - } - - var llamaGuardResp LlamaGuardResponse - if err := json.NewDecoder(resp.Body).Decode(&llamaGuardResp); err != nil { - return nil, fmt.Errorf("error decoding response: %v", err) - } - - return parseLlamaGuardResponse(llamaGuardResp.Response), nil -} - -func parseLlamaGuardResponse(response string) *AnalyzeResponse { - - result := &AnalyzeResponse{ - Analysis: response, - IsSafe: response == "safe", - } - - if !result.IsSafe { - - for _, category := range []string{"S1", "S2", "S3", "S4", "S5", "S6", "S7", - "S8", "S9", "S10", "S11", "S12", "S13"} { - if bytes.Contains([]byte(response), []byte(category)) { - result.Categories = append(result.Categories, category) - } - } - } - - return result -} - -func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) { - provider.LogProviderInput(r, "llamaguard", body) -} - -func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) { - var productIDStr string - err := lib.DB().Table("api_keys").Where("id = ?", apiKeyId).Pluck("product_id", &productIDStr).Error - if err != nil { - return uuid.Nil, err - } - - productID, err := uuid.Parse(productIDStr) - if err != nil { - return uuid.Nil, errors.New("failed to parse product_id as UUID") - } - - return productID, nil -} diff --git a/lib/promptguard/handlers.go b/lib/promptguard/handlers.go deleted file mode 100644 index c0e1f51..0000000 --- a/lib/promptguard/handlers.go +++ /dev/null @@ -1,103 +0,0 @@ -package promptguard - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "github.com/google/uuid" - "net/http" - - "github.com/go-chi/chi/v5" - "github.com/openshieldai/openshield/lib" -) - -type AnalyzeRequest struct { - Text string `json:"text"` - Threshold float64 `json:"threshold"` -} - -type AnalyzeResponse struct { - Score float64 `json:"score"` - Details struct { - BenignProbability float64 `json:"benign_probability"` - InjectionProbability float64 `json:"injection_probability"` - JailbreakProbability float64 `json:"jailbreak_probability"` - } `json:"details"` -} - -func SetupRoutes(r chi.Router) { - r.Post("/promptguard/analyze", lib.AuthOpenShieldMiddleware(AnalyzeHandler)) -} - -func AnalyzeHandler(w http.ResponseWriter, r *http.Request) { - var req AnalyzeRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - lib.ErrorResponse(w, fmt.Errorf("error reading request body: %v", err)) - return - } - - if req.Text == "" { - lib.ErrorResponse(w, fmt.Errorf("text field is required")) - return - } - - resp, err := callPromptGuardService(r.Context(), req) - if err != nil { - lib.ErrorResponse(w, fmt.Errorf("error calling PromptGuard service: %v", err)) - return - } - - json.NewEncoder(w).Encode(resp) -} - -func callPromptGuardService(ctx context.Context, req AnalyzeRequest) (*AnalyzeResponse, error) { - config := lib.GetConfig() - promptGuardURL := config.Services.PromptGuard.BaseUrl + "/analyze" - - reqBody, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, "POST", promptGuardURL, bytes.NewBuffer(reqBody)) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - httpReq.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("service returned status %d", resp.StatusCode) - } - - var result AnalyzeResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("error decoding response: %v", err) - } - - return &result, nil -} - -func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) { - var productIDStr string - err := lib.DB().Table("api_keys").Where("id = ?", apiKeyId).Pluck("product_id", &productIDStr).Error - if err != nil { - return uuid.Nil, err - } - - productID, err := uuid.Parse(productIDStr) - if err != nil { - return uuid.Nil, errors.New("failed to parse product_id as UUID") - } - - return productID, nil -} diff --git a/server/server.go b/server/server.go index 9a1cf23..914f23b 100644 --- a/server/server.go +++ b/server/server.go @@ -3,8 +3,6 @@ package server import ( "context" "fmt" - "github.com/openshieldai/openshield/lib/llamaguard" - "github.com/openshieldai/openshield/lib/promptguard" "net/http" "time" @@ -145,8 +143,6 @@ func setupOpenAIRoutes(r chi.Router) { r.Route("/huggingface/v1", func(r chi.Router) { r.Post("/chat/completions", lib.AuthOpenShieldMiddleware(huggingface.ChatCompletionHandler)) }) - r.Post("/v1/llamaguard/analyze", lib.AuthOpenShieldMiddleware(llamaguard.AnalyzeHandler)) - r.Post("/v1/promptguard/analyze", lib.AuthOpenShieldMiddleware(promptguard.AnalyzeHandler)) } var redisClient *redis.Client From c1630ebd3c0e875a29e45b97f55772d8f36961a2 Mon Sep 17 00:00:00 2001 From: krichard1212 Date: Mon, 2 Dec 2024 14:04:14 +0100 Subject: [PATCH 4/4] config --- config_example.yaml | 29 ++++++++++++ lib/config.go | 1 + lib/rules/input.go | 58 +++++++++++++++++++++--- services/rule/src/plugins/llama_guard.py | 33 +++++++------- 4 files changed, 99 insertions(+), 22 deletions(-) diff --git a/config_example.yaml b/config_example.yaml index 8333667..b973587 100644 --- a/config_example.yaml +++ b/config_example.yaml @@ -1,4 +1,8 @@ rules: + + + + input: - name: "language_detection" type: "language_detection" @@ -62,6 +66,31 @@ rules: action: type: "block" # - type: "monitoring" # logging + - name: "llamaguard_check" + type: "llama_guard" + enabled: true + order_number: 4 + config: + plugin_name: "llama_guard" + threshold: 0.5 + relation: ">" + # can be left empty, in that case every categgory is included. + #categories: ["S1","S7"] + + action: + type: "block" + + - name: "PromptGuard Injection Detection" + type: "prompt_guard" + enabled: true + order_number: 5 + config: + plugin_name: "prompt_guard" + threshold: 0.7 + relation: ">" + temperature: 3.0 + action: + type: "block" output: - name: "pii_example" type: "pii_filter" diff --git a/lib/config.go b/lib/config.go index 09fc910..2044c47 100644 --- a/lib/config.go +++ b/lib/config.go @@ -151,6 +151,7 @@ type Config struct { Url string `mapstructure:"url,omitempty"` ApiKey string `mapstructure:"api_key,omitempty"` PIIService interface{} `mapstructure:"piiservice,omitempty"` + Categories []string `mapstructure:"categories,omitempty"` } // Add this function to set default values diff --git a/lib/rules/input.go b/lib/rules/input.go index f4b2734..293d64c 100644 --- a/lib/rules/input.go +++ b/lib/rules/input.go @@ -109,15 +109,61 @@ func genericHandler(inputConfig lib.Rule, rule RuleResult) (bool, string, error) return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil } func handleLlamaGuardAction(inputConfig lib.Rule, rule RuleResult) (bool, string, error) { - log.Printf("%s detection result: Match=%v, Score=%f", inputConfig.Type, rule.Match, rule.Inspection.Score) + log.Printf("LlamaGuard detection result: Match=%v, Score=%f", rule.Match, rule.Inspection.Score) + + // Log which categories we're checking + if len(inputConfig.Config.Categories) > 0 { + log.Printf("Checking specific categories: %v", inputConfig.Config.Categories) + } else { + log.Println("Checking all default categories") + } + if rule.Match { - if inputConfig.Action.Type == "block" { - log.Println("Blocking request due to LlamaGuard detection.") - return true, fmt.Sprintf(`{"status": "blocked", "rule_type": "%s"}`, inputConfig.Type), nil + details := rule.Inspection.Details + if details != nil { + if rawAnalysis, ok := details["raw_analysis"].(string); ok { + log.Printf("LlamaGuard analysis: %s", rawAnalysis) + } + + if violatedCategories, ok := details["violated_categories"].([]interface{}); ok { + categories := make([]string, len(violatedCategories)) + for i, v := range violatedCategories { + categories[i] = v.(string) + } + + relevantViolations := []string{} + configuredCategories := inputConfig.Config.Categories + + if len(configuredCategories) > 0 { + + for _, violation := range categories { + for _, configured := range configuredCategories { + if violation == configured { + relevantViolations = append(relevantViolations, violation) + break + } + } + } + } else { + + relevantViolations = categories + } + + if len(relevantViolations) > 0 { + log.Printf("Violated categories (after filtering): %v", relevantViolations) + if inputConfig.Action.Type == "block" { + log.Printf("Blocking request due to LlamaGuard detection in categories: %v", relevantViolations) + return true, fmt.Sprintf(`{"status": "blocked", "rule_type": "%s", "violated_categories": %v}`, + inputConfig.Type, relevantViolations), nil + } + log.Printf("Monitoring request due to LlamaGuard detection in categories: %v", relevantViolations) + return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s", "violated_categories": %v}`, + inputConfig.Type, relevantViolations), nil + } + } } - log.Println("Monitoring request due to LlamaGuard detection.") - return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil } + log.Println("LlamaGuard Rule Not Matched") return false, fmt.Sprintf(`{"status": "non_blocked", "rule_type": "%s"}`, inputConfig.Type), nil } diff --git a/services/rule/src/plugins/llama_guard.py b/services/rule/src/plugins/llama_guard.py index aef018e..91c5fc3 100644 --- a/services/rule/src/plugins/llama_guard.py +++ b/services/rule/src/plugins/llama_guard.py @@ -6,7 +6,6 @@ import logging from typing import Dict, Any, List, Optional import torch -import accelerate from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import login, HfApi @@ -17,6 +16,8 @@ ) logger = logging.getLogger(__name__) +DEFAULT_CATEGORIES = ["S1", "S2", "S3", "S4", "S5", "S6", "S7", + "S8", "S9", "S10", "S11", "S12", "S13"] def get_huggingface_token(): token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") @@ -25,7 +26,6 @@ def get_huggingface_token(): return None return token - class LlamaGuardAnalyzer: def __init__(self): self.token = get_huggingface_token() @@ -48,7 +48,6 @@ def __init__(self): model_id = "meta-llama/Llama-Guard-3-1B" if torch.cuda.is_available(): - logger.info("Using GPU for model loading") self.model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, @@ -56,7 +55,6 @@ def __init__(self): token=self.token ) else: - logger.info("Using CPU for model loading") self.model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, @@ -87,8 +85,7 @@ def clean_analysis_output(self, text: str) -> str: def analyze_content( self, text: str, - categories: Optional[List[str]] = None, - excluded_categories: Optional[List[str]] = None + categories: Optional[List[str]] = None ) -> str: try: logger.info(f"Analyzing text: '{text[:100]}{'...' if len(text) > 100 else ''}'") @@ -106,11 +103,17 @@ def analyze_content( ] kwargs = {"return_tensors": "pt"} + if categories: + # Convert categories to the format expected by the model cats_dict = {cat: cat for cat in categories} kwargs["categories"] = cats_dict - if excluded_categories: - kwargs["excluded_category_keys"] = excluded_categories + logger.info(f"Using specified categories: {cats_dict}") + else: + # Use all default categories if none specified + cats_dict = {cat: cat for cat in DEFAULT_CATEGORIES} + kwargs["categories"] = cats_dict + logger.info("Using all default categories") input_ids = self.tokenizer.apply_chat_template( conversation, @@ -138,7 +141,6 @@ def analyze_content( logger.error(f"Error during analysis: {e}") raise - analyzer = None try: logger.info("Initializing LlamaGuard analyzer...") @@ -147,19 +149,17 @@ def analyze_content( except Exception as e: logger.error(f"Failed to initialize LlamaGuard analyzer: {str(e)}") - def handler(text: str, threshold: float, config: Dict[str, Any]) -> Dict[str, Any]: try: if analyzer is None: raise RuntimeError("LlamaGuard analyzer not initialized") + # Extract categories from config categories = config.get('categories', []) - excluded_categories = config.get('excluded_categories', []) analysis = analyzer.analyze_content( text, - categories=categories, - excluded_categories=excluded_categories + categories=categories if categories else None ) is_unsafe = not analysis.lower().startswith('safe') @@ -167,8 +167,9 @@ def handler(text: str, threshold: float, config: Dict[str, Any]) -> Dict[str, An violated_categories = [] if is_unsafe: - for category in ["S1", "S2", "S3", "S4", "S5", "S6", "S7", - "S8", "S9", "S10", "S11", "S12", "S13"]: + # Look for category violations in the analysis text + check_categories = categories if categories else DEFAULT_CATEGORIES + for category in check_categories: if category in analysis: violated_categories.append(category) @@ -188,4 +189,4 @@ def handler(text: str, threshold: float, config: Dict[str, Any]) -> Dict[str, An "check_result": False, "score": 0.0, "details": {"error": str(e)} - } + } \ No newline at end of file