diff --git a/lib/llamaguard/handlers.go b/lib/llamaguard/handlers.go index 37adb37..7786bd9 100644 --- a/lib/llamaguard/handlers.go +++ b/lib/llamaguard/handlers.go @@ -3,12 +3,11 @@ package llamaguard import ( "bytes" "context" - "crypto/sha256" "encoding/json" "errors" "fmt" "github.com/google/uuid" - "log" + "github.com/openshieldai/openshield/lib/provider" "net/http" "github.com/go-chi/chi/v5" @@ -47,7 +46,7 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) { return } - performAuditLogging(r, "llamaguard_analyze", "input", []byte(req.Text)) + provider.LogProviderInput(r, "llamaguard", req.Text) resp, err := callLlamaGuardService(r.Context(), req) if err != nil { @@ -56,7 +55,7 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) { } respBytes, _ := json.Marshal(resp) - performAuditLogging(r, "llamaguard_analyze", "output", respBytes) + provider.LogProviderOutput(r, "llamaguard", respBytes) json.NewEncoder(w).Encode(resp) } @@ -117,16 +116,7 @@ func parseLlamaGuardResponse(response string) *AnalyzeResponse { } func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) { - apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID) - - productID, err := getProductIDFromAPIKey(apiKeyId) - if err != nil { - hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String())) - log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err) - return - } - - lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r) + provider.LogProviderInput(r, "llamaguard", body) } func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) { diff --git a/lib/openai/handlers.go b/lib/openai/handlers.go index 4b9d2d4..36ec52f 100644 --- a/lib/openai/handlers.go +++ b/lib/openai/handlers.go @@ -794,16 +794,7 @@ func ChatCompletionHandler(w http.ResponseWriter, r *http.Request) { } func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) { - apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID) - - productID, err := getProductIDFromAPIKey(apiKeyId) - if err != nil { - hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String())) - log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err) - return - } - lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r) - + provider.LogProviderInput(r, "openai", body) } func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) { var productIDStr string diff --git a/lib/promptguard/handlers.go b/lib/promptguard/handlers.go index 900f04f..c0e1f51 100644 --- a/lib/promptguard/handlers.go +++ b/lib/promptguard/handlers.go @@ -3,12 +3,10 @@ package promptguard import ( "bytes" "context" - "crypto/sha256" "encoding/json" "errors" "fmt" "github.com/google/uuid" - "log" "net/http" "github.com/go-chi/chi/v5" @@ -45,17 +43,12 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) { return } - performAuditLogging(r, "promptguard_analyze", "input", []byte(req.Text)) - resp, err := callPromptGuardService(r.Context(), req) if err != nil { lib.ErrorResponse(w, fmt.Errorf("error calling PromptGuard service: %v", err)) return } - respBytes, _ := json.Marshal(resp) - performAuditLogging(r, "promptguard_analyze", "output", respBytes) - json.NewEncoder(w).Encode(resp) } @@ -94,19 +87,6 @@ func callPromptGuardService(ctx context.Context, req AnalyzeRequest) (*AnalyzeRe return &result, nil } -func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) { - apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID) - - productID, err := getProductIDFromAPIKey(apiKeyId) - if err != nil { - hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String())) - log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err) - return - } - - lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r) -} - func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) { var productIDStr string err := lib.DB().Table("api_keys").Where("id = ?", apiKeyId).Pluck("product_id", &productIDStr).Error diff --git a/lib/provider/common.go b/lib/provider/common.go index db90134..3c87736 100644 --- a/lib/provider/common.go +++ b/lib/provider/common.go @@ -2,6 +2,7 @@ package provider import ( "context" + "crypto/sha256" "encoding/json" "fmt" "io" @@ -145,10 +146,6 @@ func handleNonStreamingRequest(w http.ResponseWriter, r *http.Request, provider if err := json.NewEncoder(w).Encode(resp); err != nil { return fmt.Errorf("error encoding response: %v", err) } - - // Perform response audit logging - PerformResponseAuditLogging(r, resp) - return nil } @@ -178,38 +175,73 @@ func PerformAuditLogging(r *http.Request, logType string, messageType string, bo lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r) } -func PerformResponseAuditLogging(r *http.Request, resp *ChatCompletionResponse) { - apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID) - productID, err := GetProductIDFromAPIKey(r.Context(), apiKeyId) +func performProviderAuditLog(r *http.Request, logPrefix string, messageType string, content interface{}) { + apiKeyID := r.Context().Value("apiKeyId").(uuid.UUID) + + productID, err := getProductIDFromAPIKey(apiKeyID) if err != nil { + hashedApiKeyId := sha256.Sum256([]byte(apiKeyID.String())) + log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err) return } - responseJSON, err := json.Marshal(resp) - if err != nil { - log.Printf("Failed to marshal response: %v", err) - return + var messageStr string + switch v := content.(type) { + case string: + messageStr = v + case []byte: + messageStr = string(v) + default: + jsonBytes, err := json.Marshal(v) + if err != nil { + log.Printf("Failed to marshal content: %v", err) + return + } + messageStr = string(jsonBytes) } - auditLog := lib.AuditLogs(string(responseJSON), "chat_completion", apiKeyId, "output", productID, r) + auditLog := lib.AuditLogs( + messageStr, + logPrefix+"_"+messageType, + apiKeyID, + messageType, + productID, + r, + ) if auditLog == nil { - log.Printf("Failed to create audit log") - return + log.Printf("Failed to create audit log for %s", logPrefix) } +} - lib.LogUsage( - resp.Model, - 0, - resp.Usage.PromptTokens, - resp.Usage.CompletionTokens, - resp.Usage.TotalTokens, - resp.Choices[0].FinishReason, - "chat_completion", - productID, - auditLog.Id, - ) +func LogProviderInput(r *http.Request, provider string, content interface{}) { + performProviderAuditLog(r, provider, "input", content) } + +func LogProviderOutput(r *http.Request, provider string, content interface{}) { + performProviderAuditLog(r, provider, "output", content) +} + +func LogProviderError(r *http.Request, provider string, err error) { + performProviderAuditLog(r, provider, "error", err.Error()) +} + +// getProductIDFromAPIKey centralizes the product ID lookup logic +func getProductIDFromAPIKey(apiKeyID uuid.UUID) (uuid.UUID, error) { + var productIDStr string + err := lib.DB().Table("api_keys").Where("id = ?", apiKeyID).Pluck("product_id", &productIDStr).Error + if err != nil { + return uuid.Nil, err + } + + productID, err := uuid.Parse(productIDStr) + if err != nil { + return uuid.Nil, fmt.Errorf("failed to parse product_id as UUID") + } + + return productID, nil +} + func HandleContextCache(ctx context.Context, req ChatCompletionRequest, productID uuid.UUID) (string, bool, error) { config := lib.GetConfig() if !config.Settings.ContextCache.Enabled { @@ -426,7 +458,40 @@ func HandleAPICallAndResponse(w http.ResponseWriter, r *http.Request, ctx contex log.Printf("Error setting context cache: %v", err) } - PerformResponseAuditLogging(r, resp) + apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID) + responseJSON, err := json.Marshal(resp) + if err != nil { + log.Printf("Failed to marshal response: %v", err) + HandleError(w, fmt.Errorf("error encoding response: %v", err), http.StatusInternalServerError) + return + } + + auditLog := lib.AuditLogs( + string(responseJSON), + "chat_completion", + apiKeyId, + "output", + productID, + r, + ) + + if auditLog == nil { + log.Printf("Failed to create audit log") + HandleError(w, fmt.Errorf("failed to create audit log"), http.StatusInternalServerError) + return + } + + lib.LogUsage( + resp.Model, + 0, + resp.Usage.PromptTokens, + resp.Usage.CompletionTokens, + resp.Usage.TotalTokens, + resp.Choices[0].FinishReason, + "chat_completion", + productID, + auditLog.Id, + ) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil {