Skip to content

Commit

Permalink
Improvements and fixes (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
pigri authored Nov 7, 2024
2 parents c06dc2c + d237ec0 commit 7f5bc39
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 48 deletions.
74 changes: 44 additions & 30 deletions lib/openai/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,6 @@ func CreateThreadHandler(w http.ResponseWriter, r *http.Request) {

performAuditLogging(r, "openai_create_thread", "input", body)

filtered, message, errorMessage := rules.Input(r, req)
if errorMessage != nil {
handleError(w, fmt.Errorf("error processing input: %v", errorMessage), http.StatusBadRequest)
return
}

logMessage, err := json.Marshal(message)
if err != nil {
handleError(w, fmt.Errorf("error marshalling message: %v", err), http.StatusBadRequest)
return
}

if filtered {
performAuditLogging(r, "rule", "filtered", logMessage)
handleError(w, fmt.Errorf(message), http.StatusBadRequest)
return
}

c := openai.DefaultConfig(openAIAPIKey)
c.BaseURL = openAIBaseURL
client := openai.NewClientWithConfig(c)
Expand All @@ -123,7 +105,7 @@ func CreateThreadHandler(w http.ResponseWriter, r *http.Request) {
return
}

if config.Settings.Cache.Enabled {
if config.Settings.ContextCache.Enabled {
w.Header().Set(OSCacheStatusHeader, "MISS")
resJson, err := json.Marshal(resp)
if err != nil {
Expand All @@ -138,6 +120,24 @@ func CreateThreadHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set(OSCacheStatusHeader, "BYPASS")
}

filtered, message, errorMessage := rules.Input(r, req)
if errorMessage != nil {
handleError(w, fmt.Errorf("error processing input: %v", errorMessage), http.StatusBadRequest)
return
}

logMessage, err := json.Marshal(message)
if err != nil {
handleError(w, fmt.Errorf("error marshalling message: %v", err), http.StatusBadRequest)
return
}

if filtered {
performAuditLogging(r, "rule", "filtered", logMessage)
handleError(w, fmt.Errorf(message), http.StatusBadRequest)
return
}

performThreadAuditLogging(r, resp)
json.NewEncoder(w).Encode(resp)
}
Expand Down Expand Up @@ -242,7 +242,31 @@ func CreateMessageHandler(w http.ResponseWriter, r *http.Request) {
return
}

// Perform input validation and filtering
c := openai.DefaultConfig(openAIAPIKey)
c.BaseURL = openAIBaseURL
client := openai.NewClientWithConfig(c)

resp, err := client.CreateMessage(r.Context(), threadID, req)
if err != nil {
handleError(w, fmt.Errorf("failed to create message: %v", err), http.StatusInternalServerError)
return
}

if config.Settings.ContextCache.Enabled {
w.Header().Set(OSCacheStatusHeader, "MISS")
resJson, err := json.Marshal(resp)
if err != nil {
log.Printf("Error marshalling response to JSON: %v", err)
} else {
err = lib.SetCache(string(body), resJson)
if err != nil {
log.Printf("Error setting cache: %v", err)
}
}
} else {
w.Header().Set(OSCacheStatusHeader, "BYPASS")
}

filtered, message, errorMessage := rules.Input(r, req)
if errorMessage != nil {
handleError(w, fmt.Errorf("error processing input: %v", errorMessage), http.StatusBadRequest)
Expand All @@ -261,16 +285,6 @@ func CreateMessageHandler(w http.ResponseWriter, r *http.Request) {
return
}

c := openai.DefaultConfig(openAIAPIKey)
c.BaseURL = openAIBaseURL
client := openai.NewClientWithConfig(c)

resp, err := client.CreateMessage(r.Context(), threadID, req)
if err != nil {
handleError(w, fmt.Errorf("failed to create message: %v", err), http.StatusInternalServerError)
return
}

json.NewEncoder(w).Encode(resp)
}
func ListMessagesHandler(w http.ResponseWriter, r *http.Request) {
Expand Down
29 changes: 24 additions & 5 deletions services/cache/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import os
import json
from typing import Optional, Dict
from urllib.parse import urlparse

Expand All @@ -21,8 +22,26 @@
import uvicorn
from pydantic import BaseModel

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class JSONFormatter(logging.Formatter):
def format(self, record):
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"level": record.levelname,
"message": record.getMessage(),
"name": record.name,
"filename": record.filename,
"lineno": record.lineno,
}
return json.dumps(log_record)

# Set up JSON logging
json_formatter = JSONFormatter()
handler = logging.StreamHandler()
handler.setFormatter(json_formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

app = FastAPI()
openai_caches: Dict[str, Cache] = {}
redis_url = os.getenv("REDIS_URL")
Expand Down Expand Up @@ -54,21 +73,21 @@ def openshield_check_hit_func(cur_session_id, cache_session_ids, cache_questions
async def put_cache(cache_data: CacheData) -> str:
session = Session(name=cache_data.product_id, check_hit_func=openshield_check_hit_func)
put(cache_data.prompt, cache_data.answer, session=session)
logger.info(f"Setting cache data: %s", cache_data.prompt)
logger.info(f"Setting cache data: {cache_data.prompt}")
return "successfully update the cache"


@app.post("/get")
async def get_cache(cache_data: CacheData) -> CacheData:
session = Session(name=cache_data.product_id, check_hit_func=openshield_check_hit_func)
logger.info(f"Getting cache data: %s", cache_data.prompt)
logger.info(f"Getting cache data: {cache_data.prompt}")
result = get(cache_data.prompt, session=session)

if result is None:
logger.info(f"Cache miss for prompt: %s", cache_data.prompt)
logger.info(f"Cache miss for prompt: {cache_data.prompt}")
raise HTTPException(status_code=404, detail="Cache miss")
else:
logger.info(f"Cache hit for prompt: %s", cache_data.prompt)
logger.info(f"Cache hit for prompt: {cache_data.prompt}")
return CacheData(prompt=cache_data.prompt, answer=result, product_id=cache_data.product_id)


Expand Down
71 changes: 58 additions & 13 deletions services/rule/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,42 @@
import importlib
import rule_engine
import logging
import os
import json

class JSONFormatter(logging.Formatter):
def format(self, record):
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"level": record.levelname,
"message": record.getMessage(),
"name": record.name,
"filename": record.filename,
"lineno": record.lineno,
}
return json.dumps(log_record)

def setup_logging():
# Get the log level from the environment variable, default to 'INFO'
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()

# Validate and set the log level
numeric_level = getattr(logging, log_level, None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {log_level}')

# Configure the logging
json_formatter = JSONFormatter()
handler = logging.StreamHandler()
handler.setFormatter(json_formatter)
logger = logging.getLogger(__name__)
logger.setLevel(numeric_level)
logger.addHandler(handler)

# Configure logging
logging.basicConfig(level=logging.DEBUG)
setup_logging()

# Example usage of logging
logger = logging.getLogger(__name__)

# Define plugin_name at the module level
Expand All @@ -17,11 +50,9 @@ class Message(BaseModel):
role: str
content: str


class Thread(BaseModel):
messages: List[Message]


class Prompt(BaseModel):
model: Optional[str] = None
assistant_id: Optional[str] = None
Expand All @@ -30,7 +61,6 @@ class Prompt(BaseModel):
role: Optional[str] = None
content: Optional[str] = None


class Config(BaseModel):
PluginName: str
Threshold: float
Expand All @@ -40,13 +70,12 @@ class Config(BaseModel):
class Config:
extra = "allow"


class Rule(BaseModel):
prompt: Prompt
config: Config


app = FastAPI()

@app.middleware("http")
async def log_request(request: Request, call_next):
logger.debug(f"Incoming request: {request.method} {request.url}")
Expand All @@ -55,6 +84,9 @@ async def log_request(request: Request, call_next):
response = await call_next(request)
return response

@app.get("/status/healthz")
async def health_check():
return {"status": "healthy"}

@app.post("/rule/execute")
async def execute_plugin(rule: Rule):
Expand Down Expand Up @@ -126,15 +158,28 @@ async def execute_plugin(rule: Rule):

# Create and evaluate the rule
relation = rule.config.Relation
rule_obj = rule_engine.Rule(f"score {relation} threshold", context=context)
match = rule_obj.matches(data)
logger.debug(f"Rule engine result: match={match}")
if relation is None:
raise ValueError("The 'relation' variable is undefined in the context.")

response = {"match": match, "inspection": plugin_result}
logger.debug(f"Plugin Name: {plugin_name} API response: {response}")
try:
rule_obj = rule_engine.Rule(f"score {relation} threshold", context=context)
match = rule_obj.matches(data)
logger.debug(f"Rule engine result: match={match}")
response = {"match": match, "inspection": plugin_result}
logger.debug(f"Plugin Name: {plugin_name} API response: {response}")

return response
return response
except Exception as e:
logger.error(f"Error executing rule engine: {e}")
raise HTTPException(status_code=500, detail="Error executing rule engine")

def main():
# Get host and port from environment variables, with defaults
host = os.getenv('HOST', '0.0.0.0')
port = int(os.getenv('PORT', 8000))

logger.info(f"Starting server on {host}:{port}")
uvicorn.run(app, host=host, port=port)

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
main()

0 comments on commit 7f5bc39

Please sign in to comment.