From 7326edfd93bd8439732f91b1aebd1d8d1508be88 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Wed, 6 Nov 2024 18:19:44 +0800 Subject: [PATCH] feat: support count tokens (#39) * feat: support count tokens * fix: lint * test: add count tokens test --- config.go | 10 ++-- count_tokens.go | 31 ++++++++++ count_tokens_test.go | 88 ++++++++++++++++++++++++++++ integrationtest/count_tokens_test.go | 35 +++++++++++ message.go | 2 +- 5 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 count_tokens.go create mode 100644 count_tokens_test.go create mode 100644 integrationtest/count_tokens_test.go diff --git a/config.go b/config.go index ea54dd7..dac8ed3 100644 --- a/config.go +++ b/config.go @@ -18,11 +18,11 @@ const ( type BetaVersion string const ( - BetaTools20240404 BetaVersion = "tools-2024-04-04" - BetaTools20240516 BetaVersion = "tools-2024-05-16" - BetaPromptCaching20240731 BetaVersion = "prompt-caching-2024-07-31" - BetaMessageBatches20240924 BetaVersion = "message-batches-2024-09-24" - + BetaTools20240404 BetaVersion = "tools-2024-04-04" + BetaTools20240516 BetaVersion = "tools-2024-05-16" + BetaPromptCaching20240731 BetaVersion = "prompt-caching-2024-07-31" + BetaMessageBatches20240924 BetaVersion = "message-batches-2024-09-24" + BetaTokenCounting20241101 BetaVersion = "token-counting-2024-11-01" BetaMaxTokens35Sonnet20240715 BetaVersion = "max-tokens-3-5-sonnet-2024-07-15" ) diff --git a/count_tokens.go b/count_tokens.go new file mode 100644 index 0000000..b3ec625 --- /dev/null +++ b/count_tokens.go @@ -0,0 +1,31 @@ +package anthropic + +import ( + "context" + "net/http" +) + +type CountTokensResponse struct { + httpHeader + + InputTokens int `json:"input_tokens"` +} + +func (c *Client) CountTokens( + ctx context.Context, + request MessagesRequest, +) (response CountTokensResponse, err error) { + var setters []requestSetter + if len(c.config.BetaVersion) > 0 { + setters = append(setters, withBetaVersion(c.config.BetaVersion...)) + } + + urlSuffix := "/messages/count_tokens" + req, err := c.newRequest(ctx, http.MethodPost, urlSuffix, request, setters...) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/count_tokens_test.go b/count_tokens_test.go new file mode 100644 index 0000000..597277a --- /dev/null +++ b/count_tokens_test.go @@ -0,0 +1,88 @@ +package anthropic_test + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "testing" + + "github.com/liushuangls/go-anthropic/v2" + "github.com/liushuangls/go-anthropic/v2/internal/test" +) + +func TestCountTokens(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/messages/count_tokens", handleCountTokens) + + ts := server.AnthropicTestServer() + ts.Start() + defer ts.Close() + + baseUrl := ts.URL + "/v1" + client := anthropic.NewClient( + test.GetTestToken(), + anthropic.WithBaseURL(baseUrl), + anthropic.WithBetaVersion(anthropic.BetaTokenCounting20241101), + ) + + request := anthropic.MessagesRequest{ + Model: anthropic.ModelClaude3Dot5HaikuLatest, + MultiSystem: anthropic.NewMultiSystemMessages("you are an assistant", "you are snarky"), + Messages: []anthropic.Message{ + anthropic.NewUserTextMessage("What is your name?"), + anthropic.NewAssistantTextMessage("My name is Claude."), + anthropic.NewUserTextMessage("What is your favorite color?"), + }, + } + + t.Run("count tokens success", func(t *testing.T) { + resp, err := client.CountTokens(context.Background(), request) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + + t.Logf("CountTokens resp: %+v", resp) + }) + + t.Run("count tokens failure", func(t *testing.T) { + request.MaxTokens = 10 + _, err := client.CountTokens(context.Background(), request) + if err == nil { + t.Fatalf("CountTokens expected error, got nil") + } + + t.Logf("CountTokens error: %v", err) + }) +} + +func handleCountTokens(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + var req anthropic.MessagesRequest + if req, err = getRequest[anthropic.MessagesRequest](r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + if req.MaxTokens > 0 { + http.Error(w, "max_tokens: Extra inputs are not permitted", http.StatusBadRequest) + return + } + + betaHeaders := r.Header.Get("Anthropic-Beta") + if !strings.Contains(betaHeaders, string(anthropic.BetaTokenCounting20241101)) { + http.Error(w, "missing beta version header", http.StatusBadRequest) + return + } + + res := anthropic.CountTokensResponse{ + InputTokens: 100, + } + resBytes, _ = json.Marshal(res) + _, _ = w.Write(resBytes) +} diff --git a/integrationtest/count_tokens_test.go b/integrationtest/count_tokens_test.go new file mode 100644 index 0000000..7613ef8 --- /dev/null +++ b/integrationtest/count_tokens_test.go @@ -0,0 +1,35 @@ +package integrationtest + +import ( + "context" + "testing" + + "github.com/liushuangls/go-anthropic/v2" +) + +func TestCountTokens(t *testing.T) { + testAPIKey(t) + client := anthropic.NewClient( + APIKey, + anthropic.WithBetaVersion(anthropic.BetaTokenCounting20241101), + ) + ctx := context.Background() + + request := anthropic.MessagesRequest{ + Model: anthropic.ModelClaude3Dot5HaikuLatest, + MultiSystem: anthropic.NewMultiSystemMessages("you are an assistant", "you are snarky"), + Messages: []anthropic.Message{ + anthropic.NewUserTextMessage("What is your name?"), + anthropic.NewAssistantTextMessage("My name is Claude."), + anthropic.NewUserTextMessage("What is your favorite color?"), + }, + } + + t.Run("CountTokens on real API", func(t *testing.T) { + resp, err := client.CountTokens(ctx, request) + if err != nil { + t.Fatalf("CountTokens error: %s", err) + } + t.Logf("CountTokens resp: %+v", resp) + }) +} diff --git a/message.go b/message.go index 052eea1..03594f4 100644 --- a/message.go +++ b/message.go @@ -36,7 +36,7 @@ const ( type MessagesRequest struct { Model Model `json:"model"` Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` + MaxTokens int `json:"max_tokens,omitempty"` System string `json:"-"` MultiSystem []MessageSystemPart `json:"-"`