Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

direct route for fireworks models #64521

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/cody-gateway/internal/httpapi/completions/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (a *AnthropicHandlerMethods) getRequestMetadata(body anthropicRequest) (mod
}
}

func (a *AnthropicHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
func (a *AnthropicHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *anthropicRequest) {
// Mimic headers set by the official Anthropic client:
// https://sourcegraph.com/github.com/anthropics/anthropic-sdk-typescript@493075d70f50f1568a276ed0cb177e297f5fef9f/-/blob/src/index.ts
upstreamRequest.Header.Set("Cache-Control", "no-cache")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func (a *AnthropicMessagesHandlerMethods) getRequestMetadata(body anthropicMessa
}
}

func (a *AnthropicMessagesHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
func (a *AnthropicMessagesHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *anthropicMessagesRequest) {
upstreamRequest.Header.Set("Content-Type", "application/json")
upstreamRequest.Header.Set("X-API-Key", a.config.AccessToken)
upstreamRequest.Header.Set("anthropic-version", "2023-06-01")
Expand Down
80 changes: 72 additions & 8 deletions cmd/cody-gateway/internal/httpapi/completions/fireworks.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ import (
"github.com/sourcegraph/sourcegraph/internal/httpcli"
)

type ModelDirectRouteSpec struct {
Url string
AccessToken string
}

func NewFireworksHandler(baseLogger log.Logger, eventLogger events.Logger, rs limiter.RedisStore, rateLimitNotifier notify.RateLimitNotifier, httpClient httpcli.Doer, config config.FireworksConfig, promptRecorder PromptRecorder, upstreamConfig UpstreamHandlerConfig, tracedRequestsCounter metric.Int64Counter) http.Handler {
// Setting to a valuer higher than SRC_HTTP_CLI_EXTERNAL_RETRY_AFTER_MAX_DURATION to not
// do any retries
Expand Down Expand Up @@ -61,7 +66,13 @@ type fireworksRequest struct {
Stream bool `json:"stream,omitempty"`
Echo bool `json:"echo,omitempty"`
Stop []string `json:"stop,omitempty"`
LanguageID string `json:"languageId,omitempty"`
User string `json:"user,omitempty"`

// These are the extra fields, that are used for experimentation purpose
// and deleted before sending request to upstream.
LanguageID string `json:"languageId,omitempty"`
AnonymousUserID string `json:"anonymousUserID,omitempty"`
ShouldUseDirectRoute bool `json:"shouldUseDirectRoute,omitempty" default:"false"`
}

func (fr fireworksRequest) ShouldStream() bool {
Expand Down Expand Up @@ -108,10 +119,15 @@ type FireworksHandlerMethods struct {
tracedRequestsCounter metric.Int64Counter
}

func (f *FireworksHandlerMethods) getAPIURL(feature codygateway.Feature, _ fireworksRequest) string {
func (f *FireworksHandlerMethods) getAPIURL(feature codygateway.Feature, body fireworksRequest) string {
if feature == codygateway.FeatureChatCompletions {
return "https://api.fireworks.ai/inference/v1/chat/completions"
} else {
directRouteSpec, ok := f.GetDirectRouteSpec(&body)
if ok && directRouteSpec != nil {
// Use Direct Route if specified.
return directRouteSpec.Url
}
return "https://api.fireworks.ai/inference/v1/completions"
}
}
Expand All @@ -133,27 +149,75 @@ func (f *FireworksHandlerMethods) transformBody(body *fireworksRequest, _ string
body.N = 1
}
modelLanguageId := body.LanguageID
// Delete the fields that are not supported by the Fireworks API.
if body.LanguageID != "" {
body.LanguageID = ""
}

body.Model = pickStarCoderModel(body.Model, f.config)
body.Model = pickFineTunedModel(body.Model, modelLanguageId)

directRouteSpec, ok := f.GetDirectRouteSpec(body)
if directRouteSpec != nil && ok && body.AnonymousUserID != "" {
body.User = body.AnonymousUserID
}
// Delete ExtraFields from the body
body.LanguageID = ""
body.AnonymousUserID = ""
}

func (f *FireworksHandlerMethods) GetDirectRouteSpec(body *fireworksRequest) (*ModelDirectRouteSpec, bool) {
if !body.ShouldUseDirectRoute {
return nil, false
}

directRouteUrlMappings := map[string]string{
fireworks.DeepseekCoderV2LiteBase: "https://sourcegraph-7ca5ec0c.direct.fireworks.ai/v1/completions",
}

modelURL, exists := directRouteUrlMappings[body.Model]
if !exists || modelURL == "" {
return nil, false
}

token := f.getDirectAccessToken(body.Model)
if token == "" {
return nil, false
}

return &ModelDirectRouteSpec{
Url: modelURL,
AccessToken: token,
}, true
}

func (f *FireworksHandlerMethods) getDirectAccessToken(model string) string {
switch model {
case fireworks.DeepseekCoderV2LiteBase:
return f.config.DirectRouteConfig.DeepSeekCoderV2LiteBaseAccessToken
default:
return ""
}
}

func (f *FireworksHandlerMethods) getRequestMetadata(body fireworksRequest) (model string, additionalMetadata map[string]any) {
return body.Model, map[string]any{"stream": body.Stream}
}

func (f *FireworksHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
func (f *FireworksHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, body *fireworksRequest) {
// Enable tracing if the client requests it, see https://readme.fireworks.ai/docs/enabling-tracing
if downstreamRequest.Header.Get("X-Fireworks-Genie") == "true" {
upstreamRequest.Header.Set("X-Fireworks-Genie", "true")
f.tracedRequestsCounter.Add(downstreamRequest.Context(), 1)
}
upstreamRequest.Header.Set("Content-Type", "application/json")
upstreamRequest.Header.Set("Authorization", "Bearer "+f.config.AccessToken)

directRouteSpec, ok := f.GetDirectRouteSpec(body)
if ok && directRouteSpec != nil {
if body.AnonymousUserID != "" {
upstreamRequest.Header.Set("X-Session-Affinity", body.AnonymousUserID)
}
upstreamRequest.Header.Set("Authorization", "Bearer "+directRouteSpec.AccessToken)
} else {
upstreamRequest.Header.Set("Authorization", "Bearer "+f.config.AccessToken)
}
body.ShouldUseDirectRoute = false
}

func (f *FireworksHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody fireworksRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {
Expand Down
2 changes: 1 addition & 1 deletion cmd/cody-gateway/internal/httpapi/completions/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (*GoogleHandlerMethods) getRequestMetadata(body googleRequest) (model strin
return body.Model, map[string]any{"stream": body.ShouldStream()}
}

func (o *GoogleHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
func (o *GoogleHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *googleRequest) {
upstreamRequest.Header.Set("Content-Type", "application/json")
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/cody-gateway/internal/httpapi/completions/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (*OpenAIHandlerMethods) getRequestMetadata(body openaiRequest) (model strin
return body.Model, map[string]any{"stream": body.Stream}
}

func (o *OpenAIHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
func (o *OpenAIHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *openaiRequest) {
upstreamRequest.Header.Set("Content-Type", "application/json")
upstreamRequest.Header.Set("Authorization", "Bearer "+o.config.AccessToken)
if o.config.OrgID != "" {
Expand Down
6 changes: 4 additions & 2 deletions cmd/cody-gateway/internal/httpapi/completions/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type upstreamHandlerMethods[ReqT UpstreamRequest] interface {
// transformRequest can be used to modify the HTTP request before it is sent
// upstream. The downstreamRequest parameter is the request sent from the Gateway client.
// To manipulate the body, use transformBody.
transformRequest(downstreamRequest, upstreamRequest *http.Request)
transformRequest(downstreamRequest, upstreamRequest *http.Request, _ *ReqT)
// getRequestMetadata should extract details about the request we are sending
// upstream for validation and tracking purposes. Usage data does not need
// to be reported here - instead, use parseResponseAndUsage to extract usage,
Expand Down Expand Up @@ -326,7 +326,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
}

// Run the request transformer.
methods.transformRequest(downstreamRequest, upstreamRequest)
methods.transformRequest(downstreamRequest, upstreamRequest, &body)

// Retrieve metadata from the initial request.
model, requestMetadata := methods.getRequestMetadata(body)
Expand Down Expand Up @@ -424,6 +424,8 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
return
}

fmt.Println("-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* upstreamRequest\n", upstreamRequest)

resp, err := httpClient.Do(upstreamRequest)
defer modelAvailabilityTracker.record(gatewayModel, resp, err)

Expand Down
13 changes: 13 additions & 0 deletions cmd/cody-gateway/shared/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,19 @@ type AnthropicConfig struct {
FlaggingConfig FlaggingConfig
}

type FireworksDirectRouteConfig struct {
// direct route token for deepseek model
DeepSeekCoderV2LiteBaseAccessToken string
}

type FireworksConfig struct {
// Non-prefixed model names
AllowedModels []string
AccessToken string
StarcoderCommunitySingleTenantPercent int
StarcoderEnterpriseSingleTenantPercent int
FlaggingConfig FlaggingConfig
DirectRouteConfig FireworksDirectRouteConfig
}

type OpenAIConfig struct {
Expand Down Expand Up @@ -307,6 +313,7 @@ func (c *Config) Load() {
}
c.Fireworks.StarcoderCommunitySingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_COMMUNITY_SINGLE_TENANT_PERCENT", "0", "The percentage of community traffic for Starcoder to be redirected to the single-tenant deployment.")
c.Fireworks.StarcoderEnterpriseSingleTenantPercent = c.GetPercent("CODY_GATEWAY_FIREWORKS_STARCODER_ENTERPRISE_SINGLE_TENANT_PERCENT", "100", "The percentage of Enterprise traffic for Starcoder to be redirected to the single-tenant deployment.")
c.Fireworks.DirectRouteConfig = c.GetFireworksDirectRouteConfig()

// Configurations for Google Gemini models.
c.Google.AccessToken = c.GetOptional("CODY_GATEWAY_GOOGLE_ACCESS_TOKEN", "The Google AI Studio access token to be used.")
Expand Down Expand Up @@ -425,6 +432,12 @@ func (c *Config) loadFlaggingConfig(cfg *FlaggingConfig, envVarPrefix string) {
cfg.FlaggedModelNames = maybeLoadLowercaseSlice("FLAGGED_MODEL_NAMES", "LLM models that will always lead to the request getting flagged.")
}

func (c *Config) GetFireworksDirectRouteConfig() FireworksDirectRouteConfig {
return FireworksDirectRouteConfig{
DeepSeekCoderV2LiteBaseAccessToken: c.Get("CODY_GATEWAY_FIREWORKS_DIRECT_ROUTE_DEEPSEEK_CODER_V2_LITE_BASE_ACCESS_TOKEN", "", "DeepseekCoderV2LiteBaseAccessToken"),
}
}

// splitMaybe splits the provided string on commas, but returns nil if given the empty string.
func splitMaybe(input string) []string {
if input == "" {
Expand Down
Loading