Skip to content

Commit

Permalink
audit logs
Browse files Browse the repository at this point in the history
  • Loading branch information
krichard1212 committed Nov 25, 2024
1 parent 2b24db2 commit c20ed0f
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 70 deletions.
18 changes: 4 additions & 14 deletions lib/llamaguard/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ package llamaguard
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"log"
"github.com/openshieldai/openshield/lib/provider"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -47,7 +46,7 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) {
return
}

performAuditLogging(r, "llamaguard_analyze", "input", []byte(req.Text))
provider.LogProviderInput(r, "llamaguard", req.Text)

resp, err := callLlamaGuardService(r.Context(), req)
if err != nil {
Expand All @@ -56,7 +55,7 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) {
}

respBytes, _ := json.Marshal(resp)
performAuditLogging(r, "llamaguard_analyze", "output", respBytes)
provider.LogProviderOutput(r, "llamaguard", respBytes)

json.NewEncoder(w).Encode(resp)
}
Expand Down Expand Up @@ -117,16 +116,7 @@ func parseLlamaGuardResponse(response string) *AnalyzeResponse {
}

func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyId)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}

lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)
provider.LogProviderInput(r, "llamaguard", body)
}

func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) {
Expand Down
11 changes: 1 addition & 10 deletions lib/openai/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,16 +794,7 @@ func ChatCompletionHandler(w http.ResponseWriter, r *http.Request) {
}

func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyId)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}
lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)

provider.LogProviderInput(r, "openai", body)
}
func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) {
var productIDStr string
Expand Down
20 changes: 0 additions & 20 deletions lib/promptguard/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package promptguard
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"log"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -45,17 +43,12 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) {
return
}

performAuditLogging(r, "promptguard_analyze", "input", []byte(req.Text))

resp, err := callPromptGuardService(r.Context(), req)
if err != nil {
lib.ErrorResponse(w, fmt.Errorf("error calling PromptGuard service: %v", err))
return
}

respBytes, _ := json.Marshal(resp)
performAuditLogging(r, "promptguard_analyze", "output", respBytes)

json.NewEncoder(w).Encode(resp)
}

Expand Down Expand Up @@ -94,19 +87,6 @@ func callPromptGuardService(ctx context.Context, req AnalyzeRequest) (*AnalyzeRe
return &result, nil
}

func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyId)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}

lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)
}

func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) {
var productIDStr string
err := lib.DB().Table("api_keys").Where("id = ?", apiKeyId).Pluck("product_id", &productIDStr).Error
Expand Down
117 changes: 91 additions & 26 deletions lib/provider/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package provider

import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -145,10 +146,6 @@ func handleNonStreamingRequest(w http.ResponseWriter, r *http.Request, provider
if err := json.NewEncoder(w).Encode(resp); err != nil {
return fmt.Errorf("error encoding response: %v", err)
}

// Perform response audit logging
PerformResponseAuditLogging(r, resp)

return nil
}

Expand Down Expand Up @@ -178,38 +175,73 @@ func PerformAuditLogging(r *http.Request, logType string, messageType string, bo
lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)
}

func PerformResponseAuditLogging(r *http.Request, resp *ChatCompletionResponse) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)
productID, err := GetProductIDFromAPIKey(r.Context(), apiKeyId)
func performProviderAuditLog(r *http.Request, logPrefix string, messageType string, content interface{}) {
apiKeyID := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyID)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyID.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}

responseJSON, err := json.Marshal(resp)
if err != nil {
log.Printf("Failed to marshal response: %v", err)
return
var messageStr string
switch v := content.(type) {
case string:
messageStr = v
case []byte:
messageStr = string(v)
default:
jsonBytes, err := json.Marshal(v)
if err != nil {
log.Printf("Failed to marshal content: %v", err)
return
}
messageStr = string(jsonBytes)
}

auditLog := lib.AuditLogs(string(responseJSON), "chat_completion", apiKeyId, "output", productID, r)
auditLog := lib.AuditLogs(
messageStr,
logPrefix+"_"+messageType,
apiKeyID,
messageType,
productID,
r,
)

if auditLog == nil {
log.Printf("Failed to create audit log")
return
log.Printf("Failed to create audit log for %s", logPrefix)
}
}

lib.LogUsage(
resp.Model,
0,
resp.Usage.PromptTokens,
resp.Usage.CompletionTokens,
resp.Usage.TotalTokens,
resp.Choices[0].FinishReason,
"chat_completion",
productID,
auditLog.Id,
)
func LogProviderInput(r *http.Request, provider string, content interface{}) {
performProviderAuditLog(r, provider, "input", content)
}

func LogProviderOutput(r *http.Request, provider string, content interface{}) {
performProviderAuditLog(r, provider, "output", content)
}

func LogProviderError(r *http.Request, provider string, err error) {
performProviderAuditLog(r, provider, "error", err.Error())
}

// getProductIDFromAPIKey centralizes the product ID lookup logic
func getProductIDFromAPIKey(apiKeyID uuid.UUID) (uuid.UUID, error) {
var productIDStr string
err := lib.DB().Table("api_keys").Where("id = ?", apiKeyID).Pluck("product_id", &productIDStr).Error
if err != nil {
return uuid.Nil, err
}

productID, err := uuid.Parse(productIDStr)
if err != nil {
return uuid.Nil, fmt.Errorf("failed to parse product_id as UUID")
}

return productID, nil
}

func HandleContextCache(ctx context.Context, req ChatCompletionRequest, productID uuid.UUID) (string, bool, error) {
config := lib.GetConfig()
if !config.Settings.ContextCache.Enabled {
Expand Down Expand Up @@ -426,7 +458,40 @@ func HandleAPICallAndResponse(w http.ResponseWriter, r *http.Request, ctx contex
log.Printf("Error setting context cache: %v", err)
}

PerformResponseAuditLogging(r, resp)
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)
responseJSON, err := json.Marshal(resp)
if err != nil {
log.Printf("Failed to marshal response: %v", err)
HandleError(w, fmt.Errorf("error encoding response: %v", err), http.StatusInternalServerError)
return
}

auditLog := lib.AuditLogs(
string(responseJSON),
"chat_completion",
apiKeyId,
"output",
productID,
r,
)

if auditLog == nil {
log.Printf("Failed to create audit log")
HandleError(w, fmt.Errorf("failed to create audit log"), http.StatusInternalServerError)
return
}

lib.LogUsage(
resp.Model,
0,
resp.Usage.PromptTokens,
resp.Usage.CompletionTokens,
resp.Usage.TotalTokens,
resp.Choices[0].FinishReason,
"chat_completion",
productID,
auditLog.Id,
)

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
Expand Down

0 comments on commit c20ed0f

Please sign in to comment.