Skip to content

Commit

Permalink
feat(ai-proxy): add support for vertex ai provider (#1590)
Browse files Browse the repository at this point in the history
  • Loading branch information
floreks authored Nov 19, 2024
1 parent 66b8d5a commit f83fd4a
Show file tree
Hide file tree
Showing 16 changed files with 495 additions and 25 deletions.
4 changes: 2 additions & 2 deletions charts/ai-proxy/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
version: 0.1.0
version: 0.2.0

# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to
# follow Semantic Versioning. They should reflect the version the application is using.
# It is recommended to use it with quotes.
appVersion: "v1.0.0"
appVersion: "v1.1.0"
7 changes: 6 additions & 1 deletion charts/ai-proxy/templates/secrets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,9 @@ metadata:
{{- include "ai-proxy.labels" . | nindent 4 }}
type: Opaque
data:
PLRL_PROVIDER_TOKEN: {{ .Values.secrets.token | b64enc | quote }}
{{- with .Values.secrets.token }}
PLRL_PROVIDER_TOKEN: {{ . | b64enc | quote }}
{{- end }}
{{- with .Values.secrets.serviceAccount }}
PLRL_PROVIDER_SERVICE_ACCOUNT: {{ . | b64enc | quote }}
{{- end }}
7 changes: 4 additions & 3 deletions charts/ai-proxy/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ image:
# This sets the pull policy for images.
pullPolicy: IfNotPresent
# Overrides the image tag whose default is the chart appVersion.
tag: "v1.0.0"
tag: ~

config:
# One of: ollama, openai
# One of: ollama, openai, vertex
provider: openai
# Provider API URL
providerHost: https://api.openai.com

secrets:
token: changeme
token: ~
serviceAccount: ~

# AI Proxy container args
# Note: It can override config.provider and config.providerHost
Expand Down
14 changes: 14 additions & 0 deletions go/ai-proxy/api/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/pluralsh/console/go/ai-proxy/api/ollama"
"github.com/pluralsh/console/go/ai-proxy/api/openai"
"github.com/pluralsh/console/go/ai-proxy/api/vertex"
)

type Provider string
Expand All @@ -21,6 +22,8 @@ func ToProvider(s string) (Provider, error) {
return ProviderOpenAI, nil
case ProviderAnthropic.String():
return ProviderAnthropic, nil
case ProviderVertex.String():
return ProviderVertex, nil
}

return "", fmt.Errorf("invalid provider: %s", s)
Expand All @@ -30,6 +33,7 @@ const (
ProviderOpenAI Provider = "openai"
ProviderAnthropic Provider = "anthropic"
ProviderOllama Provider = "ollama"
ProviderVertex Provider = "vertex"
)

type OllamaAPI string
Expand All @@ -40,6 +44,9 @@ var (
ollamaToOpenAI ProviderAPIMapping = map[string]string{
ollama.EndpointChat: openai.EndpointChat,
}
ollamaToVertex ProviderAPIMapping = map[string]string{
ollama.EndpointChat: vertex.EndpointChat,
}
)

func ToProviderAPIPath(target Provider, path string) string {
Expand All @@ -52,6 +59,13 @@ func ToProviderAPIPath(target Provider, path string) string {
panic(fmt.Sprintf("path %s not registered for provider %s", path, target))
}

return targetPath
case ProviderVertex:
targetPath, exists := ollamaToVertex[path]
if !exists {
panic(fmt.Sprintf("path %s not registered for provider %s", path, target))
}

return targetPath
}

Expand Down
34 changes: 34 additions & 0 deletions go/ai-proxy/api/vertex/vertex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package vertex

import (
ollamaapi "github.com/ollama/ollama/api"
"github.com/pluralsh/polly/algorithms"
)

type Endpoint string

const (
EndpointChat = "/v1/projects/${PROJECT_ID}/locations/${LOCATION}/endpoints/openapi/chat/completions"
EnvProjectID = "PROJECT_ID"
EnvLocation = "LOCATION"
)

type ErrorResponse struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status,omitempty"`
} `json:"error"`
}

func FromErrorResponse(statusCode int) func(response []ErrorResponse) []ollamaapi.StatusError {
return func(in []ErrorResponse) []ollamaapi.StatusError {
return algorithms.Map(in, func(err ErrorResponse) ollamaapi.StatusError {
return ollamaapi.StatusError{
StatusCode: statusCode,
ErrorMessage: err.Error.Message,
Status: err.Error.Status,
}
})
}
}
30 changes: 20 additions & 10 deletions go/ai-proxy/args/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ import (
)

