Skip to content

Commit

Permalink
Add Google VertexAI support. (#52)
Browse files Browse the repository at this point in the history
* Add Google VertexAI support.

Support VertexAI framework.
Support detection and error handling.

* Improvements and cleanup as suggested by package owners:
* Fix whitespace
* Remove duplicate code
* Move certain functions to be methods.
* Better error message grammar.
* Remove commented-out unneeded code.
* Fix sample code.

* Integration test for VertexAI

* Refactor vertex changes to an interface.

* Run golines.

* Unit tests for Vertex functionality.
  • Loading branch information
steveheyman authored Dec 8, 2024
1 parent 04cc4c4 commit 63bad02
Show file tree
Hide file tree
Showing 21 changed files with 816 additions and 18 deletions.
73 changes: 72 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func main() {
}
return
}
fmt.Println(resp.Content[0].Text)
fmt.Println(*resp.Content[0].GetText())
}
```
</details>
Expand Down Expand Up @@ -305,6 +305,77 @@ func main() {

</details>

<details>
<summary>VertexAI example</summary>


If you are using a Google Credentials file, you can use the following code to create a client:

```go

package main

import (
"context"
"errors"
"fmt"
"os"

"github.com/liushuangls/go-anthropic/v2"
"golang.org/x/oauth2/google"
)

func main() {
credBytes, err := os.ReadFile("<path to your credentials file>")
if err != nil {
fmt.Println("Error reading file")
return
}

ts, err := google.JWTAccessTokenSourceWithScope(credBytes, "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/cloud-platform.read-only")
if err != nil {
fmt.Println("Error creating token source")
return
}

// use JWTAccessTokenSourceWithScope
token, err := ts.Token()
if err != nil {
fmt.Println("Error getting token")
return
}

fmt.Println(token.AccessToken)

client := anthropic.NewClient(token.AccessToken, anthropic.WithVertexAI("<YOUR PROJECTID>", "<YOUR LOCATION>"))

resp, err := client.CreateMessagesStream(context.Background(), anthropic.MessagesStreamRequest{
MessagesRequest: anthropic.MessagesRequest{
Model: anthropic.ModelClaude3Haiku20240307,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
},
OnContentBlockDelta: func(data anthropic.MessagesEventContentBlockDeltaData) {
fmt.Printf("Stream Content: %s\n", *data.Delta.Text)
},
})
if err != nil {
var e *anthropic.APIError
if errors.As(err, &e) {
fmt.Printf("Messages stream error, type: %s, message: %s", e.Type, e.Message)
} else {
fmt.Printf("Messages stream error: %v\n", err)
}
return
}
fmt.Println(resp.Content[0].GetText())
}

```
</details>

## Acknowledgments
The following project had particular influence on go-anthropic is design.

Expand Down
26 changes: 19 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ func (c *Client) handlerRequestError(resp *http.Response) error {
if err != nil {
return fmt.Errorf("error, reading response body: %w", err)
}

// use the adapter to translate the error, if it can
if err, handled := c.config.Adapter.TranslateError(resp, body); handled {
return err
}

var errRes ErrorResponse
err = json.Unmarshal(body, &errRes)
if err != nil || errRes.Error == nil {
Expand All @@ -74,15 +80,12 @@ func (c *Client) handlerRequestError(resp *http.Response) error {
}
return &reqErr
}

return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
}
return nil
}

func (c *Client) fullURL(suffix string) string {
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}

type requestSetter func(req *http.Request)

func withBetaVersion(betaVersion ...BetaVersion) requestSetter {
Expand All @@ -105,6 +108,14 @@ func (c *Client) newRequest(
body any,
requestSetters ...requestSetter,
) (req *http.Request, err error) {

// prepare the request
var fullURL string
fullURL, err = c.config.Adapter.PrepareRequest(c, method, urlSuffix, body)
if err != nil {
return nil, err
}

var reqBody []byte
if body != nil {
reqBody, err = json.Marshal(body)
Expand All @@ -116,7 +127,7 @@ func (c *Client) newRequest(
req, err = http.NewRequestWithContext(
ctx,
method,
c.fullURL(urlSuffix),
fullURL,
bytes.NewBuffer(reqBody),
)
if err != nil {
Expand All @@ -125,8 +136,9 @@ func (c *Client) newRequest(

req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json; charset=utf-8")
req.Header.Set("X-Api-Key", c.config.apiKey)
req.Header.Set("Anthropic-Version", string(c.config.APIVersion))

// set any provider-specific headers (including Authorization)
c.config.Adapter.SetRequestHeaders(c, req)

for _, setter := range requestSetters {
setter(req)
Expand Down
15 changes: 15 additions & 0 deletions clientadapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package anthropic

import (
"net/http"
)

// ClientAdapter is an interface that defines the methods that allow use of the anthropic API with different providers.
type ClientAdapter interface {
// Translate provider specific errors. Responds with an error and a boolean indicating if the error has been successfully parsed.
TranslateError(resp *http.Response, body []byte) (error, bool)
// Prepare the request for the provider and return the full URL
PrepareRequest(c *Client, method, urlSuffix string, body any) (string, error)
// Set the request headers for the provider
SetRequestHeaders(c *Client, req *http.Request) error
}
19 changes: 19 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,22 @@ const (
RoleUser ChatRole = "user"
RoleAssistant ChatRole = "assistant"
)

func (m Model) asVertexModel() string {
switch m {
case ModelClaude3Opus20240229:
return "claude-3-opus@20240229"
case ModelClaude3Sonnet20240229:
return "claude-3-sonnet@20240229"
case ModelClaude3Dot5Sonnet20240620:
return "claude-3-5-sonnet@20240620"
case ModelClaude3Dot5Sonnet20241022:
return "claude-3-5-sonnet@20241022"
case ModelClaude3Haiku20240307:
return "claude-3-haiku@20240307"
case ModelClaude3Dot5Haiku20241022:
return "claude-3-5-haiku@20241022"
default:
return string(m)
}
}
2 changes: 1 addition & 1 deletion complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (c *Client) CreateComplete(
request.Stream = false

urlSuffix := "/complete"
req, err := c.newRequest(ctx, http.MethodPost, urlSuffix, request)
req, err := c.newRequest(ctx, http.MethodPost, urlSuffix, &request)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion complete_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *Client) CreateCompleteStream(
request.Stream = true

urlSuffix := "/complete"
req, err := c.newStreamRequest(ctx, http.MethodPost, urlSuffix, request)
req, err := c.newStreamRequest(ctx, http.MethodPost, urlSuffix, &request)
if err != nil {
return
}
Expand Down
38 changes: 36 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package anthropic

import (
"fmt"
"net/http"
)

Expand All @@ -12,7 +13,8 @@ const (
type APIVersion string

const (
APIVersion20230601 APIVersion = "2023-06-01"
APIVersion20230601 APIVersion = "2023-06-01"
APIVersionVertex20231016 APIVersion = "vertex-2023-10-16"
)

type BetaVersion string
Expand All @@ -26,16 +28,21 @@ const (
BetaMaxTokens35Sonnet20240715 BetaVersion = "max-tokens-3-5-sonnet-2024-07-15"
)

type ApiKeyFunc func() string

// ClientConfig is a configuration of a client.
type ClientConfig struct {
apiKey string
apiKeyFunc ApiKeyFunc
apiKey string

BaseURL string
APIVersion APIVersion
BetaVersion []BetaVersion
HTTPClient *http.Client

EmptyMessagesLimit uint

Adapter ClientAdapter
}

type ClientOption func(c *ClientConfig)
Expand All @@ -49,6 +56,7 @@ func newConfig(apiKey string, opts ...ClientOption) ClientConfig {
HTTPClient: &http.Client{},

EmptyMessagesLimit: defaultEmptyMessagesLimit,
Adapter: &DefaultAdapter{},
}

for _, opt := range opts {
Expand All @@ -58,6 +66,13 @@ func newConfig(apiKey string, opts ...ClientOption) ClientConfig {
return c
}

func (c *ClientConfig) GetApiKey() string {
if c.apiKeyFunc != nil {
return c.apiKeyFunc()
}
return c.apiKey
}

func WithBaseURL(baseUrl string) ClientOption {
return func(c *ClientConfig) {
c.BaseURL = baseUrl
Expand Down Expand Up @@ -87,3 +102,22 @@ func WithBetaVersion(betaVersion ...BetaVersion) ClientOption {
c.BetaVersion = betaVersion
}
}

func WithVertexAI(projectID string, location string) ClientOption {
return func(c *ClientConfig) {
c.BaseURL = fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models",
location,
projectID,
location,
)
c.APIVersion = APIVersionVertex20231016
c.Adapter = &VertexAdapter{}
}
}

func WithApiKeyFunc(apiKeyFunc ApiKeyFunc) ClientOption {
return func(c *ClientConfig) {
c.apiKeyFunc = apiKeyFunc
}
}
2 changes: 1 addition & 1 deletion count_tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (c *Client) CountTokens(
}

urlSuffix := "/messages/count_tokens"
req, err := c.newRequest(ctx, http.MethodPost, urlSuffix, request, setters...)
req, err := c.newRequest(ctx, http.MethodPost, urlSuffix, &request, setters...)
if err != nil {
return
}
Expand Down
36 changes: 36 additions & 0 deletions defaultadapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package anthropic

import (
"fmt"
"net/http"
)

var _ ClientAdapter = (*DefaultAdapter)(nil)

type DefaultAdapter struct {
}

func (v *DefaultAdapter) TranslateError(resp *http.Response, body []byte) (error, bool) {
return nil, false
}

func (v *DefaultAdapter) fullURL(baseUrl string, suffix string) string {
// replace the first slash with a colon
return fmt.Sprintf("%s%s", baseUrl, suffix)
}

func (v *DefaultAdapter) PrepareRequest(
c *Client,
method string,
urlSuffix string,
body any,
) (string, error) {
return v.fullURL(c.config.BaseURL, urlSuffix), nil
}

func (v *DefaultAdapter) SetRequestHeaders(c *Client, req *http.Request) error {
req.Header.Set("X-Api-Key", c.config.GetApiKey())
req.Header.Set("Anthropic-Version", string(c.config.APIVersion))

return nil
}
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
module github.com/liushuangls/go-anthropic/v2

go 1.21

require golang.org/x/oauth2 v0.24.0

require cloud.google.com/go/compute/metadata v0.3.0 // indirect
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
18 changes: 18 additions & 0 deletions integrationtest/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,21 @@ func testAPIKey(t *testing.T) {
t.Fatal("ANTHROPIC_KEY must be set for integration tests")
}
}

var (
VertexAPIKey = os.Getenv("VERTEX_KEY")
VertexAPILocation = os.Getenv("VERTEX_LOCATION")
VertexAPIProject = os.Getenv("VERTEX_PROJECT")
)

func testVertexAPIKey(t *testing.T) {
if VertexAPIKey == "" {
t.Fatal("VERTEX_KEY must be set for integration tests")
}
if VertexAPILocation == "" {
t.Fatal("VERTEX_LOCATION must be set for integration tests")
}
if VertexAPIProject == "" {
t.Fatal("VERTEX_PROJECT must be set for integration tests")
}
}
Loading

0 comments on commit 63bad02

Please sign in to comment.