From ef20d51350caf97ae1a4c01388d41807cc29949e Mon Sep 17 00:00:00 2001 From: hitesh-1997 Date: Sun, 18 Aug 2024 01:50:37 +0530 Subject: [PATCH] direct route for fireworks models --- .../internal/httpapi/completions/anthropic.go | 2 +- .../httpapi/completions/anthropicmessages.go | 2 +- .../internal/httpapi/completions/fireworks.go | 80 +++++++++++++++++-- .../internal/httpapi/completions/google.go | 2 +- .../internal/httpapi/completions/openai.go | 2 +- .../internal/httpapi/completions/upstream.go | 6 +- cmd/cody-gateway/shared/config/config.go | 13 +++ 7 files changed, 93 insertions(+), 14 deletions(-) diff --git a/cmd/cody-gateway/internal/httpapi/completions/anthropic.go b/cmd/cody-gateway/internal/httpapi/completions/anthropic.go index b42af51aee2de..62ffd4d1e982a 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/anthropic.go +++ b/cmd/cody-gateway/internal/httpapi/completions/anthropic.go @@ -155,7 +155,7 @@ func (a *AnthropicHandlerMethods) getRequestMetadata(body anthropicRequest) (mod } } -func (a *AnthropicHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) { +func (a *AnthropicHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *anthropicRequest) { // Mimic headers set by the official Anthropic client: // https://sourcegraph.com/github.com/anthropics/anthropic-sdk-typescript@493075d70f50f1568a276ed0cb177e297f5fef9f/-/blob/src/index.ts upstreamRequest.Header.Set("Cache-Control", "no-cache") diff --git a/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go b/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go index 7e78566e6331b..4b8e02e627952 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go +++ b/cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go @@ -209,7 +209,7 @@ func (a *AnthropicMessagesHandlerMethods) getRequestMetadata(body anthropicMessa } } -func (a *AnthropicMessagesHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) { +func (a *AnthropicMessagesHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *anthropicMessagesRequest) { upstreamRequest.Header.Set("Content-Type", "application/json") upstreamRequest.Header.Set("X-API-Key", a.config.AccessToken) upstreamRequest.Header.Set("anthropic-version", "2023-06-01") diff --git a/cmd/cody-gateway/internal/httpapi/completions/fireworks.go b/cmd/cody-gateway/internal/httpapi/completions/fireworks.go index b91faf99c3259..dbd43475a0590 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/fireworks.go +++ b/cmd/cody-gateway/internal/httpapi/completions/fireworks.go @@ -24,6 +24,11 @@ import ( "github.com/sourcegraph/sourcegraph/internal/httpcli" ) +type ModelDirectRouteSpec struct { + Url string + AccessToken string +} + func NewFireworksHandler(baseLogger log.Logger, eventLogger events.Logger, rs limiter.RedisStore, rateLimitNotifier notify.RateLimitNotifier, httpClient httpcli.Doer, config config.FireworksConfig, promptRecorder PromptRecorder, upstreamConfig UpstreamHandlerConfig, tracedRequestsCounter metric.Int64Counter) http.Handler { // Setting to a valuer higher than SRC_HTTP_CLI_EXTERNAL_RETRY_AFTER_MAX_DURATION to not // do any retries @@ -61,7 +66,13 @@ type fireworksRequest struct { Stream bool `json:"stream,omitempty"` Echo bool `json:"echo,omitempty"` Stop []string `json:"stop,omitempty"` - LanguageID string `json:"languageId,omitempty"` + User string `json:"user,omitempty"` + + // These are the extra fields, that are used for experimentation purpose + // and deleted before sending request to upstream. + LanguageID string `json:"languageId,omitempty"` + AnonymousUserID string `json:"anonymousUserID,omitempty"` + ShouldUseDirectRoute bool `json:"shouldUseDirectRoute,omitempty" default:"false"` } func (fr fireworksRequest) ShouldStream() bool { @@ -108,10 +119,15 @@ type FireworksHandlerMethods struct { tracedRequestsCounter metric.Int64Counter } -func (f *FireworksHandlerMethods) getAPIURL(feature codygateway.Feature, _ fireworksRequest) string { +func (f *FireworksHandlerMethods) getAPIURL(feature codygateway.Feature, body fireworksRequest) string { if feature == codygateway.FeatureChatCompletions { return "https://api.fireworks.ai/inference/v1/chat/completions" } else { + directRouteSpec, ok := f.GetDirectRouteSpec(&body) + if ok && directRouteSpec != nil { + // Use Direct Route if specified. + return directRouteSpec.Url + } return "https://api.fireworks.ai/inference/v1/completions" } } @@ -133,27 +149,75 @@ func (f *FireworksHandlerMethods) transformBody(body *fireworksRequest, _ string body.N = 1 } modelLanguageId := body.LanguageID - // Delete the fields that are not supported by the Fireworks API. - if body.LanguageID != "" { - body.LanguageID = "" - } body.Model = pickStarCoderModel(body.Model, f.config) body.Model = pickFineTunedModel(body.Model, modelLanguageId) + + directRouteSpec, ok := f.GetDirectRouteSpec(body) + if directRouteSpec != nil && ok && body.AnonymousUserID != "" { + body.User = body.AnonymousUserID + } + // Delete ExtraFields from the body + body.LanguageID = "" + body.AnonymousUserID = "" +} + +func (f *FireworksHandlerMethods) GetDirectRouteSpec(body *fireworksRequest) (*ModelDirectRouteSpec, bool) { + if !body.ShouldUseDirectRoute { + return nil, false + } + + directRouteUrlMappings := map[string]string{ + fireworks.DeepseekCoderV2LiteBase: "https://sourcegraph-7ca5ec0c.direct.fireworks.ai/v1/completions", + } + + modelURL, exists := directRouteUrlMappings[body.Model] + if !exists || modelURL == "" { + return nil, false + } + + token := f.getDirectAccessToken(body.Model) + if token == "" { + return nil, false + } + + return &ModelDirectRouteSpec{ + Url: modelURL, + AccessToken: token, + }, true +} + +func (f *FireworksHandlerMethods) getDirectAccessToken(model string) string { + switch model { + case fireworks.DeepseekCoderV2LiteBase: + return f.config.DirectRouteConfig.DeepSeekCoderV2LiteBaseAccessToken + default: + return "" + } } func (f *FireworksHandlerMethods) getRequestMetadata(body fireworksRequest) (model string, additionalMetadata map[string]any) { return body.Model, map[string]any{"stream": body.Stream} } -func (f *FireworksHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) { +func (f *FireworksHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, body *fireworksRequest) { // Enable tracing if the client requests it, see https://readme.fireworks.ai/docs/enabling-tracing if downstreamRequest.Header.Get("X-Fireworks-Genie") == "true" { upstreamRequest.Header.Set("X-Fireworks-Genie", "true") f.tracedRequestsCounter.Add(downstreamRequest.Context(), 1) } upstreamRequest.Header.Set("Content-Type", "application/json") - upstreamRequest.Header.Set("Authorization", "Bearer "+f.config.AccessToken) + + directRouteSpec, ok := f.GetDirectRouteSpec(body) + if ok && directRouteSpec != nil { + if body.AnonymousUserID != "" { + upstreamRequest.Header.Set("X-Session-Affinity", body.AnonymousUserID) + } + upstreamRequest.Header.Set("Authorization", "Bearer "+directRouteSpec.AccessToken) + } else { + upstreamRequest.Header.Set("Authorization", "Bearer "+f.config.AccessToken) + } + body.ShouldUseDirectRoute = false } func (f *FireworksHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody fireworksRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) { diff --git a/cmd/cody-gateway/internal/httpapi/completions/google.go b/cmd/cody-gateway/internal/httpapi/completions/google.go index faf55c19f512b..385d4463c10f1 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/google.go +++ b/cmd/cody-gateway/internal/httpapi/completions/google.go @@ -103,7 +103,7 @@ func (*GoogleHandlerMethods) getRequestMetadata(body googleRequest) (model strin return body.Model, map[string]any{"stream": body.ShouldStream()} } -func (o *GoogleHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) { +func (o *GoogleHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *googleRequest) { upstreamRequest.Header.Set("Content-Type", "application/json") } diff --git a/cmd/cody-gateway/internal/httpapi/completions/openai.go b/cmd/cody-gateway/internal/httpapi/completions/openai.go index 1fb037f9f6efa..b37c2c3299cb8 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/openai.go +++ b/cmd/cody-gateway/internal/httpapi/completions/openai.go @@ -145,7 +145,7 @@ func (*OpenAIHandlerMethods) getRequestMetadata(body openaiRequest) (model strin return body.Model, map[string]any{"stream": body.Stream} } -func (o *OpenAIHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) { +func (o *OpenAIHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *openaiRequest) { upstreamRequest.Header.Set("Content-Type", "application/json") upstreamRequest.Header.Set("Authorization", "Bearer "+o.config.AccessToken) if o.config.OrgID != "" { diff --git a/cmd/cody-gateway/internal/httpapi/completions/upstream.go b/cmd/cody-gateway/internal/httpapi/completions/upstream.go index 394d3c3b95433..7dd5b7b701d03 100644 --- a/cmd/cody-gateway/internal/httpapi/completions/upstream.go +++ b/cmd/cody-gateway/internal/httpapi/completions/upstream.go @@ -94,7 +94,7 @@ type upstreamHandlerMethods[ReqT UpstreamRequest] interface { // transformRequest can be used to modify the HTTP request before it is sent // upstream. The downstreamRequest parameter is the request sent from the Gateway client. // To manipulate the body, use transformBody. - transformRequest(downstreamRequest, upstreamRequest *http.Request) + transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *ReqT) // getRequestMetadata should extract details about the request we are sending // upstream for validation and tracking purposes. Usage data does not need // to be reported here - instead, use parseResponseAndUsage to extract usage, @@ -326,7 +326,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest]( } // Run the request transformer. - methods.transformRequest(downstreamRequest, upstreamRequest) + methods.transformRequest(downstreamRequest, upstreamRequest, &body) // Retrieve metadata from the initial request. model, requestMetadata := methods.getRequestMetadata(body) @@ -424,6 +424,8 @@ func makeUpstreamHandler[ReqT UpstreamRequest]( return } + fmt.Println("-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* upstreamRequest\n", upstreamRequest) + resp, err := httpClient.Do(upstreamRequest) defer modelAvailabilityTracker.record(gatewayModel, resp, err) diff --git a/cmd/cody-gateway/shared/config/config.go b/cmd/cody-gateway/shared/config/config.go index 9b588db4d2d12..14587fa18dbc9 100644 --- a/cmd/cody-gateway/shared/config/config.go +++ b/cmd/cody-gateway/shared/config/config.go @@ -102,6 +102,11 @@ type AnthropicConfig struct { FlaggingConfig FlaggingConfig } +type FireworksDirectRouteConfig struct { + // direct route token for deepseek model + DeepSeekCoderV2LiteBaseAccessToken string +} + type FireworksConfig struct { // Non-prefixed model names AllowedModels []string @@ -109,6 +114,7 @@ type FireworksConfig struct { StarcoderCommunitySingleTenantPercent int StarcoderEnterpriseSingleTenantPercent int FlaggingConfig FlaggingConfig + DirectRouteConfig FireworksDirectRouteConfig } type OpenAIConfig struct { @@ -307,6 +313,7 @@ func (c *Config) Load() { } c.Fireworks.StarcoderCommunitySingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_COMMUNITY_SINGLE_TENANT_PERCENT", "0", "The percentage of community traffic for Starcoder to be redirected to the single-tenant deployment.") c.Fireworks.StarcoderEnterpriseSingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_ENTERPRISE_SINGLE_TENANT_PERCENT", "100", "The percentage of Enterprise traffic for Starcoder to be redirected to the single-tenant deployment.") + c.Fireworks.DirectRouteConfig = c.GetFireworksDirectRouteConfig() // Configurations for Google Gemini models. c.Google.AccessToken = c.GetOptional("CODY_GATEWAY_GOOGLE_ACCESS_TOKEN", "The Google AI Studio access token to be used.") @@ -425,6 +432,12 @@ func (c *Config) loadFlaggingConfig(cfg *FlaggingConfig, envVarPrefix string) { cfg.FlaggedModelNames = maybeLoadLowercaseSlice("FLAGGED_MODEL_NAMES", "LLM models that will always lead to the request getting flagged.") } +func (c *Config) GetFireworksDirectRouteConfig() FireworksDirectRouteConfig { + return FireworksDirectRouteConfig{ + DeepSeekCoderV2LiteBaseAccessToken: c.Get("CODY_GATEWAY_FIREWORKS_DIRECT_ROUTE_DEEPSEEK_CODER_V2_LITE_BASE_ACCESS_TOKEN", "", "DeepseekCoderV2LiteBaseAccessToken"), + } +} + // splitMaybe splits the provided string on commas, but returns nil if given the empty string. func splitMaybe(input string) []string { if input == "" {