Skip to content

Commit

Permalink
feat: support count tokens (#39)
Browse files Browse the repository at this point in the history
* feat: support count tokens

* fix: lint

* test: add count tokens test
  • Loading branch information
liushuangls authored Nov 6, 2024
1 parent c96c8c6 commit 7326edf
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 6 deletions.
10 changes: 5 additions & 5 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ const (
type BetaVersion string

const (
BetaTools20240404 BetaVersion = "tools-2024-04-04"
BetaTools20240516 BetaVersion = "tools-2024-05-16"
BetaPromptCaching20240731 BetaVersion = "prompt-caching-2024-07-31"
BetaMessageBatches20240924 BetaVersion = "message-batches-2024-09-24"

BetaTools20240404 BetaVersion = "tools-2024-04-04"
BetaTools20240516 BetaVersion = "tools-2024-05-16"
BetaPromptCaching20240731 BetaVersion = "prompt-caching-2024-07-31"
BetaMessageBatches20240924 BetaVersion = "message-batches-2024-09-24"
BetaTokenCounting20241101 BetaVersion = "token-counting-2024-11-01"
BetaMaxTokens35Sonnet20240715 BetaVersion = "max-tokens-3-5-sonnet-2024-07-15"
)

Expand Down
31 changes: 31 additions & 0 deletions count_tokens.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package anthropic

import (
"context"
"net/http"
)

type CountTokensResponse struct {
httpHeader

InputTokens int `json:"input_tokens"`
}

func (c *Client) CountTokens(
ctx context.Context,
request MessagesRequest,
) (response CountTokensResponse, err error) {
var setters []requestSetter
if len(c.config.BetaVersion) > 0 {
setters = append(setters, withBetaVersion(c.config.BetaVersion...))
}

urlSuffix := "/messages/count_tokens"
req, err := c.newRequest(ctx, http.MethodPost, urlSuffix, request, setters...)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}
88 changes: 88 additions & 0 deletions count_tokens_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package anthropic_test

import (
"context"
"encoding/json"
"net/http"
"strings"
"testing"

"github.com/liushuangls/go-anthropic/v2"
"github.com/liushuangls/go-anthropic/v2/internal/test"
)

func TestCountTokens(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/messages/count_tokens", handleCountTokens)

ts := server.AnthropicTestServer()
ts.Start()
defer ts.Close()

baseUrl := ts.URL + "/v1"
client := anthropic.NewClient(
test.GetTestToken(),
anthropic.WithBaseURL(baseUrl),
anthropic.WithBetaVersion(anthropic.BetaTokenCounting20241101),
)

request := anthropic.MessagesRequest{
Model: anthropic.ModelClaude3Dot5HaikuLatest,
MultiSystem: anthropic.NewMultiSystemMessages("you are an assistant", "you are snarky"),
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
anthropic.NewAssistantTextMessage("My name is Claude."),
anthropic.NewUserTextMessage("What is your favorite color?"),
},
}

t.Run("count tokens success", func(t *testing.T) {
resp, err := client.CountTokens(context.Background(), request)
if err != nil {
t.Fatalf("CountTokens error: %v", err)
}

t.Logf("CountTokens resp: %+v", resp)
})

t.Run("count tokens failure", func(t *testing.T) {
request.MaxTokens = 10
_, err := client.CountTokens(context.Background(), request)
if err == nil {
t.Fatalf("CountTokens expected error, got nil")
}

t.Logf("CountTokens error: %v", err)
})
}

func handleCountTokens(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}

var req anthropic.MessagesRequest
if req, err = getRequest[anthropic.MessagesRequest](r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
if req.MaxTokens > 0 {
http.Error(w, "max_tokens: Extra inputs are not permitted", http.StatusBadRequest)
return
}

betaHeaders := r.Header.Get("Anthropic-Beta")
if !strings.Contains(betaHeaders, string(anthropic.BetaTokenCounting20241101)) {
http.Error(w, "missing beta version header", http.StatusBadRequest)
return
}

res := anthropic.CountTokensResponse{
InputTokens: 100,
}
resBytes, _ = json.Marshal(res)
_, _ = w.Write(resBytes)
}
35 changes: 35 additions & 0 deletions integrationtest/count_tokens_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package integrationtest

import (
"context"
"testing"

"github.com/liushuangls/go-anthropic/v2"
)

func TestCountTokens(t *testing.T) {
testAPIKey(t)
client := anthropic.NewClient(
APIKey,
anthropic.WithBetaVersion(anthropic.BetaTokenCounting20241101),
)
ctx := context.Background()

request := anthropic.MessagesRequest{
Model: anthropic.ModelClaude3Dot5HaikuLatest,
MultiSystem: anthropic.NewMultiSystemMessages("you are an assistant", "you are snarky"),
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
anthropic.NewAssistantTextMessage("My name is Claude."),
anthropic.NewUserTextMessage("What is your favorite color?"),
},
}

t.Run("CountTokens on real API", func(t *testing.T) {
resp, err := client.CountTokens(ctx, request)
if err != nil {
t.Fatalf("CountTokens error: %s", err)
}
t.Logf("CountTokens resp: %+v", resp)
})
}
2 changes: 1 addition & 1 deletion message.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const (
type MessagesRequest struct {
Model Model `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
MaxTokens int `json:"max_tokens,omitempty"`

System string `json:"-"`
MultiSystem []MessageSystemPart `json:"-"`
Expand Down

0 comments on commit 7326edf

Please sign in to comment.