diff --git a/lib/openai/handlers.go b/lib/openai/handlers.go index 0233b9a..4b9d2d4 100644 --- a/lib/openai/handlers.go +++ b/lib/openai/handlers.go @@ -95,24 +95,6 @@ func CreateThreadHandler(w http.ResponseWriter, r *http.Request) { performAuditLogging(r, "openai_create_thread", "input", body) - filtered, message, errorMessage := rules.Input(r, req) - if errorMessage != nil { - handleError(w, fmt.Errorf("error processing input: %v", errorMessage), http.StatusBadRequest) - return - } - - logMessage, err := json.Marshal(message) - if err != nil { - handleError(w, fmt.Errorf("error marshalling message: %v", err), http.StatusBadRequest) - return - } - - if filtered { - performAuditLogging(r, "rule", "filtered", logMessage) - handleError(w, fmt.Errorf(message), http.StatusBadRequest) - return - } - c := openai.DefaultConfig(openAIAPIKey) c.BaseURL = openAIBaseURL client := openai.NewClientWithConfig(c) @@ -123,7 +105,7 @@ func CreateThreadHandler(w http.ResponseWriter, r *http.Request) { return } - if config.Settings.Cache.Enabled { + if config.Settings.ContextCache.Enabled { w.Header().Set(OSCacheStatusHeader, "MISS") resJson, err := json.Marshal(resp) if err != nil { @@ -138,6 +120,24 @@ func CreateThreadHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set(OSCacheStatusHeader, "BYPASS") } + filtered, message, errorMessage := rules.Input(r, req) + if errorMessage != nil { + handleError(w, fmt.Errorf("error processing input: %v", errorMessage), http.StatusBadRequest) + return + } + + logMessage, err := json.Marshal(message) + if err != nil { + handleError(w, fmt.Errorf("error marshalling message: %v", err), http.StatusBadRequest) + return + } + + if filtered { + performAuditLogging(r, "rule", "filtered", logMessage) + handleError(w, fmt.Errorf(message), http.StatusBadRequest) + return + } + performThreadAuditLogging(r, resp) json.NewEncoder(w).Encode(resp) } @@ -242,7 +242,31 @@ func CreateMessageHandler(w http.ResponseWriter, r *http.Request) { return } - // Perform input validation and filtering + c := openai.DefaultConfig(openAIAPIKey) + c.BaseURL = openAIBaseURL + client := openai.NewClientWithConfig(c) + + resp, err := client.CreateMessage(r.Context(), threadID, req) + if err != nil { + handleError(w, fmt.Errorf("failed to create message: %v", err), http.StatusInternalServerError) + return + } + + if config.Settings.ContextCache.Enabled { + w.Header().Set(OSCacheStatusHeader, "MISS") + resJson, err := json.Marshal(resp) + if err != nil { + log.Printf("Error marshalling response to JSON: %v", err) + } else { + err = lib.SetCache(string(body), resJson) + if err != nil { + log.Printf("Error setting cache: %v", err) + } + } + } else { + w.Header().Set(OSCacheStatusHeader, "BYPASS") + } + filtered, message, errorMessage := rules.Input(r, req) if errorMessage != nil { handleError(w, fmt.Errorf("error processing input: %v", errorMessage), http.StatusBadRequest) @@ -261,16 +285,6 @@ func CreateMessageHandler(w http.ResponseWriter, r *http.Request) { return } - c := openai.DefaultConfig(openAIAPIKey) - c.BaseURL = openAIBaseURL - client := openai.NewClientWithConfig(c) - - resp, err := client.CreateMessage(r.Context(), threadID, req) - if err != nil { - handleError(w, fmt.Errorf("failed to create message: %v", err), http.StatusInternalServerError) - return - } - json.NewEncoder(w).Encode(resp) } func ListMessagesHandler(w http.ResponseWriter, r *http.Request) { diff --git a/services/cache/main.py b/services/cache/main.py index 418418e..adea41f 100644 --- a/services/cache/main.py +++ b/services/cache/main.py @@ -1,6 +1,7 @@ import argparse import logging import os +import json from typing import Optional, Dict from urllib.parse import urlparse @@ -21,8 +22,26 @@ import uvicorn from pydantic import BaseModel -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +class JSONFormatter(logging.Formatter): + def format(self, record): + log_record = { + "timestamp": self.formatTime(record, self.datefmt), + "level": record.levelname, + "message": record.getMessage(), + "name": record.name, + "filename": record.filename, + "lineno": record.lineno, + } + return json.dumps(log_record) + +# Set up JSON logging +json_formatter = JSONFormatter() +handler = logging.StreamHandler() +handler.setFormatter(json_formatter) logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(handler) + app = FastAPI() openai_caches: Dict[str, Cache] = {} redis_url = os.getenv("REDIS_URL") @@ -54,21 +73,21 @@ def openshield_check_hit_func(cur_session_id, cache_session_ids, cache_questions async def put_cache(cache_data: CacheData) -> str: session = Session(name=cache_data.product_id, check_hit_func=openshield_check_hit_func) put(cache_data.prompt, cache_data.answer, session=session) - logger.info(f"Setting cache data: %s", cache_data.prompt) + logger.info(f"Setting cache data: {cache_data.prompt}") return "successfully update the cache" @app.post("/get") async def get_cache(cache_data: CacheData) -> CacheData: session = Session(name=cache_data.product_id, check_hit_func=openshield_check_hit_func) - logger.info(f"Getting cache data: %s", cache_data.prompt) + logger.info(f"Getting cache data: {cache_data.prompt}") result = get(cache_data.prompt, session=session) if result is None: - logger.info(f"Cache miss for prompt: %s", cache_data.prompt) + logger.info(f"Cache miss for prompt: {cache_data.prompt}") raise HTTPException(status_code=404, detail="Cache miss") else: - logger.info(f"Cache hit for prompt: %s", cache_data.prompt) + logger.info(f"Cache hit for prompt: {cache_data.prompt}") return CacheData(prompt=cache_data.prompt, answer=result, product_id=cache_data.product_id) diff --git a/services/rule/src/main.py b/services/rule/src/main.py index ae4d437..789f67f 100644 --- a/services/rule/src/main.py +++ b/services/rule/src/main.py @@ -5,9 +5,42 @@ import importlib import rule_engine import logging +import os +import json + +class JSONFormatter(logging.Formatter): + def format(self, record): + log_record = { + "timestamp": self.formatTime(record, self.datefmt), + "level": record.levelname, + "message": record.getMessage(), + "name": record.name, + "filename": record.filename, + "lineno": record.lineno, + } + return json.dumps(log_record) + +def setup_logging(): + # Get the log level from the environment variable, default to 'INFO' + log_level = os.getenv('LOG_LEVEL', 'INFO').upper() + + # Validate and set the log level + numeric_level = getattr(logging, log_level, None) + if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {log_level}') + + # Configure the logging + json_formatter = JSONFormatter() + handler = logging.StreamHandler() + handler.setFormatter(json_formatter) + logger = logging.getLogger(__name__) + logger.setLevel(numeric_level) + logger.addHandler(handler) # Configure logging -logging.basicConfig(level=logging.DEBUG) +setup_logging() + +# Example usage of logging logger = logging.getLogger(__name__) # Define plugin_name at the module level @@ -17,11 +50,9 @@ class Message(BaseModel): role: str content: str - class Thread(BaseModel): messages: List[Message] - class Prompt(BaseModel): model: Optional[str] = None assistant_id: Optional[str] = None @@ -30,7 +61,6 @@ class Prompt(BaseModel): role: Optional[str] = None content: Optional[str] = None - class Config(BaseModel): PluginName: str Threshold: float @@ -40,13 +70,12 @@ class Config(BaseModel): class Config: extra = "allow" - class Rule(BaseModel): prompt: Prompt config: Config - app = FastAPI() + @app.middleware("http") async def log_request(request: Request, call_next): logger.debug(f"Incoming request: {request.method} {request.url}") @@ -55,6 +84,9 @@ async def log_request(request: Request, call_next): response = await call_next(request) return response +@app.get("/status/healthz") +async def health_check(): + return {"status": "healthy"} @app.post("/rule/execute") async def execute_plugin(rule: Rule): @@ -126,15 +158,28 @@ async def execute_plugin(rule: Rule): # Create and evaluate the rule relation = rule.config.Relation - rule_obj = rule_engine.Rule(f"score {relation} threshold", context=context) - match = rule_obj.matches(data) - logger.debug(f"Rule engine result: match={match}") + if relation is None: + raise ValueError("The 'relation' variable is undefined in the context.") - response = {"match": match, "inspection": plugin_result} - logger.debug(f"Plugin Name: {plugin_name} API response: {response}") + try: + rule_obj = rule_engine.Rule(f"score {relation} threshold", context=context) + match = rule_obj.matches(data) + logger.debug(f"Rule engine result: match={match}") + response = {"match": match, "inspection": plugin_result} + logger.debug(f"Plugin Name: {plugin_name} API response: {response}") - return response + return response + except Exception as e: + logger.error(f"Error executing rule engine: {e}") + raise HTTPException(status_code=500, detail="Error executing rule engine") + +def main(): + # Get host and port from environment variables, with defaults + host = os.getenv('HOST', '0.0.0.0') + port = int(os.getenv('PORT', 8000)) + logger.info(f"Starting server on {host}:{port}") + uvicorn.run(app, host=host, port=port) if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + main()