Skip to content

Commit

Permalink
Order Fix (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
pigri authored Nov 12, 2024
2 parents 1df50ec + 91835f2 commit be0eb6e
Showing 1 changed file with 49 additions and 31 deletions.
80 changes: 49 additions & 31 deletions lib/provider/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,6 @@ func HandleCommonRequestLogic(w http.ResponseWriter, r *http.Request, providerNa

log.Printf("Received request: %+v", req)

PerformAuditLogging(r, providerName+"_create_message", "input", body)

filtered, err := ProcessInput(w, r, req)
if err != nil || filtered {
log.Printf("Request filtered or error occurred: %v", err)
return ChatCompletionRequest{}, nil, uuid.Nil, false
}

apiKeyID, ok := r.Context().Value("apiKeyId").(uuid.UUID)
if !ok {
HandleError(w, fmt.Errorf("apiKeyId not found in context"), http.StatusInternalServerError)
Expand All @@ -367,6 +359,37 @@ func HandleCommonRequestLogic(w http.ResponseWriter, r *http.Request, providerNa

ctx := context.WithValue(r.Context(), "productID", productID)

// Check cache before running rules
if !req.Stream {
cachedResponse, cacheHit, err := HandleContextCache(ctx, req, productID)
if err != nil {
log.Printf("Error handling context cache: %v", err)
}

if cacheHit {
log.Println("Cache hit, using cached response")
resp, err := CreateChatCompletionResponseFromCache(cachedResponse, req.Model)
if err != nil {
log.Printf("Error creating response from cache: %v", err)
} else {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Error encoding cached response: %v", err)
http.Error(w, "Error encoding response", http.StatusInternalServerError)
}
return ChatCompletionRequest{}, nil, uuid.Nil, true
}
}
}

PerformAuditLogging(r, providerName+"_create_message", "input", body)

filtered, err := ProcessInput(w, r, req)
if err != nil || filtered {
log.Printf("Request filtered or error occurred: %v", err)
return ChatCompletionRequest{}, nil, uuid.Nil, false
}

return req, ctx, productID, true
}

Expand All @@ -387,33 +410,28 @@ func HandleCacheLogic(ctx context.Context, req ChatCompletionRequest, productID

func HandleAPICallAndResponse(w http.ResponseWriter, r *http.Request, ctx context.Context, req ChatCompletionRequest, productID uuid.UUID, provider Provider) {
if req.Stream {
handleStreamingRequest(w, r, provider, req)
} else {
resp, cacheHit, err := HandleCacheLogic(ctx, req, productID)
if err != nil {
log.Printf("Error handling cache logic: %v", err)
if err := handleStreamingRequest(w, r, provider, req); err != nil {
HandleError(w, err, http.StatusInternalServerError)
}
return
}

if !cacheHit {
log.Printf("Cache miss, making API call to provider")
resp, err = provider.CreateChatCompletion(ctx, req)
if err != nil {
HandleError(w, fmt.Errorf("error creating chat completion: %v", err), http.StatusInternalServerError)
return
}
resp, err := provider.CreateChatCompletion(ctx, req)
if err != nil {
HandleError(w, fmt.Errorf("error creating chat completion: %v", err), http.StatusInternalServerError)
return
}

if err := SetContextCacheResponse(ctx, req, resp, productID); err != nil {
log.Printf("Error setting context cache: %v", err)
}
if err := SetContextCacheResponse(ctx, req, resp, productID); err != nil {
log.Printf("Error setting context cache: %v", err)
}

PerformResponseAuditLogging(r, resp)
}
PerformResponseAuditLogging(r, resp)

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Error encoding response: %v", err)
http.Error(w, "Error encoding response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("Error encoding response: %v", err)
http.Error(w, "Error encoding response", http.StatusInternalServerError)
return
}
}

0 comments on commit be0eb6e

Please sign in to comment.