const (
envProviderToken = "PROVIDER_TOKEN"
envProviderToken = "PROVIDER_TOKEN"
envProviderServiceAccount = "PROVIDER_SERVICE_ACCOUNT"

defaultPort = 8000
defaultProvider = api.ProviderOllama
defaultAddress = "0.0.0.0"
)

var (
argProvider = pflag.String("provider", defaultProvider.String(), "Provider name. Must be one of: ollama, openai. Defaults to 'ollama' type API.")
argProviderHost = pflag.String("provider-host", "", "Provider host address to access the API i.e. https://api.openai.com")
argProviderToken = pflag.String("provider-token", helpers.GetPluralEnv(envProviderToken, ""), "Provider token used to connect to the API if needed. Can be overridden via PLRL_PROVIDER_TOKEN env var.")
argPort = pflag.Int("port", defaultPort, "The port to listen on. Defaults to port 8000.")
argAddress = pflag.IP("address", net.ParseIP(defaultAddress), "The IP address to serve on. Defaults to 0.0.0.0 (all interfaces).")
argProvider = pflag.String("provider", defaultProvider.String(), "Provider name. Must be one of: ollama, openai, vertex. Defaults to 'ollama' type API.")
argProviderHost = pflag.String("provider-host", "", "Provider host address to access the API i.e. https://api.openai.com")
argProviderToken = pflag.String("provider-token", helpers.GetPluralEnv(envProviderToken, ""), "Provider token used to connect to the API if needed. Can be overridden via PLRL_PROVIDER_TOKEN env var.")
argProviderServiceAccount = pflag.String("provider-service-account", helpers.GetPluralEnv(envProviderServiceAccount, ""), "Provider service account file used to connect to the API if needed. Can be overridden via PLRL_PROVIDER_SERVICE_ACCOUNT env var.")
argPort = pflag.Int("port", defaultPort, "The port to listen on. Defaults to port 8000.")
argAddress = pflag.IP("address", net.ParseIP(defaultAddress), "The IP address to serve on. Defaults to 0.0.0.0 (all interfaces).")
)

func init() {
Expand Down Expand Up @@ -68,12 +70,20 @@ func ProviderHost() string {
return *argProviderHost
}

func ProviderToken() string {
if len(*argProviderToken) == 0 && Provider() != defaultProvider {
panic(fmt.Errorf("provider secret is required"))
func ProviderCredentials() string {
if len(*argProviderToken) > 0 && Provider() == api.ProviderOpenAI {
return *argProviderToken
}

return *argProviderToken
if len(*argProviderServiceAccount) > 0 && Provider() == api.ProviderVertex {
return *argProviderServiceAccount
}

if Provider() == defaultProvider {
return ""
}

panic(fmt.Errorf("provider credentials must be provided when %s provider is used", Provider()))
}

func Address() string {
Expand Down
28 changes: 28 additions & 0 deletions go/ai-proxy/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,33 @@ require (
)

require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/aiplatform v1.68.0 // indirect
cloud.google.com/go/auth v0.9.9 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
cloud.google.com/go/compute/metadata v0.5.2 // indirect
cloud.google.com/go/iam v1.2.1 // indirect
cloud.google.com/go/longrunning v0.6.1 // indirect
cloud.google.com/go/vertexai v0.13.2 // indirect
github.com/bytedance/sonic v1.12.3 // indirect
github.com/bytedance/sonic/loader v0.2.1 // indirect
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/gabriel-vasile/mimetype v1.4.6 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.10.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.22.1 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/googleapis/gax-go/v2 v2.13.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.8 // indirect
github.com/kr/pretty v0.3.0 // indirect
Expand All @@ -37,11 +51,25 @@ require (
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
go.opentelemetry.io/otel v1.29.0 // indirect
go.opentelemetry.io/otel/metric v1.29.0 // indirect
go.opentelemetry.io/otel/trace v1.29.0 // indirect
golang.org/x/arch v0.11.0 // indirect
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.7.0 // indirect
google.golang.org/api v0.203.0 // indirect
google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 // indirect
google.golang.org/grpc v1.67.1 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
Loading

0 comments on commit f83fd4a

Please sign in to comment.