diff --git a/lib/provider/common.go b/lib/provider/common.go index b3e3e60..db90134 100644 --- a/lib/provider/common.go +++ b/lib/provider/common.go @@ -345,14 +345,6 @@ func HandleCommonRequestLogic(w http.ResponseWriter, r *http.Request, providerNa log.Printf("Received request: %+v", req) - PerformAuditLogging(r, providerName+"_create_message", "input", body) - - filtered, err := ProcessInput(w, r, req) - if err != nil || filtered { - log.Printf("Request filtered or error occurred: %v", err) - return ChatCompletionRequest{}, nil, uuid.Nil, false - } - apiKeyID, ok := r.Context().Value("apiKeyId").(uuid.UUID) if !ok { HandleError(w, fmt.Errorf("apiKeyId not found in context"), http.StatusInternalServerError) @@ -367,6 +359,37 @@ func HandleCommonRequestLogic(w http.ResponseWriter, r *http.Request, providerNa ctx := context.WithValue(r.Context(), "productID", productID) + // Check cache before running rules + if !req.Stream { + cachedResponse, cacheHit, err := HandleContextCache(ctx, req, productID) + if err != nil { + log.Printf("Error handling context cache: %v", err) + } + + if cacheHit { + log.Println("Cache hit, using cached response") + resp, err := CreateChatCompletionResponseFromCache(cachedResponse, req.Model) + if err != nil { + log.Printf("Error creating response from cache: %v", err) + } else { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Printf("Error encoding cached response: %v", err) + http.Error(w, "Error encoding response", http.StatusInternalServerError) + } + return ChatCompletionRequest{}, nil, uuid.Nil, true + } + } + } + + PerformAuditLogging(r, providerName+"_create_message", "input", body) + + filtered, err := ProcessInput(w, r, req) + if err != nil || filtered { + log.Printf("Request filtered or error occurred: %v", err) + return ChatCompletionRequest{}, nil, uuid.Nil, false + } + return req, ctx, productID, true } @@ -387,33 +410,28 @@ func HandleCacheLogic(ctx context.Context, req ChatCompletionRequest, productID func HandleAPICallAndResponse(w http.ResponseWriter, r *http.Request, ctx context.Context, req ChatCompletionRequest, productID uuid.UUID, provider Provider) { if req.Stream { - handleStreamingRequest(w, r, provider, req) - } else { - resp, cacheHit, err := HandleCacheLogic(ctx, req, productID) - if err != nil { - log.Printf("Error handling cache logic: %v", err) + if err := handleStreamingRequest(w, r, provider, req); err != nil { + HandleError(w, err, http.StatusInternalServerError) } + return + } - if !cacheHit { - log.Printf("Cache miss, making API call to provider") - resp, err = provider.CreateChatCompletion(ctx, req) - if err != nil { - HandleError(w, fmt.Errorf("error creating chat completion: %v", err), http.StatusInternalServerError) - return - } + resp, err := provider.CreateChatCompletion(ctx, req) + if err != nil { + HandleError(w, fmt.Errorf("error creating chat completion: %v", err), http.StatusInternalServerError) + return + } - if err := SetContextCacheResponse(ctx, req, resp, productID); err != nil { - log.Printf("Error setting context cache: %v", err) - } + if err := SetContextCacheResponse(ctx, req, resp, productID); err != nil { + log.Printf("Error setting context cache: %v", err) + } - PerformResponseAuditLogging(r, resp) - } + PerformResponseAuditLogging(r, resp) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(resp); err != nil { - log.Printf("Error encoding response: %v", err) - http.Error(w, "Error encoding response", http.StatusInternalServerError) - return - } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Printf("Error encoding response: %v", err) + http.Error(w, "Error encoding response", http.StatusInternalServerError) + return } }