Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

audit logs #251

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions lib/llamaguard/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ package llamaguard
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"log"
"github.com/openshieldai/openshield/lib/provider"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -47,7 +46,7 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) {
return
}

performAuditLogging(r, "llamaguard_analyze", "input", []byte(req.Text))
provider.LogProviderInput(r, "llamaguard", req.Text)

resp, err := callLlamaGuardService(r.Context(), req)
if err != nil {
Expand All @@ -56,7 +55,7 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) {
}

respBytes, _ := json.Marshal(resp)
performAuditLogging(r, "llamaguard_analyze", "output", respBytes)
provider.LogProviderOutput(r, "llamaguard", respBytes)

json.NewEncoder(w).Encode(resp)
}
Expand Down Expand Up @@ -117,16 +116,7 @@ func parseLlamaGuardResponse(response string) *AnalyzeResponse {
}

func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyId)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}

lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)
provider.LogProviderInput(r, "llamaguard", body)
}

func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) {
Expand Down
11 changes: 1 addition & 10 deletions lib/openai/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,16 +794,7 @@ func ChatCompletionHandler(w http.ResponseWriter, r *http.Request) {
}

func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyId)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}
lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)

provider.LogProviderInput(r, "openai", body)
}
func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) {
var productIDStr string
Expand Down
20 changes: 0 additions & 20 deletions lib/promptguard/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package promptguard
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"log"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -45,17 +43,12 @@ func AnalyzeHandler(w http.ResponseWriter, r *http.Request) {
return
}

performAuditLogging(r, "promptguard_analyze", "input", []byte(req.Text))

resp, err := callPromptGuardService(r.Context(), req)
if err != nil {
lib.ErrorResponse(w, fmt.Errorf("error calling PromptGuard service: %v", err))
return
}

respBytes, _ := json.Marshal(resp)
performAuditLogging(r, "promptguard_analyze", "output", respBytes)

json.NewEncoder(w).Encode(resp)
}

Expand Down Expand Up @@ -94,19 +87,6 @@ func callPromptGuardService(ctx context.Context, req AnalyzeRequest) (*AnalyzeRe
return &result, nil
}

func performAuditLogging(r *http.Request, logType string, messageType string, body []byte) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyId)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyId.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}

lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)
}

func getProductIDFromAPIKey(apiKeyId uuid.UUID) (uuid.UUID, error) {
var productIDStr string
err := lib.DB().Table("api_keys").Where("id = ?", apiKeyId).Pluck("product_id", &productIDStr).Error
Expand Down
117 changes: 91 additions & 26 deletions lib/provider/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package provider

import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -145,10 +146,6 @@ func handleNonStreamingRequest(w http.ResponseWriter, r *http.Request, provider
if err := json.NewEncoder(w).Encode(resp); err != nil {
return fmt.Errorf("error encoding response: %v", err)
}

// Perform response audit logging
PerformResponseAuditLogging(r, resp)

return nil
}

Expand Down Expand Up @@ -178,38 +175,73 @@ func PerformAuditLogging(r *http.Request, logType string, messageType string, bo
lib.AuditLogs(string(body), logType, apiKeyId, messageType, productID, r)
}

func PerformResponseAuditLogging(r *http.Request, resp *ChatCompletionResponse) {
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)
productID, err := GetProductIDFromAPIKey(r.Context(), apiKeyId)
func performProviderAuditLog(r *http.Request, logPrefix string, messageType string, content interface{}) {
apiKeyID := r.Context().Value("apiKeyId").(uuid.UUID)

productID, err := getProductIDFromAPIKey(apiKeyID)
if err != nil {
hashedApiKeyId := sha256.Sum256([]byte(apiKeyID.String()))
log.Printf("Failed to retrieve ProductID for apiKeyId %x: %v", hashedApiKeyId, err)
return
}

responseJSON, err := json.Marshal(resp)
if err != nil {
log.Printf("Failed to marshal response: %v", err)
return
var messageStr string
switch v := content.(type) {
case string:
messageStr = v
case []byte:
messageStr = string(v)
default:
jsonBytes, err := json.Marshal(v)
if err != nil {
log.Printf("Failed to marshal content: %v", err)
return
}
messageStr = string(jsonBytes)
}

auditLog := lib.AuditLogs(string(responseJSON), "chat_completion", apiKeyId, "output", productID, r)
auditLog := lib.AuditLogs(
messageStr,
logPrefix+"_"+messageType,
apiKeyID,
messageType,
productID,
r,
)

if auditLog == nil {
log.Printf("Failed to create audit log")
return
log.Printf("Failed to create audit log for %s", logPrefix)
}
}

lib.LogUsage(
resp.Model,
0,
resp.Usage.PromptTokens,
resp.Usage.CompletionTokens,
resp.Usage.TotalTokens,
resp.Choices[0].FinishReason,
"chat_completion",
productID,
auditLog.Id,
)
func LogProviderInput(r *http.Request, provider string, content interface{}) {
performProviderAuditLog(r, provider, "input", content)
}

func LogProviderOutput(r *http.Request, provider string, content interface{}) {
performProviderAuditLog(r, provider, "output", content)
}

func LogProviderError(r *http.Request, provider string, err error) {
performProviderAuditLog(r, provider, "error", err.Error())
}

// getProductIDFromAPIKey centralizes the product ID lookup logic
func getProductIDFromAPIKey(apiKeyID uuid.UUID) (uuid.UUID, error) {
var productIDStr string
err := lib.DB().Table("api_keys").Where("id = ?", apiKeyID).Pluck("product_id", &productIDStr).Error
if err != nil {
return uuid.Nil, err
}

productID, err := uuid.Parse(productIDStr)
if err != nil {
return uuid.Nil, fmt.Errorf("failed to parse product_id as UUID")
}

return productID, nil
}

func HandleContextCache(ctx context.Context, req ChatCompletionRequest, productID uuid.UUID) (string, bool, error) {
config := lib.GetConfig()
if !config.Settings.ContextCache.Enabled {
Expand Down Expand Up @@ -426,7 +458,40 @@ func HandleAPICallAndResponse(w http.ResponseWriter, r *http.Request, ctx contex
log.Printf("Error setting context cache: %v", err)
}

PerformResponseAuditLogging(r, resp)
apiKeyId := r.Context().Value("apiKeyId").(uuid.UUID)
responseJSON, err := json.Marshal(resp)
if err != nil {
log.Printf("Failed to marshal response: %v", err)
HandleError(w, fmt.Errorf("error encoding response: %v", err), http.StatusInternalServerError)
return
}

auditLog := lib.AuditLogs(
string(responseJSON),
"chat_completion",
apiKeyId,
"output",
productID,
r,
)

if auditLog == nil {
log.Printf("Failed to create audit log")
HandleError(w, fmt.Errorf("failed to create audit log"), http.StatusInternalServerError)
return
}

lib.LogUsage(
resp.Model,
0,
resp.Usage.PromptTokens,
resp.Usage.CompletionTokens,
resp.Usage.TotalTokens,
resp.Choices[0].FinishReason,
"chat_completion",
productID,
auditLog.Id,
)

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
Expand Down