From be253c2d63b3df346d3521991241e5cb77e6a4cd Mon Sep 17 00:00:00 2001 From: GargantuaX Date: Thu, 11 May 2023 05:30:24 +0800 Subject: [PATCH 001/206] change azure engine config to modelMapper (#306) * change azure engine config to azure modelMapper config * Update go.mod * Revert "Update go.mod" This reverts commit 78d14c58f2a9ce668da43f6adbe20b60afcfe0d7. * lint fix * add test * lint fix * lint fix * lint fix * opt * opt * opt * opt --- api_internal_test.go | 17 +++++++----- audio.go | 2 +- chat.go | 2 +- chat_stream.go | 2 +- client.go | 22 ++++++++++++---- completion.go | 2 +- config.go | 28 +++++++++++++------- config_test.go | 62 ++++++++++++++++++++++++++++++++++++++++++++ edits.go | 3 ++- embeddings.go | 2 +- example_test.go | 3 +-- models_test.go | 2 +- moderation.go | 2 +- stream.go | 2 +- 14 files changed, 119 insertions(+), 32 deletions(-) create mode 100644 config_test.go diff --git a/api_internal_test.go b/api_internal_test.go index 9651ad402..529e7c7c4 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) { az.OrgID = c.OrgID cli := NewClientWithConfig(az) - req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) + req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "") if err != nil { t.Errorf("Failed to create request: %v", err) } @@ -109,14 +109,16 @@ func TestRequestAuthHeader(t *testing.T) { func TestAzureFullURL(t *testing.T) { cases := []struct { - Name string - BaseURL string - Engine string - Expect string + Name string + BaseURL string + AzureModelMapper map[string]string + Model string + Expect string }{ { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", + nil, "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -125,6 +127,7 @@ func TestAzureFullURL(t *testing.T) { { "AzureBaseURLWithoutSlashOK", "https://httpbin.org", + nil, "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -134,10 +137,10 @@ func TestAzureFullURL(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine) + az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL("/chat/completions", c.Model) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/audio.go b/audio.go index d22daf98c..12c6ccc22 100644 --- a/audio.go +++ b/audio.go @@ -68,7 +68,7 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody) if err != nil { return AudioResponse{}, err } diff --git a/chat.go b/chat.go index c09861c8c..312ef8e20 100644 --- a/chat.go +++ b/chat.go @@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 9ed0bc70a..f4fda882a 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -46,7 +46,7 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) + req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) if err != nil { return } diff --git a/client.go b/client.go index 0f8aa41ba..9579ba27b 100644 --- a/client.go +++ b/client.go @@ -98,8 +98,10 @@ func decodeString(body io.Reader, output *string) error { return nil } -func (c *Client) fullURL(suffix string) string { - // /openai/deployments/{engine}/chat/completions?api-version={api_version} +// fullURL returns full URL for request. +// args[0] is model name, if API type is Azure, model name is required to get deployment name. +func (c *Client) fullURL(suffix string, args ...any) string { + // /openai/deployments/{model}/chat/completions?api-version={api_version} if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") @@ -108,8 +110,17 @@ func (c *Client) fullURL(suffix string) string { if strings.Contains(suffix, "/models") { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } + azureDeploymentName := "UNKNOWN" + if len(args) > 0 { + model, ok := args[0].(string) + if ok { + azureDeploymentName = c.config.GetAzureDeploymentByModel(model) + } + } return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) + baseURL, azureAPIPrefix, azureDeploymentsPrefix, + azureDeploymentName, suffix, c.config.APIVersion, + ) } // c.config.APIType == APITypeOpenAI || c.config.APIType == "" @@ -120,8 +131,9 @@ func (c *Client) newStreamRequest( ctx context.Context, method string, urlSuffix string, - body any) (*http.Request, error) { - req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body) + body any, + model string) (*http.Request, error) { + req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body) if err != nil { return nil, err } diff --git a/completion.go b/completion.go index 5eec88c29..e3d1b85eb 100644 --- a/completion.go +++ b/completion.go @@ -155,7 +155,7 @@ func (c *Client) CreateCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) if err != nil { return } diff --git a/config.go b/config.go index c800df15c..fbcf377c0 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,7 @@ package openai import ( "net/http" + "regexp" ) const ( @@ -26,13 +27,12 @@ const AzureAPIKeyHeader = "api-key" type ClientConfig struct { authToken string - BaseURL string - OrgID string - APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD - Engine string // required when APIType is APITypeAzure or APITypeAzureAD - - HTTPClient *http.Client + BaseURL string + OrgID string + APIType APIType + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + AzureModelMapperFunc func(model string) string // replace model to azure deployment name func + HTTPClient *http.Client EmptyMessagesLimit uint } @@ -50,14 +50,16 @@ func DefaultConfig(authToken string) ClientConfig { } } -func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { +func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { return ClientConfig{ authToken: apiKey, BaseURL: baseURL, OrgID: "", APIType: APITypeAzure, APIVersion: "2023-03-15-preview", - Engine: engine, + AzureModelMapperFunc: func(model string) string { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + }, HTTPClient: &http.Client{}, @@ -68,3 +70,11 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { func (ClientConfig) String() string { return "" } + +func (c ClientConfig) GetAzureDeploymentByModel(model string) string { + if c.AzureModelMapperFunc != nil { + return c.AzureModelMapperFunc(model) + } + + return model +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 000000000..488511b11 --- /dev/null +++ b/config_test.go @@ -0,0 +1,62 @@ +package openai_test + +import ( + "testing" + + . "github.com/sashabaranov/go-openai" +) + +func TestGetAzureDeploymentByModel(t *testing.T) { + cases := []struct { + Model string + AzureModelMapperFunc func(model string) string + Expect string + }{ + { + Model: "gpt-3.5-turbo", + Expect: "gpt-35-turbo", + }, + { + Model: "gpt-3.5-turbo-0301", + Expect: "gpt-35-turbo-0301", + }, + { + Model: "text-embedding-ada-002", + Expect: "text-embedding-ada-002", + }, + { + Model: "", + Expect: "", + }, + { + Model: "models", + Expect: "models", + }, + { + Model: "gpt-3.5-turbo", + Expect: "my-gpt35", + AzureModelMapperFunc: func(model string) string { + modelmapper := map[string]string{ + "gpt-3.5-turbo": "my-gpt35", + } + if val, ok := modelmapper[model]; ok { + return val + } + return model + }, + }, + } + + for _, c := range cases { + t.Run(c.Model, func(t *testing.T) { + conf := DefaultAzureConfig("", "https://test.openai.azure.com/") + if c.AzureModelMapperFunc != nil { + conf.AzureModelMapperFunc = c.AzureModelMapperFunc + } + actual := conf.GetAzureDeploymentByModel(c.Model) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + }) + } +} diff --git a/edits.go b/edits.go index 858a8e537..c2c8db794 100644 --- a/edits.go +++ b/edits.go @@ -2,6 +2,7 @@ package openai import ( "context" + "fmt" "net/http" ) @@ -31,7 +32,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index 2deaccc3a..7fb432ead 100644 --- a/embeddings.go +++ b/embeddings.go @@ -132,7 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) if err != nil { return } diff --git a/example_test.go b/example_test.go index da253806d..b5dfafea9 100644 --- a/example_test.go +++ b/example_test.go @@ -305,8 +305,7 @@ func Example_chatbot() { func ExampleDefaultAzureConfig() { azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint - azureModel := os.Getenv("AZURE_OPENAI_MODEL") // Your model deployment name - config := openai.DefaultAzureConfig(azureKey, azureEndpoint, azureModel) + config := openai.DefaultAzureConfig(azureKey, azureEndpoint) client := openai.NewClientWithConfig(config) resp, err := client.CreateChatCompletion( context.Background(), diff --git a/models_test.go b/models_test.go index 70d6d756c..b017800b9 100644 --- a/models_test.go +++ b/models_test.go @@ -40,7 +40,7 @@ func TestAzureListModels(t *testing.T) { ts.Start() defer ts.Close() - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/", "dummyengine") + config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") config.BaseURL = ts.URL client := NewClientWithConfig(config) ctx := context.Background() diff --git a/moderation.go b/moderation.go index b386ddb95..ebd66afb9 100644 --- a/moderation.go +++ b/moderation.go @@ -63,7 +63,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) if err != nil { return } diff --git a/stream.go b/stream.go index 95662db6d..cd435faea 100644 --- a/stream.go +++ b/stream.go @@ -35,7 +35,7 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) + req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) if err != nil { return } From 83d03fca527288a2959783d3679e80526ecf60c6 Mon Sep 17 00:00:00 2001 From: "xuanming.zhang" Date: Mon, 15 May 2023 12:29:28 +0800 Subject: [PATCH 002/206] Adjust the azure model deployment name call corresponding to README (#309) --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 272d853c6..994f26dd1 100644 --- a/README.md +++ b/README.md @@ -436,7 +436,15 @@ import ( func main() { - config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint ", "your Model deployment name") + config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint") + //If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function + //config.AzureModelMapperFunc = func(model string) string { + // azureModelMapping = map[string]string{ + // "gpt-3.5-turbo":"your gpt-3.5-turbo deployment name", + // } + // return azureModelMapping[model] + //} + client := openai.NewClientWithConfig(config) resp, err := client.CreateChatCompletion( context.Background(), From 21eef5bc8dd8a919c487fffb17214c055396bfcd Mon Sep 17 00:00:00 2001 From: JoyShi <286753440@qq.com> Date: Wed, 17 May 2023 04:38:09 +0800 Subject: [PATCH 003/206] Move form_builder into internal pkg. (#311) * Move form_uilder into internal pkg. * Fix import of audio.go * Reorganize. * Fix import. * Fix --------- Co-authored-by: JoyShi --- audio.go | 20 ++++---- client.go | 8 +-- files.go | 8 +-- files_test.go | 3 +- form_builder.go | 49 ------------------- image.go | 28 +++++------ image_test.go | 13 ++--- internal/form_builder.go | 49 +++++++++++++++++++ .../form_builder_test.go | 8 +-- 9 files changed, 96 insertions(+), 90 deletions(-) delete mode 100644 form_builder.go create mode 100644 internal/form_builder.go rename form_builder_test.go => internal/form_builder_test.go (88%) diff --git a/audio.go b/audio.go index 12c6ccc22..bf2365391 100644 --- a/audio.go +++ b/audio.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" "os" + + utils "github.com/sashabaranov/go-openai/internal" ) // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. @@ -72,7 +74,7 @@ func (c *Client) callAudioAPI( if err != nil { return AudioResponse{}, err } - req.Header.Add("Content-Type", builder.formDataContentType()) + req.Header.Add("Content-Type", builder.FormDataContentType()) if request.HasJSONResponse() { err = c.sendRequest(req, &response) @@ -92,26 +94,26 @@ func (r AudioRequest) HasJSONResponse() bool { // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. -func audioMultipartForm(request AudioRequest, b formBuilder) error { +func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { f, err := os.Open(request.FilePath) if err != nil { return fmt.Errorf("opening audio file: %w", err) } defer f.Close() - err = b.createFormFile("file", f) + err = b.CreateFormFile("file", f) if err != nil { return fmt.Errorf("creating form file: %w", err) } - err = b.writeField("model", request.Model) + err = b.WriteField("model", request.Model) if err != nil { return fmt.Errorf("writing model name: %w", err) } // Create a form field for the prompt (if provided) if request.Prompt != "" { - err = b.writeField("prompt", request.Prompt) + err = b.WriteField("prompt", request.Prompt) if err != nil { return fmt.Errorf("writing prompt: %w", err) } @@ -119,7 +121,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the format (if provided) if request.Format != "" { - err = b.writeField("response_format", string(request.Format)) + err = b.WriteField("response_format", string(request.Format)) if err != nil { return fmt.Errorf("writing format: %w", err) } @@ -127,7 +129,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the temperature (if provided) if request.Temperature != 0 { - err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature)) + err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature)) if err != nil { return fmt.Errorf("writing temperature: %w", err) } @@ -135,12 +137,12 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the language (if provided) if request.Language != "" { - err = b.writeField("language", request.Language) + err = b.WriteField("language", request.Language) if err != nil { return fmt.Errorf("writing language: %w", err) } } // Close the multipart writer - return b.close() + return b.Close() } diff --git a/client.go b/client.go index 9579ba27b..c55166aa6 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,8 @@ import ( "io" "net/http" "strings" + + utils "github.com/sashabaranov/go-openai/internal" ) // Client is OpenAI GPT-3 API client. @@ -14,7 +16,7 @@ type Client struct { config ClientConfig requestBuilder requestBuilder - createFormBuilder func(io.Writer) formBuilder + createFormBuilder func(io.Writer) utils.FormBuilder } // NewClient creates new OpenAI API client. @@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, requestBuilder: newRequestBuilder(), - createFormBuilder: func(body io.Writer) formBuilder { - return newFormBuilder(body) + createFormBuilder: func(body io.Writer) utils.FormBuilder { + return utils.NewFormBuilder(body) }, } } diff --git a/files.go b/files.go index b701b9454..5667ec861 100644 --- a/files.go +++ b/files.go @@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File var b bytes.Buffer builder := c.createFormBuilder(&b) - err = builder.writeField("purpose", request.Purpose) + err = builder.WriteField("purpose", request.Purpose) if err != nil { return } @@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - err = builder.createFormFile("file", fileData) + err = builder.CreateFormFile("file", fileData) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &file) diff --git a/files_test.go b/files_test.go index bb06498c8..56dbb414f 100644 --- a/files_test.go +++ b/files_test.go @@ -1,6 +1,7 @@ package openai //nolint:testpackage // testing private field import ( + . "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { config.BaseURL = "" client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) FormBuilder { return mockBuilder } diff --git a/form_builder.go b/form_builder.go deleted file mode 100644 index 7fbb1643a..000000000 --- a/form_builder.go +++ /dev/null @@ -1,49 +0,0 @@ -package openai - -import ( - "io" - "mime/multipart" - "os" -) - -type formBuilder interface { - createFormFile(fieldname string, file *os.File) error - writeField(fieldname, value string) error - close() error - formDataContentType() string -} - -type defaultFormBuilder struct { - writer *multipart.Writer -} - -func newFormBuilder(body io.Writer) *defaultFormBuilder { - return &defaultFormBuilder{ - writer: multipart.NewWriter(body), - } -} - -func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error { - fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) - if err != nil { - return err - } - - _, err = io.Copy(fieldWriter, file) - if err != nil { - return err - } - return nil -} - -func (fb *defaultFormBuilder) writeField(fieldname, value string) error { - return fb.writer.WriteField(fieldname, value) -} - -func (fb *defaultFormBuilder) close() error { - return fb.writer.Close() -} - -func (fb *defaultFormBuilder) formDataContentType() string { - return fb.writer.FormDataContentType() -} diff --git a/image.go b/image.go index 21703bda7..87ffea25e 100644 --- a/image.go +++ b/image.go @@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) builder := c.createFormBuilder(body) // image - err = builder.createFormFile("image", request.Image) + err = builder.CreateFormFile("image", request.Image) if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.createFormFile("mask", request.Mask) + err = builder.CreateFormFile("mask", request.Mask) if err != nil { return } } - err = builder.writeField("prompt", request.Prompt) + err = builder.WriteField("prompt", request.Prompt) if err != nil { return } - err = builder.writeField("n", strconv.Itoa(request.N)) + err = builder.WriteField("n", strconv.Itoa(request.N)) if err != nil { return } - err = builder.writeField("size", request.Size) + err = builder.WriteField("size", request.Size) if err != nil { return } - err = builder.writeField("response_format", request.ResponseFormat) + err = builder.WriteField("response_format", request.ResponseFormat) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } @@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) builder := c.createFormBuilder(body) // image - err = builder.createFormFile("image", request.Image) + err = builder.CreateFormFile("image", request.Image) if err != nil { return } - err = builder.writeField("n", strconv.Itoa(request.N)) + err = builder.WriteField("n", strconv.Itoa(request.N)) if err != nil { return } - err = builder.writeField("size", request.Size) + err = builder.WriteField("size", request.Size) if err != nil { return } - err = builder.writeField("response_format", request.ResponseFormat) + err = builder.WriteField("response_format", request.ResponseFormat) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } diff --git a/image_test.go b/image_test.go index 4a7dad58f..5cf6a268d 100644 --- a/image_test.go +++ b/image_test.go @@ -1,6 +1,7 @@ package openai //nolint:testpackage // testing private field import ( + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -268,19 +269,19 @@ type mockFormBuilder struct { mockClose func() error } -func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error { +func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { return fb.mockCreateFormFile(fieldname, file) } -func (fb *mockFormBuilder) writeField(fieldname, value string) error { +func (fb *mockFormBuilder) WriteField(fieldname, value string) error { return fb.mockWriteField(fieldname, value) } -func (fb *mockFormBuilder) close() error { +func (fb *mockFormBuilder) Close() error { return fb.mockClose() } -func (fb *mockFormBuilder) formDataContentType() string { +func (fb *mockFormBuilder) FormDataContentType() string { return "" } @@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) { client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) utils.FormBuilder { return mockBuilder } ctx := context.Background() @@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) { client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) utils.FormBuilder { return mockBuilder } ctx := context.Background() diff --git a/internal/form_builder.go b/internal/form_builder.go new file mode 100644 index 000000000..359dd7e2a --- /dev/null +++ b/internal/form_builder.go @@ -0,0 +1,49 @@ +package openai + +import ( + "io" + "mime/multipart" + "os" +) + +type FormBuilder interface { + CreateFormFile(fieldname string, file *os.File) error + WriteField(fieldname, value string) error + Close() error + FormDataContentType() string +} + +type DefaultFormBuilder struct { + writer *multipart.Writer +} + +func NewFormBuilder(body io.Writer) *DefaultFormBuilder { + return &DefaultFormBuilder{ + writer: multipart.NewWriter(body), + } +} + +func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, file) + if err != nil { + return err + } + return nil +} + +func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { + return fb.writer.WriteField(fieldname, value) +} + +func (fb *DefaultFormBuilder) Close() error { + return fb.writer.Close() +} + +func (fb *DefaultFormBuilder) FormDataContentType() string { + return fb.writer.FormDataContentType() +} diff --git a/form_builder_test.go b/internal/form_builder_test.go similarity index 88% rename from form_builder_test.go rename to internal/form_builder_test.go index 78e2ec968..d3faf9982 100644 --- a/form_builder_test.go +++ b/internal/form_builder_test.go @@ -30,8 +30,8 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { defer file.Close() defer os.Remove(file.Name()) - builder := newFormBuilder(&failingWriter{}) - err = builder.createFormFile("file", file) + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFile("file", file) checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") } @@ -47,8 +47,8 @@ func TestFormBuilderWithClosedFile(t *testing.T) { defer os.Remove(file.Name()) body := &bytes.Buffer{} - builder := newFormBuilder(body) - err = builder.createFormFile("file", file) + builder := NewFormBuilder(body) + err = builder.CreateFormFile("file", file) checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") } From b62a325b0af8ac7eb887667b64279ccbecaf65c4 Mon Sep 17 00:00:00 2001 From: Takahiro Ikeuchi Date: Sat, 20 May 2023 04:04:16 +0900 Subject: [PATCH 004/206] Azure OpenAI API version 2023-05-15 (#316) * chore(config.go): update Azure API version to 2023-05-15 to use the latest version available * chore(api_internal_test.go): update Azure API version to 2023-05-15 to match the latest version --- api_internal_test.go | 4 ++-- config.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 529e7c7c4..214b627bf 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -122,7 +122,7 @@ func TestAzureFullURL(t *testing.T) { "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + - "/chat/completions?api-version=2023-03-15-preview", + "/chat/completions?api-version=2023-05-15", }, { "AzureBaseURLWithoutSlashOK", @@ -131,7 +131,7 @@ func TestAzureFullURL(t *testing.T) { "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + - "/chat/completions?api-version=2023-03-15-preview", + "/chat/completions?api-version=2023-05-15", }, } diff --git a/config.go b/config.go index fbcf377c0..c58b71ec6 100644 --- a/config.go +++ b/config.go @@ -56,7 +56,7 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { BaseURL: baseURL, OrgID: "", APIType: APITypeAzure, - APIVersion: "2023-03-15-preview", + APIVersion: "2023-05-15", AzureModelMapperFunc: func(model string) string { return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") }, From faae8b4b4beab65eb0178646a3ee3a4656ddd5b1 Mon Sep 17 00:00:00 2001 From: Tom Hennessy Date: Mon, 22 May 2023 05:17:16 +0100 Subject: [PATCH 005/206] Update README.md (#319) Added in `unofficial` to the README to make it clear it's not official. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 994f26dd1..caab5225a 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) [![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) -This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support: +This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: * ChatGPT * GPT-3, GPT-4 From a18c18d5e8381aae38b076afc136ba7207d283ef Mon Sep 17 00:00:00 2001 From: Rich Coggins Date: Mon, 22 May 2023 00:18:31 -0400 Subject: [PATCH 006/206] Update README.md with Azure OpenAI Embeddings example (#318) --- README.md | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/README.md b/README.md index caab5225a..8e76c1a3e 100644 --- a/README.md +++ b/README.md @@ -469,6 +469,54 @@ func main() { ``` +
+Azure OpenAI Embeddings + +```go +package main + +import ( + "context" + "fmt" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + + config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint") + config.APIVersion = "2023-05-15" // optional update to latest API version + + //If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function + //config.AzureModelMapperFunc = func(model string) string { + // azureModelMapping = map[string]string{ + // "gpt-3.5-turbo":"your gpt-3.5-turbo deployment name", + // } + // return azureModelMapping[model] + //} + + input := "Text to vectorize" + + client := openai.NewClientWithConfig(config) + resp, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Input: []string{input}, + Model: openai.AdaEmbeddingV2, + }) + + if err != nil { + fmt.Printf("CreateEmbeddings error: %v\n", err) + return + } + + vectors := resp.Data[0].Embedding // []float32 with 1536 dimensions + + fmt.Println(vectors[:10], "...", vectors[len(vectors)-10:]) +} +``` +
+
Error handling From 980504b47e043efdc464e6489ba583bd40540362 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sat, 27 May 2023 18:13:41 +0800 Subject: [PATCH 007/206] docs(readme): update format (#317) --- README.md | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 8e76c1a3e..7562694df 100644 --- a/README.md +++ b/README.md @@ -435,16 +435,15 @@ import ( ) func main() { - config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint") - //If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function - //config.AzureModelMapperFunc = func(model string) string { - // azureModelMapping = map[string]string{ - // "gpt-3.5-turbo":"your gpt-3.5-turbo deployment name", - // } - // return azureModelMapping[model] - //} - + // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function + // config.AzureModelMapperFunc = func(model string) string { + // azureModelMapping = map[string]string{ + // "gpt-3.5-turbo": "your gpt-3.5-turbo deployment name", + // } + // return azureModelMapping[model] + // } + client := openai.NewClientWithConfig(config) resp, err := client.CreateChatCompletion( context.Background(), @@ -458,7 +457,6 @@ func main() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return @@ -466,6 +464,7 @@ func main() { fmt.Println(resp.Choices[0].Message.Content) } + ```
From 62eb4beed29f4e821a59088f043fb56761d79fa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sun, 28 May 2023 10:51:07 +0900 Subject: [PATCH 008/206] move marshaller and unmarshaler into internal pkg (#304) (#325) --- chat_stream.go | 4 +++- error_accumulator.go | 8 +++++--- error_accumulator_test.go | 7 ++++--- files_test.go | 4 ++-- internal/marshaller.go | 15 +++++++++++++++ internal/unmarshaler.go | 15 +++++++++++++++ marshaller.go | 15 --------------- request_builder.go | 8 +++++--- request_builder_test.go | 2 +- stream.go | 4 +++- stream_reader.go | 6 ++++-- unmarshaler.go | 15 --------------- 12 files changed, 57 insertions(+), 46 deletions(-) create mode 100644 internal/marshaller.go create mode 100644 internal/unmarshaler.go delete mode 100644 marshaller.go delete mode 100644 unmarshaler.go diff --git a/chat_stream.go b/chat_stream.go index f4fda882a..842835e15 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -4,6 +4,8 @@ import ( "bufio" "context" "net/http" + + utils "github.com/sashabaranov/go-openai/internal" ) type ChatCompletionStreamChoiceDelta struct { @@ -65,7 +67,7 @@ func (c *Client) CreateChatCompletionStream( reader: bufio.NewReader(resp.Body), response: resp, errAccumulator: newErrorAccumulator(), - unmarshaler: &jsonUnmarshaler{}, + unmarshaler: &utils.JSONUnmarshaler{}, }, } return diff --git a/error_accumulator.go b/error_accumulator.go index ca6cec6e3..568afdbcd 100644 --- a/error_accumulator.go +++ b/error_accumulator.go @@ -4,6 +4,8 @@ import ( "bytes" "fmt" "io" + + utils "github.com/sashabaranov/go-openai/internal" ) type errorAccumulator interface { @@ -19,13 +21,13 @@ type errorBuffer interface { type defaultErrorAccumulator struct { buffer errorBuffer - unmarshaler unmarshaler + unmarshaler utils.Unmarshaler } func newErrorAccumulator() errorAccumulator { return &defaultErrorAccumulator{ buffer: &bytes.Buffer{}, - unmarshaler: &jsonUnmarshaler{}, + unmarshaler: &utils.JSONUnmarshaler{}, } } @@ -42,7 +44,7 @@ func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) { return } - err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp) + err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp) if err != nil { errResp = nil } diff --git a/error_accumulator_test.go b/error_accumulator_test.go index ecf954d58..821eb21b4 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -7,6 +7,7 @@ import ( "net/http" "testing" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -33,7 +34,7 @@ func (b *failingErrorBuffer) Bytes() []byte { return []byte{} } -func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error { +func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { return errTestUnmarshalerFailed } @@ -62,7 +63,7 @@ func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) { func TestErrorByteWriteErrors(t *testing.T) { accumulator := &defaultErrorAccumulator{ buffer: &failingErrorBuffer{}, - unmarshaler: &jsonUnmarshaler{}, + unmarshaler: &utils.JSONUnmarshaler{}, } err := accumulator.write([]byte("{")) if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { @@ -91,7 +92,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) { stream.errAccumulator = &defaultErrorAccumulator{ buffer: &failingErrorBuffer{}, - unmarshaler: &jsonUnmarshaler{}, + unmarshaler: &utils.JSONUnmarshaler{}, } _, err = stream.Recv() diff --git a/files_test.go b/files_test.go index 56dbb414f..ffdcfa798 100644 --- a/files_test.go +++ b/files_test.go @@ -1,7 +1,7 @@ package openai //nolint:testpackage // testing private field import ( - . "github.com/sashabaranov/go-openai/internal" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -86,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { config.BaseURL = "" client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) FormBuilder { + client.createFormBuilder = func(io.Writer) utils.FormBuilder { return mockBuilder } diff --git a/internal/marshaller.go b/internal/marshaller.go new file mode 100644 index 000000000..223a4dc1c --- /dev/null +++ b/internal/marshaller.go @@ -0,0 +1,15 @@ +package openai + +import ( + "encoding/json" +) + +type Marshaller interface { + Marshal(value any) ([]byte, error) +} + +type JSONMarshaller struct{} + +func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) { + return json.Marshal(value) +} diff --git a/internal/unmarshaler.go b/internal/unmarshaler.go new file mode 100644 index 000000000..882876022 --- /dev/null +++ b/internal/unmarshaler.go @@ -0,0 +1,15 @@ +package openai + +import ( + "encoding/json" +) + +type Unmarshaler interface { + Unmarshal(data []byte, v any) error +} + +type JSONUnmarshaler struct{} + +func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +} diff --git a/marshaller.go b/marshaller.go deleted file mode 100644 index 308ccd154..000000000 --- a/marshaller.go +++ /dev/null @@ -1,15 +0,0 @@ -package openai - -import ( - "encoding/json" -) - -type marshaller interface { - marshal(value any) ([]byte, error) -} - -type jsonMarshaller struct{} - -func (jm *jsonMarshaller) marshal(value any) ([]byte, error) { - return json.Marshal(value) -} diff --git a/request_builder.go b/request_builder.go index f0cef10fe..b4db07c2f 100644 --- a/request_builder.go +++ b/request_builder.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "net/http" + + utils "github.com/sashabaranov/go-openai/internal" ) type requestBuilder interface { @@ -11,12 +13,12 @@ type requestBuilder interface { } type httpRequestBuilder struct { - marshaller marshaller + marshaller utils.Marshaller } func newRequestBuilder() *httpRequestBuilder { return &httpRequestBuilder{ - marshaller: &jsonMarshaller{}, + marshaller: &utils.JSONMarshaller{}, } } @@ -26,7 +28,7 @@ func (b *httpRequestBuilder) build(ctx context.Context, method, url string, requ } var reqBytes []byte - reqBytes, err := b.marshaller.marshal(request) + reqBytes, err := b.marshaller.Marshal(request) if err != nil { return nil, err } diff --git a/request_builder_test.go b/request_builder_test.go index b1adbf1c6..ed4b69113 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -19,7 +19,7 @@ type ( failingMarshaller struct{} ) -func (*failingMarshaller) marshal(_ any) ([]byte, error) { +func (*failingMarshaller) Marshal(_ any) ([]byte, error) { return []byte{}, errTestMarshallerFailed } diff --git a/stream.go b/stream.go index cd435faea..b9e784acf 100644 --- a/stream.go +++ b/stream.go @@ -5,6 +5,8 @@ import ( "context" "errors" "net/http" + + utils "github.com/sashabaranov/go-openai/internal" ) var ( @@ -54,7 +56,7 @@ func (c *Client) CreateCompletionStream( reader: bufio.NewReader(resp.Body), response: resp, errAccumulator: newErrorAccumulator(), - unmarshaler: &jsonUnmarshaler{}, + unmarshaler: &utils.JSONUnmarshaler{}, }, } return diff --git a/stream_reader.go b/stream_reader.go index aa06f00ae..5eb6df7b8 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -6,6 +6,8 @@ import ( "fmt" "io" "net/http" + + utils "github.com/sashabaranov/go-openai/internal" ) type streamable interface { @@ -19,7 +21,7 @@ type streamReader[T streamable] struct { reader *bufio.Reader response *http.Response errAccumulator errorAccumulator - unmarshaler unmarshaler + unmarshaler utils.Unmarshaler } func (stream *streamReader[T]) Recv() (response T, err error) { @@ -63,7 +65,7 @@ waitForData: return } - err = stream.unmarshaler.unmarshal(line, &response) + err = stream.unmarshaler.Unmarshal(line, &response) return } diff --git a/unmarshaler.go b/unmarshaler.go deleted file mode 100644 index 05218f764..000000000 --- a/unmarshaler.go +++ /dev/null @@ -1,15 +0,0 @@ -package openai - -import ( - "encoding/json" -) - -type unmarshaler interface { - unmarshal(data []byte, v any) error -} - -type jsonUnmarshaler struct{} - -func (jm *jsonUnmarshaler) unmarshal(data []byte, v any) error { - return json.Unmarshal(data, v) -} From 61ba5f33698020982b4be97bf98fe654c043bea7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 31 May 2023 17:01:42 +0900 Subject: [PATCH 009/206] move request_builder into internal pkg (#304) (#329) * move request_builder into internal pkg (#304) * add some test for internal.RequestBuilder * add a test for openai.GetEngine --- chat.go | 2 +- client.go | 6 +- client_test.go | 149 +++++++++++++++ completion.go | 2 +- edits.go | 2 +- embeddings.go | 2 +- engines.go | 4 +- engines_test.go | 34 ++++ files.go | 6 +- fine_tunes.go | 12 +- image.go | 2 +- .../request_builder.go | 18 +- internal/request_builder_test.go | 61 ++++++ models.go | 2 +- moderation.go | 2 +- request_builder_test.go | 177 ------------------ 16 files changed, 273 insertions(+), 208 deletions(-) create mode 100644 engines_test.go rename request_builder.go => internal/request_builder.go (52%) create mode 100644 internal/request_builder_test.go delete mode 100644 request_builder_test.go diff --git a/chat.go b/chat.go index 312ef8e20..a7ce5486a 100644 --- a/chat.go +++ b/chat.go @@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) if err != nil { return } diff --git a/client.go b/client.go index c55166aa6..2486e36b6 100644 --- a/client.go +++ b/client.go @@ -15,7 +15,7 @@ import ( type Client struct { config ClientConfig - requestBuilder requestBuilder + requestBuilder utils.RequestBuilder createFormBuilder func(io.Writer) utils.FormBuilder } @@ -29,7 +29,7 @@ func NewClient(authToken string) *Client { func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, - requestBuilder: newRequestBuilder(), + requestBuilder: utils.NewRequestBuilder(), createFormBuilder: func(body io.Writer) utils.FormBuilder { return utils.NewFormBuilder(body) }, @@ -135,7 +135,7 @@ func (c *Client) newStreamRequest( urlSuffix string, body any, model string) (*http.Request, error) { - req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body) + req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body) if err != nil { return nil, err } diff --git a/client_test.go b/client_test.go index e30fa399b..862cbe856 100644 --- a/client_test.go +++ b/client_test.go @@ -2,13 +2,24 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" "errors" "fmt" "io" "net/http" "testing" + + "github.com/sashabaranov/go-openai/internal/test" ) +var errTestRequestBuilderFailed = errors.New("test request builder failed") + +type failingRequestBuilder struct{} + +func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) { + return nil, errTestRequestBuilderFailed +} + func TestClient(t *testing.T) { const mockToken = "mock token" client := NewClient(mockToken) @@ -145,3 +156,141 @@ func TestHandleErrorResp(t *testing.T) { }) } } + +func TestClientReturnsRequestBuilderErrors(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + ctx := context.Background() + + _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFineTunes(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CancelFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.DeleteFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFineTuneEvents(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.Moderations(ctx, ModerationRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.Edits(ctx, EditsRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateImage(ctx, ImageRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + err = client.DeleteFile(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetFile(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFiles(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListEngines(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetEngine(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListModels(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } +} + +func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + ctx := context.Background() + + _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) + if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) + if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } +} diff --git a/completion.go b/completion.go index e3d1b85eb..de1360fd9 100644 --- a/completion.go +++ b/completion.go @@ -155,7 +155,7 @@ func (c *Client) CreateCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) if err != nil { return } diff --git a/edits.go b/edits.go index c2c8db794..23b1a64f0 100644 --- a/edits.go +++ b/edits.go @@ -32,7 +32,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index 7fb432ead..942f3ea3a 100644 --- a/embeddings.go +++ b/embeddings.go @@ -132,7 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) if err != nil { return } diff --git a/engines.go b/engines.go index bb6a66ce4..ac01a00ed 100644 --- a/engines.go +++ b/engines.go @@ -22,7 +22,7 @@ type EnginesList struct { // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil) if err != nil { return } @@ -38,7 +38,7 @@ func (c *Client) GetEngine( engineID string, ) (engine Engine, err error) { urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } diff --git a/engines_test.go b/engines_test.go new file mode 100644 index 000000000..dfa3187cf --- /dev/null +++ b/engines_test.go @@ -0,0 +1,34 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. +func TestGetEngine(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(Engine{}) + fmt.Fprintln(w, string(resBytes)) + }) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.GetEngine(ctx, "text-davinci-003") + checks.NoError(t, err, "GetEngine error") +} diff --git a/files.go b/files.go index 5667ec861..36c024365 100644 --- a/files.go +++ b/files.go @@ -70,7 +70,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) if err != nil { return } @@ -82,7 +82,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil) if err != nil { return } @@ -95,7 +95,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } diff --git a/fine_tunes.go b/fine_tunes.go index a1218670f..069ddccfd 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct { func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } @@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // CancelFineTune cancel a fine-tune job. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) if err != nil { return } @@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) if err != nil { return } @@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } @@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F } func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) if err != nil { return } @@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) if err != nil { return } diff --git a/image.go b/image.go index 87ffea25e..df7363865 100644 --- a/image.go +++ b/image.go @@ -44,7 +44,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } diff --git a/request_builder.go b/internal/request_builder.go similarity index 52% rename from request_builder.go rename to internal/request_builder.go index b4db07c2f..0a9eabfde 100644 --- a/request_builder.go +++ b/internal/request_builder.go @@ -4,25 +4,23 @@ import ( "bytes" "context" "net/http" - - utils "github.com/sashabaranov/go-openai/internal" ) -type requestBuilder interface { - build(ctx context.Context, method, url string, request any) (*http.Request, error) +type RequestBuilder interface { + Build(ctx context.Context, method, url string, request any) (*http.Request, error) } -type httpRequestBuilder struct { - marshaller utils.Marshaller +type HTTPRequestBuilder struct { + marshaller Marshaller } -func newRequestBuilder() *httpRequestBuilder { - return &httpRequestBuilder{ - marshaller: &utils.JSONMarshaller{}, +func NewRequestBuilder() *HTTPRequestBuilder { + return &HTTPRequestBuilder{ + marshaller: &JSONMarshaller{}, } } -func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) { +func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) { if request == nil { return http.NewRequestWithContext(ctx, method, url, nil) } diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go new file mode 100644 index 000000000..e47d0f6ca --- /dev/null +++ b/internal/request_builder_test.go @@ -0,0 +1,61 @@ +package openai //nolint:testpackage // testing private field + +import ( + "bytes" + "context" + "errors" + "net/http" + "reflect" + "testing" +) + +var errTestMarshallerFailed = errors.New("test marshaller failed") + +type failingMarshaller struct{} + +func (*failingMarshaller) Marshal(_ any) ([]byte, error) { + return []byte{}, errTestMarshallerFailed +} + +func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { + builder := HTTPRequestBuilder{ + marshaller: &failingMarshaller{}, + } + + _, err := builder.Build(context.Background(), "", "", struct{}{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } +} + +func TestRequestBuilderReturnsRequest(t *testing.T) { + b := NewRequestBuilder() + var ( + ctx = context.Background() + method = http.MethodPost + url = "/foo" + request = map[string]string{"foo": "bar"} + reqBytes, _ = b.marshaller.Marshal(request) + want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) + ) + got, _ := b.Build(ctx, method, url, request) + if !reflect.DeepEqual(got.Body, want.Body) || + !reflect.DeepEqual(got.URL, want.URL) || + !reflect.DeepEqual(got.Method, want.Method) { + t.Errorf("Build() got = %v, want %v", got, want) + } +} + +func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { + var ( + ctx = context.Background() + method = http.MethodGet + url = "/foo" + want, _ = http.NewRequestWithContext(ctx, method, url, nil) + ) + b := NewRequestBuilder() + got, _ := b.Build(ctx, method, url, nil) + if !reflect.DeepEqual(got, want) { + t.Errorf("Build() got = %v, want %v", got, want) + } +} diff --git a/models.go b/models.go index 2be91aadb..485433b26 100644 --- a/models.go +++ b/models.go @@ -40,7 +40,7 @@ type ModelsList struct { // ListModels Lists the currently available models, // and provides basic information about each model such as the model id and parent. func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil) if err != nil { return } diff --git a/moderation.go b/moderation.go index ebd66afb9..bae788035 100644 --- a/moderation.go +++ b/moderation.go @@ -63,7 +63,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) + req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) if err != nil { return } diff --git a/request_builder_test.go b/request_builder_test.go deleted file mode 100644 index ed4b69113..000000000 --- a/request_builder_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "errors" - "net/http" - "testing" -) - -var ( - errTestMarshallerFailed = errors.New("test marshaller failed") - errTestRequestBuilderFailed = errors.New("test request builder failed") -) - -type ( - failingRequestBuilder struct{} - failingMarshaller struct{} -) - -func (*failingMarshaller) Marshal(_ any) ([]byte, error) { - return []byte{}, errTestMarshallerFailed -} - -func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) { - return nil, errTestRequestBuilderFailed -} - -func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { - builder := httpRequestBuilder{ - marshaller: &failingMarshaller{}, - } - - _, err := builder.build(context.Background(), "", "", struct{}{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } -} - -func TestClientReturnsRequestBuilderErrors(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - client.requestBuilder = &failingRequestBuilder{} - - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTunes(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CancelFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.DeleteFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTuneEvents(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Moderations(ctx, ModerationRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Edits(ctx, EditsRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateImage(ctx, ImageRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - err = client.DeleteFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFiles(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListEngines(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetEngine(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListModels(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } -} - -func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - client.requestBuilder = &failingRequestBuilder{} - - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) - if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) - if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } -} From fa694c61c2196e471e7bedc689b546aa52caf527 Mon Sep 17 00:00:00 2001 From: Mariano Darc Date: Mon, 5 Jun 2023 08:07:13 +0200 Subject: [PATCH 010/206] Implement optional io.Reader in AudioRequest (#303) (#265) (#331) * Implement optional io.Reader in AudioRequest (#303) (#265) * Fix err shadowing * Add test to cover AudioRequest io.Reader usage * Add additional test cases to cover AudioRequest io.Reader usage * Add test to cover opening the file specified in an AudioRequest --- audio.go | 45 +++++++++++++++++++++------ audio_test.go | 66 ++++++++++++++++++++++++++++++++++++++-- image_test.go | 11 +++++-- internal/form_builder.go | 20 ++++++++++-- 4 files changed, 124 insertions(+), 18 deletions(-) diff --git a/audio.go b/audio.go index bf2365391..20e865f11 100644 --- a/audio.go +++ b/audio.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "net/http" "os" @@ -27,8 +28,14 @@ const ( // AudioRequest represents a request structure for audio API. // ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { - Model string - FilePath string + Model string + + // FilePath is either an existing file in your filesystem or a filename representing the contents of Reader. + FilePath string + + // Reader is an optional io.Reader when you do not want to use an existing file. + Reader io.Reader + Prompt string // For translation, it should be in English Temperature float32 Language string // For translation, just do not use it. It seems "en" works, not confirmed... @@ -95,15 +102,9 @@ func (r AudioRequest) HasJSONResponse() bool { // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { - f, err := os.Open(request.FilePath) - if err != nil { - return fmt.Errorf("opening audio file: %w", err) - } - defer f.Close() - - err = b.CreateFormFile("file", f) + err := createFileField(request, b) if err != nil { - return fmt.Errorf("creating form file: %w", err) + return err } err = b.WriteField("model", request.Model) @@ -146,3 +147,27 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { // Close the multipart writer return b.Close() } + +// createFileField creates the "file" form field from either an existing file or by using the reader. +func createFileField(request AudioRequest, b utils.FormBuilder) error { + if request.Reader != nil { + err := b.CreateFormFileReader("file", request.Reader, request.FilePath) + if err != nil { + return fmt.Errorf("creating form using reader: %w", err) + } + return nil + } + + f, err := os.Open(request.FilePath) + if err != nil { + return fmt.Errorf("opening audio file: %w", err) + } + defer f.Close() + + err = b.CreateFormFile("file", f) + if err != nil { + return fmt.Errorf("creating form file: %w", err) + } + + return nil +} diff --git a/audio_test.go b/audio_test.go index daf51f28c..6452e2eb7 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" "errors" "fmt" "io" @@ -11,12 +12,10 @@ import ( "os" "path/filepath" "strings" + "testing" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" - - "context" - "testing" ) // TestAudio Tests the transcription and translation endpoints of the API using the mocked server. @@ -65,6 +64,16 @@ func TestAudio(t *testing.T) { _, err = tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") }) + + t.Run(tc.name+" (with reader)", func(t *testing.T) { + req := AudioRequest{ + FilePath: "fake.webm", + Reader: bytes.NewBuffer([]byte(`some webm binary data`)), + Model: "whisper-3", + } + _, err = tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) } } @@ -213,3 +222,54 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") } } + +func TestCreateFileField(t *testing.T) { + t.Run("createFileField failing file", func(t *testing.T) { + dir, cleanup := test.CreateTestDirectory(t) + defer cleanup() + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + } + + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder := &mockFormBuilder{ + mockCreateFormFile: func(string, *os.File) error { + return mockFailedErr + }, + } + + err := createFileField(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails") + }) + + t.Run("createFileField failing reader", func(t *testing.T) { + req := AudioRequest{ + FilePath: "test.wav", + Reader: bytes.NewBuffer([]byte(`wav test contents`)), + } + + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { + return mockFailedErr + }, + } + + err := createFileField(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails") + }) + + t.Run("createFileField failing open", func(t *testing.T) { + req := AudioRequest{ + FilePath: "non_existing_file.wav", + } + + mockBuilder := &mockFormBuilder{} + + err := createFileField(req, mockBuilder) + checks.HasError(t, err, "createFileField using file should return error when open file fails") + }) +} diff --git a/image_test.go b/image_test.go index 5cf6a268d..ca9faed95 100644 --- a/image_test.go +++ b/image_test.go @@ -264,15 +264,20 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { } type mockFormBuilder struct { - mockCreateFormFile func(string, *os.File) error - mockWriteField func(string, string) error - mockClose func() error + mockCreateFormFile func(string, *os.File) error + mockCreateFormFileReader func(string, io.Reader, string) error + mockWriteField func(string, string) error + mockClose func() error } func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { return fb.mockCreateFormFile(fieldname, file) } +func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + return fb.mockCreateFormFileReader(fieldname, r, filename) +} + func (fb *mockFormBuilder) WriteField(fieldname, value string) error { return fb.mockWriteField(fieldname, value) } diff --git a/internal/form_builder.go b/internal/form_builder.go index 359dd7e2a..2224fad45 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -1,13 +1,16 @@ package openai import ( + "fmt" "io" "mime/multipart" "os" + "path" ) type FormBuilder interface { CreateFormFile(fieldname string, file *os.File) error + CreateFormFileReader(fieldname string, r io.Reader, filename string) error WriteField(fieldname, value string) error Close() error FormDataContentType() string @@ -24,15 +27,28 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder { } func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { - fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) + return fb.createFormFile(fieldname, file, file.Name()) +} + +func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + return fb.createFormFile(fieldname, r, path.Base(filename)) +} + +func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { + if filename == "" { + return fmt.Errorf("filename cannot be empty") + } + + fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename) if err != nil { return err } - _, err = io.Copy(fieldWriter, file) + _, err = io.Copy(fieldWriter, r) if err != nil { return err } + return nil } From 1394329e44ef4174777acc3692b4af2a40a217c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 5 Jun 2023 23:35:46 +0900 Subject: [PATCH 011/206] move error_accumulator into internal pkg (#304) (#335) * move error_accumulator into internal pkg (#304) * move error_accumulator into internal pkg (#304) * add a test for ErrTooManyEmptyStreamMessages in stream_reader (#304) --- chat_stream.go | 2 +- chat_stream_test.go | 49 ++++++++++---- error_accumulator.go | 53 --------------- error_accumulator_test.go | 100 ----------------------------- internal/error_accumulator.go | 44 +++++++++++++ internal/error_accumulator_test.go | 41 ++++++++++++ internal/test/failer.go | 21 ++++++ internal/test/helpers.go | 24 +++++++ stream.go | 2 +- stream_reader.go | 20 +++++- stream_reader_test.go | 53 +++++++++++++++ stream_test.go | 41 +++--------- 12 files changed, 249 insertions(+), 201 deletions(-) delete mode 100644 error_accumulator.go delete mode 100644 error_accumulator_test.go create mode 100644 internal/error_accumulator.go create mode 100644 internal/error_accumulator_test.go create mode 100644 internal/test/failer.go create mode 100644 stream_reader_test.go diff --git a/chat_stream.go b/chat_stream.go index 842835e15..9378c7124 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -66,7 +66,7 @@ func (c *Client) CreateChatCompletionStream( emptyMessagesLimit: c.config.EmptyMessagesLimit, reader: bufio.NewReader(resp.Body), response: resp, - errAccumulator: newErrorAccumulator(), + errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, }, } diff --git a/chat_stream_test.go b/chat_stream_test.go index afcb86d5e..77d373c6a 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,7 +1,7 @@ -package openai_test +package openai //nolint:testpackage // testing private field import ( - . "github.com/sashabaranov/go-openai" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -63,9 +63,9 @@ func TestCreateChatCompletionStream(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -170,9 +170,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -227,9 +227,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -255,6 +255,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) { + var err error + server := test.NewTestServer() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "error", 200) + }) + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + + ctx := context.Background() + + stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) + checks.NoError(t, err) + + stream.errAccumulator = &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + } + + _, err = stream.Recv() + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error()) +} + // Helper funcs. func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/error_accumulator.go b/error_accumulator.go deleted file mode 100644 index 568afdbcd..000000000 --- a/error_accumulator.go +++ /dev/null @@ -1,53 +0,0 @@ -package openai - -import ( - "bytes" - "fmt" - "io" - - utils "github.com/sashabaranov/go-openai/internal" -) - -type errorAccumulator interface { - write(p []byte) error - unmarshalError() *ErrorResponse -} - -type errorBuffer interface { - io.Writer - Len() int - Bytes() []byte -} - -type defaultErrorAccumulator struct { - buffer errorBuffer - unmarshaler utils.Unmarshaler -} - -func newErrorAccumulator() errorAccumulator { - return &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, - } -} - -func (e *defaultErrorAccumulator) write(p []byte) error { - _, err := e.buffer.Write(p) - if err != nil { - return fmt.Errorf("error accumulator write error, %w", err) - } - return nil -} - -func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) { - if e.buffer.Len() == 0 { - return - } - - err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp) - if err != nil { - errResp = nil - } - - return -} diff --git a/error_accumulator_test.go b/error_accumulator_test.go deleted file mode 100644 index 821eb21b4..000000000 --- a/error_accumulator_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "bytes" - "context" - "errors" - "net/http" - "testing" - - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" -) - -var ( - errTestUnmarshalerFailed = errors.New("test unmarshaler failed") - errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") -) - -type ( - failingUnMarshaller struct{} - failingErrorBuffer struct{} -) - -func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) { - return 0, errTestErrorAccumulatorWriteFailed -} - -func (b *failingErrorBuffer) Len() int { - return 0 -} - -func (b *failingErrorBuffer) Bytes() []byte { - return []byte{} -} - -func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { - return errTestUnmarshalerFailed -} - -func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) { - accumulator := &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - unmarshaler: &failingUnMarshaller{}, - } - - respErr := accumulator.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil with empty buffer: %v", respErr) - } - - err := accumulator.write([]byte("{")) - if err != nil { - t.Fatalf("%+v", err) - } - - respErr = accumulator.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) - } -} - -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, - } - err := accumulator.write([]byte("{")) - if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) - } -} - -func TestErrorAccumulatorWriteErrors(t *testing.T) { - var err error - server := test.NewTestServer() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "error", 200) - }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - ctx := context.Background() - - stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - checks.NoError(t, err) - - stream.errAccumulator = &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - unmarshaler: &utils.JSONUnmarshaler{}, - } - - _, err = stream.Recv() - checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) -} diff --git a/internal/error_accumulator.go b/internal/error_accumulator.go new file mode 100644 index 000000000..3d3e805fe --- /dev/null +++ b/internal/error_accumulator.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "fmt" + "io" +) + +type ErrorAccumulator interface { + Write(p []byte) error + Bytes() []byte +} + +type errorBuffer interface { + io.Writer + Len() int + Bytes() []byte +} + +type DefaultErrorAccumulator struct { + Buffer errorBuffer +} + +func NewErrorAccumulator() ErrorAccumulator { + return &DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } +} + +func (e *DefaultErrorAccumulator) Write(p []byte) error { + _, err := e.Buffer.Write(p) + if err != nil { + return fmt.Errorf("error accumulator write error, %w", err) + } + return nil +} + +func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { + if e.Buffer.Len() == 0 { + return + } + errBytes = e.Buffer.Bytes() + return +} diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go new file mode 100644 index 000000000..d48f28177 --- /dev/null +++ b/internal/error_accumulator_test.go @@ -0,0 +1,41 @@ +package openai_test + +import ( + "bytes" + "errors" + "testing" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" +) + +func TestErrorAccumulatorBytes(t *testing.T) { + accumulator := &utils.DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } + + errBytes := accumulator.Bytes() + if len(errBytes) != 0 { + t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) + } + + err := accumulator.Write([]byte("{}")) + if err != nil { + t.Fatalf("%+v", err) + } + + errBytes = accumulator.Bytes() + if len(errBytes) == 0 { + t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) + } +} + +func TestErrorByteWriteErrors(t *testing.T) { + accumulator := &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + } + err := accumulator.Write([]byte("{")) + if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("Did not return error when write failed: %v", err) + } +} diff --git a/internal/test/failer.go b/internal/test/failer.go new file mode 100644 index 000000000..10ca64e34 --- /dev/null +++ b/internal/test/failer.go @@ -0,0 +1,21 @@ +package test + +import "errors" + +var ( + ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") +) + +type FailingErrorBuffer struct{} + +func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) { + return 0, ErrTestErrorAccumulatorWriteFailed +} + +func (b *FailingErrorBuffer) Len() int { + return 0 +} + +func (b *FailingErrorBuffer) Bytes() []byte { + return []byte{} +} diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 8461e5374..0e63ae82f 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -3,6 +3,7 @@ package test import ( "github.com/sashabaranov/go-openai/internal/test/checks" + "net/http" "os" "testing" ) @@ -27,3 +28,26 @@ func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { return path, func() { os.RemoveAll(path) } } + +// TokenRoundTripper is a struct that implements the RoundTripper +// interface, specifically to handle the authentication token by adding a token +// to the request header. We need this because the API requires that each +// request include a valid API token in the headers for authentication and +// authorization. +type TokenRoundTripper struct { + Token string + Fallback http.RoundTripper +} + +// RoundTrip takes an *http.Request as input and returns an +// *http.Response and an error. +// +// It is expected to use the provided request to create a connection to an HTTP +// server and return the response, or an error if one occurred. The returned +// Response should have its Body closed. If the RoundTrip method returns an +// error, the Client's Get, Head, Post, and PostForm methods return the same +// error. +func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+t.Token) + return t.Fallback.RoundTrip(req) +} diff --git a/stream.go b/stream.go index b9e784acf..d4e352314 100644 --- a/stream.go +++ b/stream.go @@ -55,7 +55,7 @@ func (c *Client) CreateCompletionStream( emptyMessagesLimit: c.config.EmptyMessagesLimit, reader: bufio.NewReader(resp.Body), response: resp, - errAccumulator: newErrorAccumulator(), + errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, }, } diff --git a/stream_reader.go b/stream_reader.go index 5eb6df7b8..a9940b0ae 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -20,7 +20,7 @@ type streamReader[T streamable] struct { reader *bufio.Reader response *http.Response - errAccumulator errorAccumulator + errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler } @@ -35,7 +35,7 @@ func (stream *streamReader[T]) Recv() (response T, err error) { waitForData: line, err := stream.reader.ReadBytes('\n') if err != nil { - respErr := stream.errAccumulator.unmarshalError() + respErr := stream.unmarshalError() if respErr != nil { err = fmt.Errorf("error, %w", respErr.Error) } @@ -45,7 +45,7 @@ waitForData: var headerData = []byte("data: ") line = bytes.TrimSpace(line) if !bytes.HasPrefix(line, headerData) { - if writeErr := stream.errAccumulator.write(line); writeErr != nil { + if writeErr := stream.errAccumulator.Write(line); writeErr != nil { err = writeErr return } @@ -69,6 +69,20 @@ waitForData: return } +func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { + errBytes := stream.errAccumulator.Bytes() + if len(errBytes) == 0 { + return + } + + err := stream.unmarshaler.Unmarshal(errBytes, &errResp) + if err != nil { + errResp = nil + } + + return +} + func (stream *streamReader[T]) Close() { stream.response.Body.Close() } diff --git a/stream_reader_test.go b/stream_reader_test.go new file mode 100644 index 000000000..0e45c0b73 --- /dev/null +++ b/stream_reader_test.go @@ -0,0 +1,53 @@ +package openai //nolint:testpackage // testing private field + +import ( + "bufio" + "bytes" + "errors" + "testing" + + utils "github.com/sashabaranov/go-openai/internal" +) + +var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") + +type failingUnMarshaller struct{} + +func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { + return errTestUnmarshalerFailed +} + +func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &failingUnMarshaller{}, + } + + respErr := stream.unmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil with empty buffer: %v", respErr) + } + + err := stream.errAccumulator.Write([]byte("{")) + if err != nil { + t.Fatalf("%+v", err) + } + + respErr = stream.unmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) + } +} + +func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + emptyMessagesLimit: 3, + reader: bufio.NewReader(bytes.NewReader([]byte("\n\n\n\n"))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + } + _, err := stream.Recv() + if !errors.Is(err, ErrTooManyEmptyStreamMessages) { + t.Fatalf("Did not return error when recv failed: %v", err) + } +} diff --git a/stream_test.go b/stream_test.go index a5c591fde..589fc9e26 100644 --- a/stream_test.go +++ b/stream_test.go @@ -57,9 +57,9 @@ func TestCreateCompletionStream(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -142,9 +142,9 @@ func TestCreateCompletionStreamError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -194,9 +194,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { // Client portion of the test config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, } client := NewClientWithConfig(config) @@ -217,29 +217,6 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } -// A "tokenRoundTripper" is a struct that implements the RoundTripper -// interface, specifically to handle the authentication token by adding a token -// to the request header. We need this because the API requires that each -// request include a valid API token in the headers for authentication and -// authorization. -type tokenRoundTripper struct { - token string - fallback http.RoundTripper -} - -// RoundTrip takes an *http.Request as input and returns an -// *http.Response and an error. -// -// It is expected to use the provided request to create a connection to an HTTP -// server and return the response, or an error if one occurred. The returned -// Response should have its Body closed. If the RoundTrip method returns an -// error, the Client's Get, Head, Post, and PostForm methods return the same -// error. -func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("Authorization", "Bearer "+t.token) - return t.fallback.RoundTrip(req) -} - // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { From 6830e0040677bde88d1ed40992667cfabb017910 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 5 Jun 2023 23:37:08 +0900 Subject: [PATCH 012/206] Support Retrieve model API (#340) (#341) * Support Retrieve model API (#340) * Test for GetModel error cases. (#340) * Reduce the cognitive complexity of TestClientReturnsRequestBuilderErrors (#340) --- client_test.go | 166 +++++++++++++++++++++---------------------------- models.go | 14 +++++ models_test.go | 41 ++++++++++++ 3 files changed, 127 insertions(+), 94 deletions(-) diff --git a/client_test.go b/client_test.go index 862cbe856..5e63539df 100644 --- a/client_test.go +++ b/client_test.go @@ -170,104 +170,82 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { ctx := context.Background() - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTunes(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CancelFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.DeleteFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTuneEvents(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Moderations(ctx, ModerationRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Edits(ctx, EditsRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) + type TestCase struct { + Name string + TestFunc func() (any, error) } - _, err = client.CreateImage(ctx, ImageRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - err = client.DeleteFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) + testCases := []TestCase{ + {"CreateCompletion", func() (any, error) { + return client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) + }}, + {"CreateCompletionStream", func() (any, error) { + return client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) + }}, + {"CreateChatCompletion", func() (any, error) { + return client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + }}, + {"CreateChatCompletionStream", func() (any, error) { + return client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + }}, + {"CreateFineTune", func() (any, error) { + return client.CreateFineTune(ctx, FineTuneRequest{}) + }}, + {"ListFineTunes", func() (any, error) { + return client.ListFineTunes(ctx) + }}, + {"CancelFineTune", func() (any, error) { + return client.CancelFineTune(ctx, "") + }}, + {"GetFineTune", func() (any, error) { + return client.GetFineTune(ctx, "") + }}, + {"DeleteFineTune", func() (any, error) { + return client.DeleteFineTune(ctx, "") + }}, + {"ListFineTuneEvents", func() (any, error) { + return client.ListFineTuneEvents(ctx, "") + }}, + {"Moderations", func() (any, error) { + return client.Moderations(ctx, ModerationRequest{}) + }}, + {"Edits", func() (any, error) { + return client.Edits(ctx, EditsRequest{}) + }}, + {"CreateEmbeddings", func() (any, error) { + return client.CreateEmbeddings(ctx, EmbeddingRequest{}) + }}, + {"CreateImage", func() (any, error) { + return client.CreateImage(ctx, ImageRequest{}) + }}, + {"DeleteFile", func() (any, error) { + return nil, client.DeleteFile(ctx, "") + }}, + {"GetFile", func() (any, error) { + return client.GetFile(ctx, "") + }}, + {"ListFiles", func() (any, error) { + return client.ListFiles(ctx) + }}, + {"ListEngines", func() (any, error) { + return client.ListEngines(ctx) + }}, + {"GetEngine", func() (any, error) { + return client.GetEngine(ctx, "") + }}, + {"ListModels", func() (any, error) { + return client.ListModels(ctx) + }}, + {"GetModel", func() (any, error) { + return client.GetModel(ctx, "text-davinci-003") + }}, } - _, err = client.ListFiles(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListEngines(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetEngine(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListModels(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) + for _, testCase := range testCases { + _, err = testCase.TestFunc() + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err) + } } } diff --git a/models.go b/models.go index 485433b26..b3d458366 100644 --- a/models.go +++ b/models.go @@ -2,6 +2,7 @@ package openai import ( "context" + "fmt" "net/http" ) @@ -48,3 +49,16 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) err = c.sendRequest(req, &models) return } + +// GetModel Retrieves a model instance, providing basic information about +// the model such as the owner and permissioning. +func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) { + urlSuffix := fmt.Sprintf("/models/%s", modelID) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &model) + return +} diff --git a/models_test.go b/models_test.go index b017800b9..834c849c4 100644 --- a/models_test.go +++ b/models_test.go @@ -54,3 +54,44 @@ func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(ModelsList{}) fmt.Fprintln(w, string(resBytes)) } + +// TestGetModel Tests the retrieve model endpoint of the API using the mocked server. +func TestGetModel(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.GetModel(ctx, "text-davinci-003") + checks.NoError(t, err, "GetModel error") +} + +func TestAzureGetModel(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config.BaseURL = ts.URL + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.GetModel(ctx, "text-davinci-003") + checks.NoError(t, err, "GetModel error") +} + +// handleModelsEndpoint Handles the models endpoint by the test server. +func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(Model{}) + fmt.Fprintln(w, string(resBytes)) +} From b8c13e4c017ab031ede870da538c0b0e3cf4996b Mon Sep 17 00:00:00 2001 From: Yuki Bobier Koshimizu Date: Fri, 9 Jun 2023 00:31:25 +0900 Subject: [PATCH 013/206] Refactor streamReader: Replace goto Statement with Loop in Recv Method (#339) * test: Add tests for improved coverage before refactoring This commit adds tests to improve coverage before refactoring to ensure that the changes do not break anything. * refactor: replace goto statement with loop This commit introduces a refactor to improve the clarity of the control flow within the method. The goto statement can sometimes make the code hard to understand and maintain, hence this refactor aims to resolve that. * refactor: extract for-loop from Recv to another method This commit improves code readability and maintainability by making the Recv method simpler. --- stream_reader.go | 65 ++++++++++--------- stream_test.go | 160 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 28 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index a9940b0ae..34161986e 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -30,43 +30,52 @@ func (stream *streamReader[T]) Recv() (response T, err error) { return } + response, err = stream.processLines() + return +} + +func (stream *streamReader[T]) processLines() (T, error) { var emptyMessagesCount uint -waitForData: - line, err := stream.reader.ReadBytes('\n') - if err != nil { - respErr := stream.unmarshalError() - if respErr != nil { - err = fmt.Errorf("error, %w", respErr.Error) + for { + rawLine, readErr := stream.reader.ReadBytes('\n') + if readErr != nil { + respErr := stream.unmarshalError() + if respErr != nil { + return *new(T), fmt.Errorf("error, %w", respErr.Error) + } + return *new(T), readErr } - return - } - var headerData = []byte("data: ") - line = bytes.TrimSpace(line) - if !bytes.HasPrefix(line, headerData) { - if writeErr := stream.errAccumulator.Write(line); writeErr != nil { - err = writeErr - return + var headerData = []byte("data: ") + noSpaceLine := bytes.TrimSpace(rawLine) + if !bytes.HasPrefix(noSpaceLine, headerData) { + writeErr := stream.errAccumulator.Write(noSpaceLine) + if writeErr != nil { + return *new(T), writeErr + } + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + return *new(T), ErrTooManyEmptyStreamMessages + } + + continue } - emptyMessagesCount++ - if emptyMessagesCount > stream.emptyMessagesLimit { - err = ErrTooManyEmptyStreamMessages - return + + noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + if string(noPrefixLine) == "[DONE]" { + stream.isFinished = true + return *new(T), io.EOF } - goto waitForData - } + var response T + unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response) + if unmarshalErr != nil { + return *new(T), unmarshalErr + } - line = bytes.TrimPrefix(line, headerData) - if string(line) == "[DONE]" { - stream.isFinished = true - err = io.EOF - return + return response, nil } - - err = stream.unmarshaler.Unmarshal(line, &response) - return } func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { diff --git a/stream_test.go b/stream_test.go index 589fc9e26..0faa21222 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,6 +2,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "net/http" @@ -217,6 +218,165 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Totally 301 empty messages (300 is the limit) + for i := 0; i < 299; i++ { + dataBytes = append(dataBytes, '\n') + } + + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) { + t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") + } +} + +func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Stream is terminated without sending "done" message + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("TestCreateCompletionStreamUnexpectedTerminatedError did not return io.EOF") + } +} + +func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Send broken json + dataBytes = append(dataBytes, []byte("event: message\n")...) + data = `{"id":"2","object":"completion","created":1598069255,"model":` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + var syntaxError *json.SyntaxError + if !errors.As(streamErr, &syntaxError) { + t.Errorf("TestCreateCompletionStreamBrokenJSONError did not return json.SyntaxError") + } +} + // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { From 06b16a728172c7181131ee0dc17d912feb59cec1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Fri, 9 Jun 2023 00:32:03 +0900 Subject: [PATCH 014/206] fix json marshaling error response of azure openai (#343) (#345) * fix json marshaling error response of azure openai (#343) * add a test case for handleErrorResp func (#343) --- chat_stream_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 9 +++++++ error.go | 10 +++++--- 3 files changed, 77 insertions(+), 3 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 77d373c6a..19c2e3cd0 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -255,6 +255,67 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { + wantCode := "429" + wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " + + "version 2023-03-15-preview have exceeded token rate limit of your current OpenAI S0 pricing tier. " + + "Please retry after 20 seconds. " + + "Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit." + + server := test.NewTestServer() + server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + // Send test responses + dataBytes := []byte(`{"error": { "code": "` + wantCode + `", "message": "` + wantMessage + `"}}`) + _, err := w.Write(dataBytes) + + checks.NoError(t, err, "Write error") + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultAzureConfig(test.GetTestToken(), ts.URL) + client := NewClientWithConfig(config) + ctx := context.Background() + + request := ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } + + apiErr := &APIError{} + _, err = client.CreateChatCompletionStream(ctx, request) + if !errors.As(err, &apiErr) { + t.Errorf("Did not return APIError: %+v\n", apiErr) + return + } + if apiErr.HTTPStatusCode != http.StatusTooManyRequests { + t.Errorf("Did not return HTTPStatusCode got = %d, want = %d\n", apiErr.HTTPStatusCode, http.StatusTooManyRequests) + return + } + code, ok := apiErr.Code.(string) + if !ok || code != wantCode { + t.Errorf("Did not return Code. got = %v, want = %s\n", apiErr.Code, wantCode) + return + } + if apiErr.Message != wantMessage { + t.Errorf("Did not return Message. got = %s, want = %s\n", apiErr.Message, wantMessage) + return + } +} + func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) { var err error server := test.NewTestServer() diff --git a/client_test.go b/client_test.go index 5e63539df..c96ceb7e6 100644 --- a/client_test.go +++ b/client_test.go @@ -134,6 +134,15 @@ func TestHandleErrorResp(t *testing.T) { }`)), expected: "error, status code: 503, message: That model...", }, + { + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + body: bytes.NewReader([]byte(` + { + "error":{} + }`)), + expected: "error, status code: 503, message: ", + }, } for _, tc := range testCases { diff --git a/error.go b/error.go index 6354f43b5..b789ed7d5 100644 --- a/error.go +++ b/error.go @@ -44,9 +44,13 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { return } - err = json.Unmarshal(rawMap["type"], &e.Type) - if err != nil { - return + // optional fields for azure openai + // refs: https://github.com/sashabaranov/go-openai/issues/343 + if _, ok := rawMap["type"]; ok { + err = json.Unmarshal(rawMap["type"], &e.Type) + if err != nil { + return + } } // optional fields From a243e7331f27b707f4942ddf39dac348a8aa4126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sun, 11 Jun 2023 17:49:57 +0900 Subject: [PATCH 015/206] Support Retrieve file content API (#347) (#348) * Support Retrieve file content API (#347) * add timeout test for GetFileContent (#347) --- chat_stream.go | 3 +- client.go | 43 +++++++------ client_test.go | 3 + files.go | 24 +++++++ files_test.go | 166 +++++++++++++++++++++++++++++++++++++++++++++++++ stream.go | 3 +- 6 files changed, 216 insertions(+), 26 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 9378c7124..625d436cb 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -3,7 +3,6 @@ package openai import ( "bufio" "context" - "net/http" utils "github.com/sashabaranov/go-openai/internal" ) @@ -57,7 +56,7 @@ func (c *Client) CreateChatCompletionStream( if err != nil { return } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + if isFailureStatusCode(resp) { return nil, c.handleErrorResp(resp) } diff --git a/client.go b/client.go index 2486e36b6..f38c1dfc3 100644 --- a/client.go +++ b/client.go @@ -47,13 +47,6 @@ func NewOrgClient(authToken, org string) *Client { func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Accept", "application/json; charset=utf-8") - // Azure API Key authentication - if c.config.APIType == APITypeAzure { - req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) - } // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data @@ -62,9 +55,7 @@ func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Content-Type", "application/json; charset=utf-8") } - if len(c.config.OrgID) > 0 { - req.Header.Set("OpenAI-Organization", c.config.OrgID) - } + c.setCommonHeaders(req) res, err := c.config.HTTPClient.Do(req) if err != nil { @@ -73,13 +64,31 @@ func (c *Client) sendRequest(req *http.Request, v any) error { defer res.Body.Close() - if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { + if isFailureStatusCode(res) { return c.handleErrorResp(res) } return decodeResponse(res.Body, v) } +func (c *Client) setCommonHeaders(req *http.Request) { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication + // Azure API Key authentication + if c.config.APIType == APITypeAzure { + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) + } else { + // OpenAI or Azure AD authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } + if c.config.OrgID != "" { + req.Header.Set("OpenAI-Organization", c.config.OrgID) + } +} + +func isFailureStatusCode(resp *http.Response) bool { + return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest +} + func decodeResponse(body io.Reader, v any) error { if v == nil { return nil @@ -145,17 +154,7 @@ func (c *Client) newStreamRequest( req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication - // Azure API Key authentication - if c.config.APIType == APITypeAzure { - req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) - } - if c.config.OrgID != "" { - req.Header.Set("OpenAI-Organization", c.config.OrgID) - } + c.setCommonHeaders(req) return req, nil } diff --git a/client_test.go b/client_test.go index c96ceb7e6..70ac81351 100644 --- a/client_test.go +++ b/client_test.go @@ -233,6 +233,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"GetFile", func() (any, error) { return client.GetFile(ctx, "") }}, + {"GetFileContent", func() (any, error) { + return client.GetFileContent(ctx, "") + }}, {"ListFiles", func() (any, error) { return client.ListFiles(ctx) }}, diff --git a/files.go b/files.go index 36c024365..fb9937bea 100644 --- a/files.go +++ b/files.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "net/http" "os" ) @@ -103,3 +104,26 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err err = c.sendRequest(req, &file) return } + +func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { + urlSuffix := fmt.Sprintf("/files/%s/content", fileID) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + if err != nil { + return + } + + c.setCommonHeaders(req) + + res, err := c.config.HTTPClient.Do(req) + if err != nil { + return + } + + if isFailureStatusCode(res) { + err = c.handleErrorResp(res) + return + } + + content = res.Body + return +} diff --git a/files_test.go b/files_test.go index ffdcfa798..8e8934935 100644 --- a/files_test.go +++ b/files_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -141,3 +142,168 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } + +func TestDeleteFile(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + err = client.DeleteFile(ctx, "deadbeef") + checks.NoError(t, err, "DeleteFile error") +} + +func TestListFile(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "{}") + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.ListFiles(ctx) + checks.NoError(t, err, "ListFiles error") +} + +func TestGetFile(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "{}") + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.GetFile(ctx, "deadbeef") + checks.NoError(t, err, "GetFile error") +} + +func TestGetFileContent(t *testing.T) { + wantRespJsonl := `{"prompt": "foo", "completion": "foo"} +{"prompt": "bar", "completion": "bar"} +{"prompt": "baz", "completion": "baz"} +` + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + // edits only accepts GET requests + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + fmt.Fprint(w, wantRespJsonl) + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + content, err := client.GetFileContent(ctx, "deadbeef") + checks.NoError(t, err, "GetFileContent error") + defer content.Close() + + actual, _ := io.ReadAll(content) + if string(actual) != wantRespJsonl { + t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual)) + } +} + +func TestGetFileContentReturnError(t *testing.T) { + wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." + wantType := "invalid_request_error" + wantErrorResp := `{ + "error": { + "message": "` + wantMessage + `", + "type": "` + wantType + `", + "param": null, + "code": null + } +}` + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, wantErrorResp) + }) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.GetFileContent(ctx, "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + + apiErr := &APIError{} + if !errors.As(err, &apiErr) { + t.Fatalf("Did not return APIError: %+v\n", apiErr) + } + if apiErr.Message != wantMessage { + t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message) + return + } + if apiErr.Type != wantType { + t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type) + return + } +} + +func TestGetFileContentReturnTimeoutError(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.GetFileContent(ctx, "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} diff --git a/stream.go b/stream.go index d4e352314..94cc0a0a2 100644 --- a/stream.go +++ b/stream.go @@ -4,7 +4,6 @@ import ( "bufio" "context" "errors" - "net/http" utils "github.com/sashabaranov/go-openai/internal" ) @@ -46,7 +45,7 @@ func (c *Client) CreateCompletionStream( if err != nil { return } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + if isFailureStatusCode(resp) { return nil, c.handleErrorResp(resp) } From b616090e699616e699d9839d1e5e3fd0cab7ef46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 12 Jun 2023 22:40:26 +0900 Subject: [PATCH 016/206] refactoring tests with mock servers (#30) (#356) --- api_test.go | 16 +-- audio_api_test.go | 162 +++++++++++++++++++++++++++ audio_test.go | 166 ---------------------------- chat_stream_test.go | 125 ++++----------------- chat_test.go | 20 +--- client_test.go | 20 +--- completion_test.go | 19 +--- edits_test.go | 17 +-- embeddings_test.go | 17 +-- engines_test.go | 28 ++--- files_api_test.go | 183 ++++++++++++++++++++++++++++++ files_test.go | 236 --------------------------------------- fine_tunes_test.go | 15 +-- image_api_test.go | 223 +++++++++++++++++++++++++++++++++++++ image_test.go | 252 ------------------------------------------ models_test.go | 77 ++++--------- moderation_test.go | 24 +--- openai_test.go | 28 +++++ stream_reader_test.go | 16 ++- stream_test.go | 147 ++++++------------------ 20 files changed, 731 insertions(+), 1060 deletions(-) create mode 100644 audio_api_test.go create mode 100644 files_api_test.go create mode 100644 image_api_test.go create mode 100644 openai_test.go diff --git a/api_test.go b/api_test.go index 78fd5cc6d..083b67412 100644 --- a/api_test.go +++ b/api_test.go @@ -6,7 +6,6 @@ import ( "errors" "io" "net/http" - "net/http/httptest" "os" "testing" @@ -226,18 +225,13 @@ func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) { } func TestRequestError(t *testing.T) { - var err error - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) - })) - defer ts.Close() + }) - config := DefaultConfig("dummy") - config.BaseURL = ts.URL - c := NewClientWithConfig(config) - ctx := context.Background() - _, err = c.ListEngines(ctx) + _, err := client.ListEngines(context.Background()) checks.HasError(t, err, "ListEngines did not fail") var reqErr *RequestError diff --git a/audio_api_test.go b/audio_api_test.go new file mode 100644 index 000000000..aad7a225a --- /dev/null +++ b/audio_api_test.go @@ -0,0 +1,162 @@ +package openai_test + +import ( + "bytes" + "context" + "errors" + "io" + "mime" + "mime/multipart" + "net/http" + "path/filepath" + "strings" + "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. +func TestAudio(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) + server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) + + testcases := []struct { + name string + createFn func(context.Context, AudioRequest) (AudioResponse, error) + }{ + { + "transcribe", + client.CreateTranscription, + }, + { + "translate", + client.CreateTranslation, + }, + } + + ctx := context.Background() + + dir, cleanup := test.CreateTestDirectory(t) + defer cleanup() + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + Model: "whisper-3", + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + + t.Run(tc.name+" (with reader)", func(t *testing.T) { + req := AudioRequest{ + FilePath: "fake.webm", + Reader: bytes.NewBuffer([]byte(`some webm binary data`)), + Model: "whisper-3", + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + } +} + +func TestAudioWithOptionalArgs(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) + server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) + + testcases := []struct { + name string + createFn func(context.Context, AudioRequest) (AudioResponse, error) + }{ + { + "transcribe", + client.CreateTranscription, + }, + { + "translate", + client.CreateTranslation, + }, + } + + ctx := context.Background() + + dir, cleanup := test.CreateTestDirectory(t) + defer cleanup() + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + Model: "whisper-3", + Prompt: "用简体中文", + Temperature: 0.5, + Language: "zh", + Format: AudioResponseFormatSRT, + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + } +} + +// handleAudioEndpoint Handles the completion endpoint by the test server. +func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } + + mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if !strings.HasPrefix(mediaType, "multipart") { + http.Error(w, "request is not multipart", http.StatusBadRequest) + } + + boundary, ok := params["boundary"] + if !ok { + http.Error(w, "no boundary in params", http.StatusBadRequest) + return + } + + fileData := &bytes.Buffer{} + mr := multipart.NewReader(r.Body, boundary) + part, err := mr.NextPart() + if err != nil && errors.Is(err, io.EOF) { + http.Error(w, "error accessing file", http.StatusBadRequest) + return + } + if _, err = io.Copy(fileData, part); err != nil { + http.Error(w, "failed to copy file", http.StatusInternalServerError) + return + } + + if len(fileData.Bytes()) == 0 { + w.WriteHeader(http.StatusInternalServerError) + http.Error(w, "received empty file data", http.StatusBadRequest) + return + } + + if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } +} diff --git a/audio_test.go b/audio_test.go index 6452e2eb7..e19a873f3 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,182 +2,16 @@ package openai //nolint:testpackage // testing private field import ( "bytes" - "context" - "errors" "fmt" "io" - "mime" - "mime/multipart" - "net/http" "os" "path/filepath" - "strings" "testing" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) -// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. -func TestAudio(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) - server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - testcases := []struct { - name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) - }{ - { - "transcribe", - client.CreateTranscription, - }, - { - "translate", - client.CreateTranslation, - }, - } - - ctx := context.Background() - - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") - test.CreateTestFile(t, path) - - req := AudioRequest{ - FilePath: path, - Model: "whisper-3", - } - _, err = tc.createFn(ctx, req) - checks.NoError(t, err, "audio API error") - }) - - t.Run(tc.name+" (with reader)", func(t *testing.T) { - req := AudioRequest{ - FilePath: "fake.webm", - Reader: bytes.NewBuffer([]byte(`some webm binary data`)), - Model: "whisper-3", - } - _, err = tc.createFn(ctx, req) - checks.NoError(t, err, "audio API error") - }) - } -} - -func TestAudioWithOptionalArgs(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) - server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - testcases := []struct { - name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) - }{ - { - "transcribe", - client.CreateTranscription, - }, - { - "translate", - client.CreateTranslation, - }, - } - - ctx := context.Background() - - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") - test.CreateTestFile(t, path) - - req := AudioRequest{ - FilePath: path, - Model: "whisper-3", - Prompt: "用简体中文", - Temperature: 0.5, - Language: "zh", - Format: AudioResponseFormatSRT, - } - _, err = tc.createFn(ctx, req) - checks.NoError(t, err, "audio API error") - }) - } -} - -// handleAudioEndpoint Handles the completion endpoint by the test server. -func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - - // audio endpoints only accept POST requests - if r.Method != "POST" { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } - - mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) - if err != nil { - http.Error(w, "failed to parse media type", http.StatusBadRequest) - return - } - - if !strings.HasPrefix(mediaType, "multipart") { - http.Error(w, "request is not multipart", http.StatusBadRequest) - } - - boundary, ok := params["boundary"] - if !ok { - http.Error(w, "no boundary in params", http.StatusBadRequest) - return - } - - fileData := &bytes.Buffer{} - mr := multipart.NewReader(r.Body, boundary) - part, err := mr.NextPart() - if err != nil && errors.Is(err, io.EOF) { - http.Error(w, "error accessing file", http.StatusBadRequest) - return - } - if _, err = io.Copy(fileData, part); err != nil { - http.Error(w, "failed to copy file", http.StatusInternalServerError) - return - } - - if len(fileData.Bytes()) == 0 { - w.WriteHeader(http.StatusInternalServerError) - http.Error(w, "received empty file data", http.StatusBadRequest) - return - } - - if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { - http.Error(w, "failed to write body", http.StatusInternalServerError) - return - } -} - func TestAudioWithFailingFormBuilder(t *testing.T) { dir, cleanup := test.CreateTestDirectory(t) defer cleanup() diff --git a/chat_stream_test.go b/chat_stream_test.go index 19c2e3cd0..c3cb9f3f7 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,8 +1,7 @@ -package openai //nolint:testpackage // testing private field +package openai_test import ( - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -10,7 +9,6 @@ import ( "errors" "io" "net/http" - "net/http/httptest" "testing" ) @@ -37,7 +35,9 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { } func TestCreateChatCompletionStream(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -57,21 +57,9 @@ func TestCreateChatCompletionStream(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -81,9 +69,7 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, Stream: true, - } - - stream, err := client.CreateChatCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -143,7 +129,9 @@ func TestCreateChatCompletionStream(t *testing.T) { } func TestCreateChatCompletionStreamError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -164,21 +152,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -188,9 +164,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { }, }, Stream: true, - } - - stream, err := client.CreateChatCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -205,7 +179,8 @@ func TestCreateChatCompletionStreamError(t *testing.T) { } func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -220,22 +195,7 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() - - request := ChatCompletionRequest{ + _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -245,10 +205,8 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { }, }, Stream: true, - } - + }) var apiErr *APIError - _, err := client.CreateChatCompletionStream(ctx, request) if !errors.As(err, &apiErr) { t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") } @@ -262,7 +220,8 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { "Please retry after 20 seconds. " + "Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit." - server := test.NewTestServer() + client, server, teardown := setupAzureTestServer() + defer teardown() server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -273,17 +232,9 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultAzureConfig(test.GetTestToken(), ts.URL) - client := NewClientWithConfig(config) - ctx := context.Background() - request := ChatCompletionRequest{ + apiErr := &APIError{} + _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -293,10 +244,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { }, }, Stream: true, - } - - apiErr := &APIError{} - _, err = client.CreateChatCompletionStream(ctx, request) + }) if !errors.As(err, &apiErr) { t.Errorf("Did not return APIError: %+v\n", apiErr) return @@ -316,33 +264,6 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } } -func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) { - var err error - server := test.NewTestServer() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "error", 200) - }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - ctx := context.Background() - - stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - checks.NoError(t, err) - - stream.errAccumulator = &utils.DefaultErrorAccumulator{ - Buffer: &test.FailingErrorBuffer{}, - } - - _, err = stream.Recv() - checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error()) -} - // Helper funcs. func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/chat_test.go b/chat_test.go index ce302a69f..ebe29f9eb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -52,20 +51,10 @@ func TestChatCompletionsWithStream(t *testing.T) { // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletions(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo, Messages: []ChatCompletionMessage{ @@ -74,8 +63,7 @@ func TestChatCompletions(t *testing.T) { Content: "Hello!", }, }, - } - _, err = client.CreateChatCompletion(ctx, req) + }) checks.NoError(t, err, "CreateChatCompletion error") } diff --git a/client_test.go b/client_test.go index 70ac81351..00b66feae 100644 --- a/client_test.go +++ b/client_test.go @@ -167,16 +167,9 @@ func TestHandleErrorResp(t *testing.T) { } func TestClientReturnsRequestBuilderErrors(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} - ctx := context.Background() type TestCase struct { @@ -254,7 +247,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } for _, testCase := range testCases { - _, err = testCase.TestFunc() + _, err := testCase.TestFunc() if !errors.Is(err, errTestRequestBuilderFailed) { t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err) } @@ -262,23 +255,14 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) + _, err := client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { t.Fatalf("Did not return error when request builder failed: %v", err) } - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { t.Fatalf("Did not return error when request builder failed: %v", err) diff --git a/completion_test.go b/completion_test.go index 2e302591a..aeddcfca1 100644 --- a/completion_test.go +++ b/completion_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -48,25 +47,15 @@ func TestCompletionWithStream(t *testing.T) { // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestCompletions(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - req := CompletionRequest{ MaxTokens: 5, Model: "ada", + Prompt: "Lorem ipsum", } - req.Prompt = "Lorem ipsum" - _, err = client.CreateCompletion(ctx, req) + _, err := client.CreateCompletion(context.Background(), req) checks.NoError(t, err, "CreateCompletion error") } diff --git a/edits_test.go b/edits_test.go index fa6c12825..c0bb84392 100644 --- a/edits_test.go +++ b/edits_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -16,19 +15,9 @@ import ( // TestEdits Tests the edits endpoint of the API using the mocked server. func TestEdits(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/edits", handleEditEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - // create an edit request model := "ada" editReq := EditsRequest{ @@ -40,7 +29,7 @@ func TestEdits(t *testing.T) { Instruction: "test instruction", N: 3, } - response, err := client.Edits(ctx, editReq) + response, err := client.Edits(context.Background(), editReq) checks.NoError(t, err, "Edits error") if len(response.Choices) != editReq.N { t.Fatalf("edits does not properly return the correct number of choices") diff --git a/embeddings_test.go b/embeddings_test.go index 252f7a5a0..d7892cd5d 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -67,7 +66,8 @@ func TestEmbeddingModel(t *testing.T) { } func TestEmbeddingEndpoint(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler( "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { @@ -75,17 +75,6 @@ func TestEmbeddingEndpoint(t *testing.T) { fmt.Fprintln(w, string(resBytes)) }, ) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") } diff --git a/engines_test.go b/engines_test.go index dfa3187cf..2beb333b3 100644 --- a/engines_test.go +++ b/engines_test.go @@ -8,27 +8,29 @@ import ( "testing" . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. func TestGetEngine(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { resBytes, _ := json.Marshal(Engine{}) fmt.Fprintln(w, string(resBytes)) }) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err := client.GetEngine(ctx, "text-davinci-003") + _, err := client.GetEngine(context.Background(), "text-davinci-003") checks.NoError(t, err, "GetEngine error") } + +// TestListEngines Tests the list engines endpoint of the API using the mocked server. +func TestListEngines(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(EnginesList{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.ListEngines(context.Background()) + checks.NoError(t, err, "ListEngines error") +} diff --git a/files_api_test.go b/files_api_test.go new file mode 100644 index 000000000..f0a08764d --- /dev/null +++ b/files_api_test.go @@ -0,0 +1,183 @@ +package openai_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strconv" + "testing" + "time" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestFileUpload(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + req := FileRequest{ + FileName: "test.go", + FilePath: "client.go", + Purpose: "fine-tune", + } + _, err := client.CreateFile(context.Background(), req) + checks.NoError(t, err, "CreateFile error") +} + +// handleCreateFile Handles the images endpoint by the test server. +func handleCreateFile(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // edits only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + err = r.ParseMultipartForm(1024 * 1024 * 1024) + if err != nil { + http.Error(w, "file is more than 1GB", http.StatusInternalServerError) + return + } + + values := r.Form + var purpose string + for key, value := range values { + if key == "purpose" { + purpose = value[0] + } + } + file, header, err := r.FormFile("file") + if err != nil { + return + } + defer file.Close() + + var fileReq = File{ + Bytes: int(header.Size), + ID: strconv.Itoa(int(time.Now().Unix())), + FileName: header.Filename, + Purpose: purpose, + CreatedAt: time.Now().Unix(), + Object: "test-objecct", + Owner: "test-owner", + } + + resBytes, _ = json.Marshal(fileReq) + fmt.Fprint(w, string(resBytes)) +} + +func TestDeleteFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {}) + err := client.DeleteFile(context.Background(), "deadbeef") + checks.NoError(t, err, "DeleteFile error") +} + +func TestListFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FilesList{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.ListFiles(context.Background()) + checks.NoError(t, err, "ListFiles error") +} + +func TestGetFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(File{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.GetFile(context.Background(), "deadbeef") + checks.NoError(t, err, "GetFile error") +} + +func TestGetFileContent(t *testing.T) { + wantRespJsonl := `{"prompt": "foo", "completion": "foo"} +{"prompt": "bar", "completion": "bar"} +{"prompt": "baz", "completion": "baz"} +` + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + // edits only accepts GET requests + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + fmt.Fprint(w, wantRespJsonl) + }) + + content, err := client.GetFileContent(context.Background(), "deadbeef") + checks.NoError(t, err, "GetFileContent error") + defer content.Close() + + actual, _ := io.ReadAll(content) + if string(actual) != wantRespJsonl { + t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual)) + } +} + +func TestGetFileContentReturnError(t *testing.T) { + wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." + wantType := "invalid_request_error" + wantErrorResp := `{ + "error": { + "message": "` + wantMessage + `", + "type": "` + wantType + `", + "param": null, + "code": null + } +}` + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, wantErrorResp) + }) + + _, err := client.GetFileContent(context.Background(), "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + + apiErr := &APIError{} + if !errors.As(err, &apiErr) { + t.Fatalf("Did not return APIError: %+v\n", apiErr) + } + if apiErr.Message != wantMessage { + t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message) + return + } + if apiErr.Type != wantType { + t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type) + return + } +} + +func TestGetFileContentReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.GetFileContent(ctx, "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} diff --git a/files_test.go b/files_test.go index 8e8934935..df6eaef7b 100644 --- a/files_test.go +++ b/files_test.go @@ -2,86 +2,15 @@ package openai //nolint:testpackage // testing private field import ( utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" - "encoding/json" - "errors" "fmt" "io" - "net/http" "os" - "strconv" "testing" - "time" ) -func TestFileUpload(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files", handleCreateFile) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := FileRequest{ - FileName: "test.go", - FilePath: "client.go", - Purpose: "fine-tune", - } - _, err = client.CreateFile(ctx, req) - checks.NoError(t, err, "CreateFile error") -} - -// handleCreateFile Handles the images endpoint by the test server. -func handleCreateFile(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // edits only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - err = r.ParseMultipartForm(1024 * 1024 * 1024) - if err != nil { - http.Error(w, "file is more than 1GB", http.StatusInternalServerError) - return - } - - values := r.Form - var purpose string - for key, value := range values { - if key == "purpose" { - purpose = value[0] - } - } - file, header, err := r.FormFile("file") - if err != nil { - return - } - defer file.Close() - - var fileReq = File{ - Bytes: int(header.Size), - ID: strconv.Itoa(int(time.Now().Unix())), - FileName: header.Filename, - Purpose: purpose, - CreatedAt: time.Now().Unix(), - Object: "test-objecct", - Owner: "test-owner", - } - - resBytes, _ = json.Marshal(fileReq) - fmt.Fprint(w, string(resBytes)) -} - func TestFileUploadWithFailingFormBuilder(t *testing.T) { config := DefaultConfig("") config.BaseURL = "" @@ -142,168 +71,3 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } - -func TestDeleteFile(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { - - }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - err = client.DeleteFile(ctx, "deadbeef") - checks.NoError(t, err, "DeleteFile error") -} - -func TestListFile(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "{}") - }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.ListFiles(ctx) - checks.NoError(t, err, "ListFiles error") -} - -func TestGetFile(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "{}") - }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.GetFile(ctx, "deadbeef") - checks.NoError(t, err, "GetFile error") -} - -func TestGetFileContent(t *testing.T) { - wantRespJsonl := `{"prompt": "foo", "completion": "foo"} -{"prompt": "bar", "completion": "bar"} -{"prompt": "baz", "completion": "baz"} -` - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { - // edits only accepts GET requests - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - fmt.Fprint(w, wantRespJsonl) - }) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - content, err := client.GetFileContent(ctx, "deadbeef") - checks.NoError(t, err, "GetFileContent error") - defer content.Close() - - actual, _ := io.ReadAll(content) - if string(actual) != wantRespJsonl { - t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual)) - } -} - -func TestGetFileContentReturnError(t *testing.T) { - wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." - wantType := "invalid_request_error" - wantErrorResp := `{ - "error": { - "message": "` + wantMessage + `", - "type": "` + wantType + `", - "param": null, - "code": null - } -}` - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) - fmt.Fprint(w, wantErrorResp) - }) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err := client.GetFileContent(ctx, "deadbeef") - if err == nil { - t.Fatal("Did not return error") - } - - apiErr := &APIError{} - if !errors.As(err, &apiErr) { - t.Fatalf("Did not return APIError: %+v\n", apiErr) - } - if apiErr.Message != wantMessage { - t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message) - return - } - if apiErr.Type != wantType { - t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type) - return - } -} - -func TestGetFileContentReturnTimeoutError(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { - time.Sleep(10 * time.Nanosecond) - }) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) - defer cancel() - - _, err := client.GetFileContent(ctx, "deadbeef") - if err == nil { - t.Fatal("Did not return error") - } - if !os.IsTimeout(err) { - t.Fatal("Did not return timeout error") - } -} diff --git a/fine_tunes_test.go b/fine_tunes_test.go index c60254993..67f681d97 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -16,7 +15,8 @@ const testFineTuneID = "fine-tune-id" // TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. func TestFineTunes(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler( "/v1/fine-tunes", func(w http.ResponseWriter, r *http.Request) { @@ -59,18 +59,9 @@ func TestFineTunes(t *testing.T) { }, ) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) ctx := context.Background() - _, err = client.ListFineTunes(ctx) + _, err := client.ListFineTunes(ctx) checks.NoError(t, err, "ListFineTunes error") _, err = client.CreateFineTune(ctx, FineTuneRequest{}) diff --git a/image_api_test.go b/image_api_test.go new file mode 100644 index 000000000..b472eb04a --- /dev/null +++ b/image_api_test.go @@ -0,0 +1,223 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "testing" + "time" +) + +func TestImages(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/generations", handleImageEndpoint) + _, err := client.CreateImage(context.Background(), ImageRequest{ + Prompt: "Lorem ipsum", + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleImageEndpoint Handles the images endpoint by the test server. +func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var imageReq ImageRequest + if imageReq, err = getImageBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := ImageResponse{ + Created: time.Now().Unix(), + } + for i := 0; i < imageReq.N; i++ { + imageData := ImageResponseDataInner{} + switch imageReq.ResponseFormat { + case CreateImageResponseFormatURL, "": + imageData.URL = "https://example.com/image.png" + case CreateImageResponseFormatB64JSON: + // This decodes to "{}" in base64. + imageData.B64JSON = "e30K" + default: + http.Error(w, "invalid response format", http.StatusBadRequest) + return + } + res.Data = append(res.Data, imageData) + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getImageBody Returns the body of the request to create a image. +func getImageBody(r *http.Request) (ImageRequest, error) { + image := ImageRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return ImageRequest{}, err + } + err = json.Unmarshal(reqBody, &image) + if err != nil { + return ImageRequest{}, err + } + return image, nil +} + +func TestImageEdit(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + mask, err := os.Create("mask.png") + if err != nil { + t.Error("open mask file error") + return + } + + defer func() { + mask.Close() + origin.Close() + os.Remove("mask.png") + os.Remove("image.png") + }() + + _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + Image: origin, + Mask: mask, + Prompt: "There is a turtle in the pool", + N: 3, + Size: CreateImageSize1024x1024, + ResponseFormat: CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +func TestImageEditWithoutMask(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + defer func() { + origin.Close() + os.Remove("image.png") + }() + + _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + Image: origin, + Prompt: "There is a turtle in the pool", + N: 3, + Size: CreateImageSize1024x1024, + ResponseFormat: CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleEditImageEndpoint Handles the images endpoint by the test server. +func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := ImageResponse{ + Created: time.Now().Unix(), + Data: []ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} + +func TestImageVariation(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + defer func() { + origin.Close() + os.Remove("image.png") + }() + + _, err = client.CreateVariImage(context.Background(), ImageVariRequest{ + Image: origin, + N: 3, + Size: CreateImageSize1024x1024, + ResponseFormat: CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleVariateImageEndpoint Handles the images endpoint by the test server. +func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := ImageResponse{ + Created: time.Now().Unix(), + Data: []ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} diff --git a/image_test.go b/image_test.go index ca9faed95..81fff6cba 100644 --- a/image_test.go +++ b/image_test.go @@ -2,267 +2,15 @@ package openai //nolint:testpackage // testing private field import ( utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" - "encoding/json" "fmt" "io" - "net/http" "os" "testing" - "time" ) -func TestImages(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/generations", handleImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -// handleImageEndpoint Handles the images endpoint by the test server. -func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var imageReq ImageRequest - if imageReq, err = getImageBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - res := ImageResponse{ - Created: time.Now().Unix(), - } - for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} - switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": - imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: - // This decodes to "{}" in base64. - imageData.B64JSON = "e30K" - default: - http.Error(w, "invalid response format", http.StatusBadRequest) - return - } - res.Data = append(res.Data, imageData) - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return ImageRequest{}, err - } - err = json.Unmarshal(reqBody, &image) - if err != nil { - return ImageRequest{}, err - } - return image, nil -} - -func TestImageEdit(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - mask, err := os.Create("mask.png") - if err != nil { - t.Error("open mask file error") - return - } - - defer func() { - mask.Close() - origin.Close() - os.Remove("mask.png") - os.Remove("image.png") - }() - - req := ImageEditRequest{ - Image: origin, - Mask: mask, - Prompt: "There is a turtle in the pool", - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateEditImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -func TestImageEditWithoutMask(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - defer func() { - origin.Close() - os.Remove("image.png") - }() - - req := ImageEditRequest{ - Image: origin, - Prompt: "There is a turtle in the pool", - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateEditImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -// handleEditImageEndpoint Handles the images endpoint by the test server. -func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - - responses := ImageResponse{ - Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ - { - URL: "test-url1", - B64JSON: "", - }, - { - URL: "test-url2", - B64JSON: "", - }, - { - URL: "test-url3", - B64JSON: "", - }, - }, - } - - resBytes, _ = json.Marshal(responses) - fmt.Fprintln(w, string(resBytes)) -} - -func TestImageVariation(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - defer func() { - origin.Close() - os.Remove("image.png") - }() - - req := ImageVariRequest{ - Image: origin, - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateVariImage(ctx, req) - checks.NoError(t, err, "CreateImage error") -} - -// handleVariateImageEndpoint Handles the images endpoint by the test server. -func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - - responses := ImageResponse{ - Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ - { - URL: "test-url1", - B64JSON: "", - }, - { - URL: "test-url2", - B64JSON: "", - }, - { - URL: "test-url3", - B64JSON: "", - }, - }, - } - - resBytes, _ = json.Marshal(responses) - fmt.Fprintln(w, string(resBytes)) -} - type mockFormBuilder struct { mockCreateFormFile func(string, *os.File) error mockCreateFormFileReader func(string, io.Reader, string) error diff --git a/models_test.go b/models_test.go index 834c849c4..0b4daf4a8 100644 --- a/models_test.go +++ b/models_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -12,85 +11,47 @@ import ( "testing" ) -// TestListModels Tests the models endpoint of the API using the mocked server. +// TestListModels Tests the list models endpoint of the API using the mocked server. func TestListModels(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/models", handleModelsEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.ListModels(ctx) + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models", handleListModelsEndpoint) + _, err := client.ListModels(context.Background()) checks.NoError(t, err, "ListModels error") } func TestAzureListModels(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/openai/models", handleModelsEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") - config.BaseURL = ts.URL - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.ListModels(ctx) + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/models", handleListModelsEndpoint) + _, err := client.ListModels(context.Background()) checks.NoError(t, err, "ListModels error") } -// handleModelsEndpoint Handles the models endpoint by the test server. -func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) { +// handleListModelsEndpoint Handles the list models endpoint by the test server. +func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(ModelsList{}) fmt.Fprintln(w, string(resBytes)) } // TestGetModel Tests the retrieve model endpoint of the API using the mocked server. func TestGetModel(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err := client.GetModel(ctx, "text-davinci-003") + _, err := client.GetModel(context.Background(), "text-davinci-003") checks.NoError(t, err, "GetModel error") } func TestAzureGetModel(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint) - // create the test server - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") - config.BaseURL = ts.URL - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err := client.GetModel(ctx, "text-davinci-003") + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/models/text-davinci-003", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "text-davinci-003") checks.NoError(t, err, "GetModel error") } -// handleModelsEndpoint Handles the models endpoint by the test server. +// handleGetModelsEndpoint Handles the get model endpoint by the test server. func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(Model{}) fmt.Fprintln(w, string(resBytes)) diff --git a/moderation_test.go b/moderation_test.go index 2c1145627..4e756137e 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -2,7 +2,6 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "context" @@ -18,26 +17,13 @@ import ( // TestModeration Tests the moderations endpoint of the API using the mocked server. func TestModerations(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - // create an edit request - model := "text-moderation-stable" - moderationReq := ModerationRequest{ - Model: model, + _, err := client.Moderations(context.Background(), ModerationRequest{ + Model: ModerationTextStable, Input: "I want to kill them.", - } - _, err = client.Moderations(ctx, moderationReq) + }) checks.NoError(t, err, "Moderation error") } diff --git a/openai_test.go b/openai_test.go new file mode 100644 index 000000000..a5e7b64ee --- /dev/null +++ b/openai_test.go @@ -0,0 +1,28 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" +) + +func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) { + server = test.NewTestServer() + ts := server.OpenAITestServer() + ts.Start() + teardown = ts.Close + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client = NewClientWithConfig(config) + return +} + +func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) { + server = test.NewTestServer() + ts := server.OpenAITestServer() + ts.Start() + teardown = ts.Close + config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config.BaseURL = ts.URL + client = NewClientWithConfig(config) + return +} diff --git a/stream_reader_test.go b/stream_reader_test.go index 0e45c0b73..cd6e46eff 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -7,6 +7,8 @@ import ( "testing" utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") @@ -47,7 +49,17 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { unmarshaler: &utils.JSONUnmarshaler{}, } _, err := stream.Recv() - if !errors.Is(err, ErrTooManyEmptyStreamMessages) { - t.Fatalf("Did not return error when recv failed: %v", err) + checks.ErrorIs(t, err, ErrTooManyEmptyStreamMessages, "Did not return error when recv failed", err.Error()) +} + +func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("\n"))), + errAccumulator: &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + }, + unmarshaler: &utils.JSONUnmarshaler{}, } + _, err := stream.Recv() + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) } diff --git a/stream_test.go b/stream_test.go index 0faa21222..5997f27e8 100644 --- a/stream_test.go +++ b/stream_test.go @@ -6,11 +6,9 @@ import ( "errors" "io" "net/http" - "net/http/httptest" "testing" . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -32,7 +30,9 @@ func TestCompletionsStreamWrongModel(t *testing.T) { } func TestCreateCompletionStream(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -52,28 +52,14 @@ func TestCreateCompletionStream(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -116,7 +102,9 @@ func TestCreateCompletionStream(t *testing.T) { } func TestCreateCompletionStreamError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -137,28 +125,14 @@ func TestCreateCompletionStreamError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ MaxTokens: 5, Model: GPT3TextDavinci003, Prompt: "Hello!", Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -173,7 +147,8 @@ func TestCreateCompletionStreamError(t *testing.T) { } func TestCreateCompletionStreamRateLimitError(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -188,30 +163,14 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() - request := CompletionRequest{ + var apiErr *APIError + _, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ MaxTokens: 5, Model: GPT3Ada, Prompt: "Hello!", Stream: true, - } - - var apiErr *APIError - _, err := client.CreateCompletionStream(ctx, request) + }) if !errors.As(err, &apiErr) { t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError") } @@ -219,7 +178,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { } func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -244,28 +205,14 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -277,7 +224,9 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { } func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -291,28 +240,14 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() @@ -324,7 +259,9 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { } func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -344,28 +281,14 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &test.TokenRoundTripper{ - Token: test.GetTestToken(), - Fallback: http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() From 3f4e3bb0ca25ebec8e98564e745d8351dae2dd1c Mon Sep 17 00:00:00 2001 From: Simon Klee Date: Tue, 13 Jun 2023 22:32:26 +0200 Subject: [PATCH 017/206] models: add *-0613 models (#361) Added GPT3Dot5Turbo0613, GPT3Dot5Turbo16K, GPT40613, and GPT432K0613 models from June update (https://openai.com/blog/function-calling-and-other-api-updates) Issue #360 --- completion.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/completion.go b/completion.go index de1360fd9..e7bf75acb 100644 --- a/completion.go +++ b/completion.go @@ -17,11 +17,15 @@ var ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( + GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" + GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" GPT4 = "gpt-4" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3TextDavinci003 = "text-davinci-003" GPT3TextDavinci002 = "text-davinci-002" @@ -50,10 +54,14 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, + GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo16K: true, GPT4: true, GPT40314: true, + GPT40613: true, GPT432K: true, GPT432K0314: true, + GPT432K0613: true, }, "/chat/completions": { CodexCodeDavinci002: true, From 646989cc5bb61f73335017243e0b008b149ba0ab Mon Sep 17 00:00:00 2001 From: Rich Coggins Date: Wed, 14 Jun 2023 10:19:18 -0400 Subject: [PATCH 018/206] Improve (#356) to support registration of wildcard URLs (#359) * Improve (#356) to support registration of wildcard URLs * Add TestAzureChatCompletions & TestAzureChatCompletionsWithCustomDeploymentName * Remove TestAzureChatCompletionsWithCustomDeploymentName --------- Co-authored-by: coggsflod --- chat_test.go | 18 ++++++++++++++++++ internal/test/server.go | 14 +++++++++----- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/chat_test.go b/chat_test.go index ebe29f9eb..a43bb4aa6 100644 --- a/chat_test.go +++ b/chat_test.go @@ -67,6 +67,24 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestAzureChatCompletions(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) + + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateAzureChatCompletion error") +} + // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error diff --git a/internal/test/server.go b/internal/test/server.go index 79d55c405..3813ff869 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -4,6 +4,7 @@ import ( "log" "net/http" "net/http/httptest" + "regexp" ) const testAPI = "this-is-my-secure-token-do-not-steal!!" @@ -36,11 +37,14 @@ func (ts *ServerTest) OpenAITestServer() *httptest.Server { return } - handlerCall, ok := ts.handlers[r.URL.Path] - if !ok { - http.Error(w, "the resource path doesn't exist", http.StatusNotFound) - return + // Handle /path/* routes. + for route, handler := range ts.handlers { + pattern, _ := regexp.Compile(route) + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } } - handlerCall(w, r) + http.Error(w, "the resource path doesn't exist", http.StatusNotFound) })) } From 7e76a682a949cf234c05896d2e2aa3f7d5c5d118 Mon Sep 17 00:00:00 2001 From: beichideyuwan <57309366+beichideyuwan@users.noreply.github.com> Date: Wed, 14 Jun 2023 22:23:03 +0800 Subject: [PATCH 019/206] Add 16k 0613 model (#365) * add 16k_0613 model * add 16k_0613 model * add model: --- completion.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/completion.go b/completion.go index e7bf75acb..efded208b 100644 --- a/completion.go +++ b/completion.go @@ -26,6 +26,7 @@ const ( GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3TextDavinci003 = "text-davinci-003" GPT3TextDavinci002 = "text-davinci-002" @@ -52,16 +53,17 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { - GPT3Dot5Turbo: true, - GPT3Dot5Turbo0301: true, - GPT3Dot5Turbo0613: true, - GPT3Dot5Turbo16K: true, - GPT4: true, - GPT40314: true, - GPT40613: true, - GPT432K: true, - GPT432K0314: true, - GPT432K0613: true, + GPT3Dot5Turbo: true, + GPT3Dot5Turbo0301: true, + GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo16K: true, + GPT3Dot5Turbo16K0613: true, + GPT4: true, + GPT40314: true, + GPT40613: true, + GPT432K: true, + GPT432K0314: true, + GPT432K0613: true, }, "/chat/completions": { CodexCodeDavinci002: true, From 2bd65aa720926506c49ddf89d7e619b3b83512c4 Mon Sep 17 00:00:00 2001 From: Ccheers <1048315650@qq.com> Date: Thu, 15 Jun 2023 16:49:54 +0800 Subject: [PATCH 020/206] feat(chat): support function call api (#369) * feat(chat): support function call api * rename struct & add const ChatMessageRoleFunction --- chat.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++--- chat_stream.go | 2 +- completion.go | 2 +- 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/chat.go b/chat.go index a7ce5486a..c8cff319e 100644 --- a/chat.go +++ b/chat.go @@ -11,8 +11,11 @@ const ( ChatMessageRoleSystem = "system" ChatMessageRoleUser = "user" ChatMessageRoleAssistant = "assistant" + ChatMessageRoleFunction = "function" ) +const chatCompletionsSuffix = "/chat/completions" + var ( ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll @@ -27,6 +30,14 @@ type ChatCompletionMessage struct { // - https://github.com/openai/openai-python/blob/main/chatml.md // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + // call function with arguments in JSON format + Arguments string `json:"arguments,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. @@ -43,12 +54,70 @@ type ChatCompletionRequest struct { FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` LogitBias map[string]int `json:"logit_bias,omitempty"` User string `json:"user,omitempty"` + Functions []*FunctionDefine `json:"functions,omitempty"` + FunctionCall string `json:"function_call,omitempty"` } +type FunctionDefine struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + // it's required in function call + Parameters *FunctionParams `json:"parameters"` +} + +type FunctionParams struct { + // the Type must be JSONSchemaTypeObject + Type JSONSchemaType `json:"type"` + Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type JSONSchemaType string + +const ( + JSONSchemaTypeObject JSONSchemaType = "object" + JSONSchemaTypeNumber JSONSchemaType = "number" + JSONSchemaTypeString JSONSchemaType = "string" + JSONSchemaTypeArray JSONSchemaType = "array" + JSONSchemaTypeNull JSONSchemaType = "null" + JSONSchemaTypeBoolean JSONSchemaType = "boolean" +) + +// JSONSchemaDefine is a struct for JSON Schema. +type JSONSchemaDefine struct { + // Type is a type of JSON Schema. + Type JSONSchemaType `json:"type,omitempty"` + // Description is a description of JSON Schema. + Description string `json:"description,omitempty"` + // Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString. + Enum []string `json:"enum,omitempty"` + // Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject. + Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. + Required []string `json:"required,omitempty"` +} + +type FinishReason string + +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonLength FinishReason = "length" + FinishReasonFunctionCall FinishReason = "function_call" + FinishReasonContentFilter FinishReason = "content_filter" + FinishReasonNull FinishReason = "null" +) + type ChatCompletionChoice struct { - Index int `json:"index"` - Message ChatCompletionMessage `json:"message"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + // FinishReason + // stop: API returned complete message, + // or a message terminated by one of the stop sequences provided via the stop parameter + // length: Incomplete model output due to max_tokens parameter or token limit + // function_call: The model decided to call a function + // content_filter: Omitted content due to a flag from our content filters + // null: API response still in progress or incomplete + FinishReason FinishReason `json:"finish_reason"` } // ChatCompletionResponse represents a response structure for chat completion API. @@ -71,7 +140,7 @@ func (c *Client) CreateChatCompletion( return } - urlSuffix := "/chat/completions" + urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return diff --git a/chat_stream.go b/chat_stream.go index 625d436cb..c7341feac 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -40,7 +40,7 @@ func (c *Client) CreateChatCompletionStream( ctx context.Context, request ChatCompletionRequest, ) (stream *ChatCompletionStream, err error) { - urlSuffix := "/chat/completions" + urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return diff --git a/completion.go b/completion.go index efded208b..e0571b007 100644 --- a/completion.go +++ b/completion.go @@ -65,7 +65,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT432K0314: true, GPT432K0613: true, }, - "/chat/completions": { + chatCompletionsSuffix: { CodexCodeDavinci002: true, CodexCodeCushman001: true, CodexCodeDavinci001: true, From 43de77162f7f6a1f391efce7a56b75d0b63042a9 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 15 Jun 2023 12:53:52 +0400 Subject: [PATCH 021/206] Create FUNDING.yml (#371) --- .github/FUNDING.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..d9fd885a9 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: [sashabaranov] From 0bd14f9584baf8b47dd9251b674c26aed9c5a723 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 15 Jun 2023 17:58:26 +0800 Subject: [PATCH 022/206] refactor: ChatCompletionStreamChoice.FinishReason from string to FinishReason (#372) --- chat_stream.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat_stream.go b/chat_stream.go index c7341feac..9093bde9e 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -15,7 +15,7 @@ type ChatCompletionStreamChoiceDelta struct { type ChatCompletionStreamChoice struct { Index int `json:"index"` Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason string `json:"finish_reason"` + FinishReason FinishReason `json:"finish_reason"` } type ChatCompletionStreamResponse struct { From ac25f318ba29e1461ceec19f40bf5fd7765b7225 Mon Sep 17 00:00:00 2001 From: Alex Wormuth Date: Fri, 16 Jun 2023 08:11:50 -0500 Subject: [PATCH 023/206] add items, which is required for array type (#373) * add items, which is required for array type * use JSONSchemaDefine directly --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index c8cff319e..4764e36ba 100644 --- a/chat.go +++ b/chat.go @@ -95,6 +95,8 @@ type JSONSchemaDefine struct { Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. Required []string `json:"required,omitempty"` + // Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray. + Items *JSONSchemaDefine `json:"items,omitempty"` } type FinishReason string From f0770cfe1d5094d5d40a878658abf535bbdcec4c Mon Sep 17 00:00:00 2001 From: romazu Date: Fri, 16 Jun 2023 17:13:26 +0400 Subject: [PATCH 024/206] audio: add items to AudioResponseFormat enum (#382) * audio: add items to AudioResponseFormat enum * audio: expand AudioResponse struct to accommodate verbose json response --------- Co-authored-by: Roman Zubov --- audio.go | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/audio.go b/audio.go index 20e865f11..adfc52766 100644 --- a/audio.go +++ b/audio.go @@ -20,9 +20,11 @@ const ( type AudioResponseFormat string const ( - AudioResponseFormatJSON AudioResponseFormat = "json" - AudioResponseFormatSRT AudioResponseFormat = "srt" - AudioResponseFormatVTT AudioResponseFormat = "vtt" + AudioResponseFormatJSON AudioResponseFormat = "json" + AudioResponseFormatText AudioResponseFormat = "text" + AudioResponseFormatSRT AudioResponseFormat = "srt" + AudioResponseFormatVerboseJSON AudioResponseFormat = "verbose_json" + AudioResponseFormatVTT AudioResponseFormat = "vtt" ) // AudioRequest represents a request structure for audio API. @@ -44,6 +46,22 @@ type AudioRequest struct { // AudioResponse represents a response structure for audio API. type AudioResponse struct { + Task string `json:"task"` + Language string `json:"language"` + Duration float64 `json:"duration"` + Segments []struct { + ID int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` + Transient bool `json:"transient"` + } `json:"segments"` Text string `json:"text"` } @@ -96,7 +114,7 @@ func (c *Client) callAudioAPI( // HasJSONResponse returns true if the response format is JSON. func (r AudioRequest) HasJSONResponse() bool { - return r.Format == "" || r.Format == AudioResponseFormatJSON + return r.Format == "" || r.Format == AudioResponseFormatJSON || r.Format == AudioResponseFormatVerboseJSON } // audioMultipartForm creates a form with audio file contents and the name of the model to use for From e49d771fff3bc699bca7cf22c9d93b67316047e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sat, 17 Jun 2023 22:57:29 +0900 Subject: [PATCH 025/206] support for parsing error response message fields even if they are arrays (#381) (#384) --- api_test.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++++---- error.go | 10 ++++- 2 files changed, 111 insertions(+), 8 deletions(-) diff --git a/api_test.go b/api_test.go index 083b67412..34173708f 100644 --- a/api_test.go +++ b/api_test.go @@ -137,6 +137,108 @@ func TestAPIError(t *testing.T) { } } +func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) { + type testCase struct { + name string + response string + hasError bool + checkFn func(t *testing.T, apiErr APIError) + } + testCases := []testCase{ + { + name: "parse succeeds when the message is string", + response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is array with single item", + response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is array with multiple items", + response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo, bar, baz" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is empty array", + response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + if apiErr.Message != "" { + t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) + } + }, + }, + { + name: "parse succeeds when the message is null", + response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + if apiErr.Message != "" { + t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) + } + }, + }, + { + name: "parse failed when the message is object", + response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is int", + response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is float", + response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is bool", + response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is not exists", + response: `{"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var apiErr APIError + err := json.Unmarshal([]byte(tc.response), &apiErr) + if (err != nil) != tc.hasError { + t.Errorf("Unexpected error: %v", err) + return + } + if tc.checkFn != nil { + tc.checkFn(t, apiErr) + } + }) + } +} + func TestAPIErrorUnmarshalJSONInteger(t *testing.T) { var apiErr APIError response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` @@ -217,13 +319,6 @@ func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) { checks.HasError(t, err, "Type should be a string") } -func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Message should be a string") -} - func TestRequestError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/error.go b/error.go index b789ed7d5..f68e92875 100644 --- a/error.go +++ b/error.go @@ -3,6 +3,7 @@ package openai import ( "encoding/json" "fmt" + "strings" ) // APIError provides error information returned by the OpenAI API. @@ -41,7 +42,14 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { err = json.Unmarshal(rawMap["message"], &e.Message) if err != nil { - return + // If the parameter field of a function call is invalid as a JSON schema + // refs: https://github.com/sashabaranov/go-openai/issues/381 + var messages []string + err = json.Unmarshal(rawMap["message"], &messages) + if err != nil { + return + } + e.Message = strings.Join(messages, ", ") } // optional fields for azure openai From b0959382c8fc01bf12de71a843d961f0d579f6f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sun, 18 Jun 2023 19:51:20 +0900 Subject: [PATCH 026/206] extract and split integration tests (#389) --- api_integration_test.go | 136 ++++++++++++++++ api_test.go | 353 ---------------------------------------- engines_test.go | 11 ++ error_test.go | 201 +++++++++++++++++++++++ openai_test.go | 9 + 5 files changed, 357 insertions(+), 353 deletions(-) create mode 100644 api_integration_test.go delete mode 100644 api_test.go create mode 100644 error_test.go diff --git a/api_integration_test.go b/api_integration_test.go new file mode 100644 index 000000000..3cafa24b4 --- /dev/null +++ b/api_integration_test.go @@ -0,0 +1,136 @@ +package openai_test + +import ( + "context" + "errors" + "io" + "os" + "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestAPI(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken) + ctx := context.Background() + _, err = c.ListEngines(ctx) + checks.NoError(t, err, "ListEngines error") + + _, err = c.GetEngine(ctx, "davinci") + checks.NoError(t, err, "GetEngine error") + + fileRes, err := c.ListFiles(ctx) + checks.NoError(t, err, "ListFiles error") + + if len(fileRes.Files) > 0 { + _, err = c.GetFile(ctx, fileRes.Files[0].ID) + checks.NoError(t, err, "GetFile error") + } // else skip + + embeddingReq := EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: AdaSearchQuery, + } + _, err = c.CreateEmbeddings(ctx, embeddingReq) + checks.NoError(t, err, "Embedding error") + + _, err = c.CreateChatCompletion( + ctx, + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + checks.NoError(t, err, "CreateChatCompletion (without name) returned error") + + _, err = c.CreateChatCompletion( + ctx, + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Name: "John_Doe", + Content: "Hello!", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (with name) returned error") + + stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: GPT3Ada, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + +func TestAPIError(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken + "_invalid") + ctx := context.Background() + _, err = c.ListEngines(ctx) + checks.HasError(t, err, "ListEngines should fail with an invalid key") + + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Error is not an APIError: %+v", err) + } + + if apiErr.HTTPStatusCode != 401 { + t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) + } + + switch v := apiErr.Code.(type) { + case string: + if v != "invalid_api_key" { + t.Fatalf("Unexpected API error code: %s", v) + } + default: + t.Fatalf("Unexpected API error code type: %T", v) + } + + if apiErr.Error() == "" { + t.Fatal("Empty error message occurred") + } +} diff --git a/api_test.go b/api_test.go deleted file mode 100644 index 34173708f..000000000 --- a/api_test.go +++ /dev/null @@ -1,353 +0,0 @@ -package openai_test - -import ( - "context" - "encoding/json" - "errors" - "io" - "net/http" - "os" - "testing" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" -) - -func TestAPI(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken) - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.NoError(t, err, "ListEngines error") - - _, err = c.GetEngine(ctx, "davinci") - checks.NoError(t, err, "GetEngine error") - - fileRes, err := c.ListFiles(ctx) - checks.NoError(t, err, "ListFiles error") - - if len(fileRes.Files) > 0 { - _, err = c.GetFile(ctx, fileRes.Files[0].ID) - checks.NoError(t, err, "GetFile error") - } // else skip - - embeddingReq := EmbeddingRequest{ - Input: []string{ - "The food was delicious and the waiter", - "Other examples of embedding request", - }, - Model: AdaSearchQuery, - } - _, err = c.CreateEmbeddings(ctx, embeddingReq) - checks.NoError(t, err, "Embedding error") - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Content: "Hello!", - }, - }, - }, - ) - - checks.NoError(t, err, "CreateChatCompletion (without name) returned error") - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Name: "John_Doe", - Content: "Hello!", - }, - }, - }, - ) - checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - - stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } -} - -func TestAPIError(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken + "_invalid") - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.HasError(t, err, "ListEngines should fail with an invalid key") - - var apiErr *APIError - if !errors.As(err, &apiErr) { - t.Fatalf("Error is not an APIError: %+v", err) - } - - if apiErr.HTTPStatusCode != 401 { - t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) - } - - switch v := apiErr.Code.(type) { - case string: - if v != "invalid_api_key" { - t.Fatalf("Unexpected API error code: %s", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } - - if apiErr.Error() == "" { - t.Fatal("Empty error message occurred") - } -} - -func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) { - type testCase struct { - name string - response string - hasError bool - checkFn func(t *testing.T, apiErr APIError) - } - testCases := []testCase{ - { - name: "parse succeeds when the message is string", - response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is array with single item", - response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is array with multiple items", - response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo, bar, baz" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is empty array", - response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - if apiErr.Message != "" { - t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) - } - }, - }, - { - name: "parse succeeds when the message is null", - response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - if apiErr.Message != "" { - t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) - } - }, - }, - { - name: "parse failed when the message is object", - response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is int", - response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is float", - response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is bool", - response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is not exists", - response: `{"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var apiErr APIError - err := json.Unmarshal([]byte(tc.response), &apiErr) - if (err != nil) != tc.hasError { - t.Errorf("Unexpected error: %v", err) - return - } - if tc.checkFn != nil { - tc.checkFn(t, apiErr) - } - }) - } -} - -func TestAPIErrorUnmarshalJSONInteger(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case int: - if v != 418 { - t.Fatalf("Unexpected API code integer: %d; expected 418", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalJSONString(t *testing.T) { - var apiErr APIError - response := `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case string: - if v != "teapot" { - t.Fatalf("Unexpected API code string: %s; expected `teapot`", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalJSONNoCode(t *testing.T) { - // test integer code - response := `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - var apiErr APIError - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case nil: - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalInvalidData(t *testing.T) { - apiErr := APIError{} - data := []byte(`--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`) - err := apiErr.UnmarshalJSON(data) - checks.HasError(t, err, "Expected error when unmarshaling invalid data") - - if apiErr.Code != nil { - t.Fatalf("Expected nil code, got %q", apiErr.Code) - } - if apiErr.Message != "" { - t.Fatalf("Expected empty message, got %q", apiErr.Message) - } - if apiErr.Param != nil { - t.Fatalf("Expected nil param, got %q", *apiErr.Param) - } - if apiErr.Type != "" { - t.Fatalf("Expected empty type, got %q", apiErr.Type) - } -} - -func TestAPIErrorUnmarshalJSONInvalidParam(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Param should be a string") -} - -func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Type should be a string") -} - -func TestRequestError(t *testing.T) { - client, server, teardown := setupOpenAITestServer() - defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusTeapot) - }) - - _, err := client.ListEngines(context.Background()) - checks.HasError(t, err, "ListEngines did not fail") - - var reqErr *RequestError - if !errors.As(err, &reqErr) { - t.Fatalf("Error is not a RequestError: %+v", err) - } - - if reqErr.HTTPStatusCode != 418 { - t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) - } - - if reqErr.Unwrap() == nil { - t.Fatalf("Empty request error occurred") - } -} - -// numTokens Returns the number of GPT-3 encoded tokens in the given text. -// This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer -// -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) -func numTokens(s string) int { - return int(float32(len(s)) / 4) -} diff --git a/engines_test.go b/engines_test.go index 2beb333b3..31e7ec8be 100644 --- a/engines_test.go +++ b/engines_test.go @@ -34,3 +34,14 @@ func TestListEngines(t *testing.T) { _, err := client.ListEngines(context.Background()) checks.NoError(t, err, "ListEngines error") } + +func TestListEnginesReturnError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + }) + + _, err := client.ListEngines(context.Background()) + checks.HasError(t, err, "ListEngines did not fail") +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 000000000..e2309abd7 --- /dev/null +++ b/error_test.go @@ -0,0 +1,201 @@ +package openai_test + +import ( + "errors" + "net/http" + "testing" + + . "github.com/sashabaranov/go-openai" +) + +func TestAPIErrorUnmarshalJSON(t *testing.T) { + type testCase struct { + name string + response string + hasError bool + checkFunc func(t *testing.T, apiErr APIError) + } + testCases := []testCase{ + // testcase for message field + { + name: "parse succeeds when the message is string", + response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo") + }, + }, + { + name: "parse succeeds when the message is array with single item", + response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo") + }, + }, + { + name: "parse succeeds when the message is array with multiple items", + response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo, bar, baz") + }, + }, + { + name: "parse succeeds when the message is empty array", + response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "") + }, + }, + { + name: "parse succeeds when the message is null", + response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "") + }, + }, + { + name: "parse failed when the message is object", + response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is int", + response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is float", + response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is bool", + response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is not exists", + response: `{"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + // testcase for code field + { + name: "parse succeeds when the code is int", + response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, 418) + }, + }, + { + name: "parse succeeds when the code is string", + response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, "teapot") + }, + }, + { + name: "parse succeeds when the code is not exists", + response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, nil) + }, + }, + // testcase for param field + { + name: "parse failed when the param is bool", + response: `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}`, + hasError: true, + }, + // testcase for type field + { + name: "parse failed when the type is bool", + response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}`, + hasError: true, + }, + // testcase for error response + { + name: "parse failed when the response is invalid json", + response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, nil) + assertAPIErrorMessage(t, apiErr, "") + assertAPIErrorParam(t, apiErr, nil) + assertAPIErrorType(t, apiErr, "") + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var apiErr APIError + err := apiErr.UnmarshalJSON([]byte(tc.response)) + if (err != nil) != tc.hasError { + t.Errorf("Unexpected error: %v", err) + } + if tc.checkFunc != nil { + tc.checkFunc(t, apiErr) + } + }) + } +} + +func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { + if apiErr.Message != expected { + t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) + } +} + +func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { + switch v := apiErr.Code.(type) { + case int: + if v != expected { + t.Errorf("Unexpected APIError code integer: %d; expected %d", v, expected) + } + case string: + if v != expected { + t.Errorf("Unexpected APIError code string: %s; expected %s", v, expected) + } + case nil: + default: + t.Errorf("Unexpected APIError error code type: %T", v) + } +} + +func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) { + if apiErr.Param != expected { + t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) + } +} + +func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) { + if apiErr.Type != typ { + t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) + } +} + +func TestRequestError(t *testing.T) { + var err error = &RequestError{ + HTTPStatusCode: http.StatusTeapot, + Err: errors.New("i am a teapot"), + } + + var reqErr *RequestError + if !errors.As(err, &reqErr) { + t.Fatalf("Error is not a RequestError: %+v", err) + } + + if reqErr.HTTPStatusCode != 418 { + t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) + } + + if reqErr.Unwrap() == nil { + t.Fatalf("Empty request error occurred") + } +} diff --git a/openai_test.go b/openai_test.go index a5e7b64ee..4fc41ecc0 100644 --- a/openai_test.go +++ b/openai_test.go @@ -26,3 +26,12 @@ func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown f client = NewClientWithConfig(config) return } + +// numTokens Returns the number of GPT-3 encoded tokens in the given text. +// This function approximates based on the rule of thumb stated by OpenAI: +// https://beta.openai.com/tokenizer +// +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +func numTokens(s string) int { + return int(float32(len(s)) / 4) +} From 68f9ef92beeb368eb77ea1bb206abedb5066501b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 19 Jun 2023 17:12:38 +0900 Subject: [PATCH 027/206] split integration test using go build tag (#392) --- README.md | 13 +++++++++++++ api_integration_test.go | 2 ++ 2 files changed, 15 insertions(+) diff --git a/README.md b/README.md index 7562694df..9a7262332 100644 --- a/README.md +++ b/README.md @@ -542,3 +542,16 @@ if errors.As(err, &e) { See the `examples/` folder for more. +### Integration tests: + +Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. + +**Notes:** +These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. + +**Run tests using:** +``` +OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go +``` + +If `OPENAI_TOKEN` environment variables are not available, integration tests will be skipped. \ No newline at end of file diff --git a/api_integration_test.go b/api_integration_test.go index 3cafa24b4..d4e7328a2 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -1,3 +1,5 @@ +//go:build integration + package openai_test import ( From 720377087fae943d15000d47c7c9ea0a214798b1 Mon Sep 17 00:00:00 2001 From: cem-unuvar <87916654+cem-unuvar@users.noreply.github.com> Date: Tue, 20 Jun 2023 18:33:53 +0300 Subject: [PATCH 028/206] feat: added function call info to chat completions (#390) --- chat_stream.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 9093bde9e..75aa6858a 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -8,8 +8,9 @@ import ( ) type ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` } type ChatCompletionStreamChoice struct { From e948150829ac980f3aea86ed1d73aa2fc5a7f12b Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 20 Jun 2023 23:39:19 +0800 Subject: [PATCH 029/206] fix: chat stream returns an error response with a 'data: ' prefix (#396) * fix: chat stream resp has 'data: ' prefix * fix: lint error * fix: lint error * fix: lint error --- chat_stream_test.go | 39 +++++++++++++++++++++++++++++++++++++++ stream_reader.go | 22 ++++++++++++++++++---- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index c3cb9f3f7..5fc70b032 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -178,6 +178,45 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, streamErr := stream.Recv() + checks.HasError(t, streamErr, "stream.Recv() did not return error") + + var apiErr *APIError + if !errors.As(streamErr, &apiErr) { + t.Errorf("stream.Recv() did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/stream_reader.go b/stream_reader.go index 34161986e..87e59e0ca 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -10,6 +10,11 @@ import ( utils "github.com/sashabaranov/go-openai/internal" ) +var ( + headerData = []byte("data: ") + errorPrefix = []byte(`data: {"error":`) +) + type streamable interface { ChatCompletionStreamResponse | CompletionResponse } @@ -34,12 +39,16 @@ func (stream *streamReader[T]) Recv() (response T, err error) { return } +//nolint:gocognit func (stream *streamReader[T]) processLines() (T, error) { - var emptyMessagesCount uint + var ( + emptyMessagesCount uint + hasErrorPrefix bool + ) for { rawLine, readErr := stream.reader.ReadBytes('\n') - if readErr != nil { + if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { return *new(T), fmt.Errorf("error, %w", respErr.Error) @@ -47,9 +56,14 @@ func (stream *streamReader[T]) processLines() (T, error) { return *new(T), readErr } - var headerData = []byte("data: ") noSpaceLine := bytes.TrimSpace(rawLine) - if !bytes.HasPrefix(noSpaceLine, headerData) { + if bytes.HasPrefix(noSpaceLine, errorPrefix) { + hasErrorPrefix = true + } + if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if hasErrorPrefix { + noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { return *new(T), writeErr From f22da8a7ed896d19661dfcce3e330e4b209b2eb3 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Wed, 21 Jun 2023 08:58:27 -0400 Subject: [PATCH 030/206] feat: allow more input types to functions, fix tests (#377) * feat: use json.rawMessage, test functions * chore: lint * fix: tests the ChatCompletion mock server doesn't actually run otherwise. N=0 is the default request but the server will treat it as n=1 * fix: tests should default to n=1 completions * chore: add back removed interfaces, custom marshal * chore: lint * chore: lint * chore: add some tests * chore: appease lint * clean up JSON schema + tests * chore: lint * feat: remove backwards compatible functions for illustrative purposes * fix: revert params change * chore: use interface{} * chore: add test * chore: add back FunctionDefine * chore: /s/interface{}/any * chore: add back jsonschemadefinition * chore: testcov * chore: lint * chore: remove pointers * chore: update comment * chore: address CR added test for compatibility as well --------- Co-authored-by: James --- chat.go | 34 +++++----- chat_test.go | 157 ++++++++++++++++++++++++++++++++++++++++++++- completion_test.go | 10 ++- 3 files changed, 180 insertions(+), 21 deletions(-) diff --git a/chat.go b/chat.go index 4764e36ba..f99af2735 100644 --- a/chat.go +++ b/chat.go @@ -54,23 +54,23 @@ type ChatCompletionRequest struct { FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` LogitBias map[string]int `json:"logit_bias,omitempty"` User string `json:"user,omitempty"` - Functions []*FunctionDefine `json:"functions,omitempty"` - FunctionCall string `json:"function_call,omitempty"` + Functions []FunctionDefinition `json:"functions,omitempty"` + FunctionCall any `json:"function_call,omitempty"` } -type FunctionDefine struct { +type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` - // it's required in function call - Parameters *FunctionParams `json:"parameters"` + // Parameters is an object describing the function. + // You can pass a raw byte array describing the schema, + // or you can pass in a struct which serializes to the proper JSONSchema. + // The JSONSchemaDefinition struct is provided for convenience, but you should + // consider another specialized library for more complex schemas. + Parameters any `json:"parameters"` } -type FunctionParams struct { - // the Type must be JSONSchemaTypeObject - Type JSONSchemaType `json:"type"` - Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` -} +// Deprecated: use FunctionDefinition instead. +type FunctionDefine = FunctionDefinition type JSONSchemaType string @@ -83,8 +83,9 @@ const ( JSONSchemaTypeBoolean JSONSchemaType = "boolean" ) -// JSONSchemaDefine is a struct for JSON Schema. -type JSONSchemaDefine struct { +// JSONSchemaDefinition is a struct for JSON Schema. +// It is fairly limited and you may have better luck using a third-party library. +type JSONSchemaDefinition struct { // Type is a type of JSON Schema. Type JSONSchemaType `json:"type,omitempty"` // Description is a description of JSON Schema. @@ -92,13 +93,16 @@ type JSONSchemaDefine struct { // Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString. Enum []string `json:"enum,omitempty"` // Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject. - Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"` // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. Required []string `json:"required,omitempty"` // Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray. - Items *JSONSchemaDefine `json:"items,omitempty"` + Items *JSONSchemaDefinition `json:"items,omitempty"` } +// Deprecated: use JSONSchemaDefinition instead. +type JSONSchemaDefine = JSONSchemaDefinition + type FinishReason string const ( diff --git a/chat_test.go b/chat_test.go index a43bb4aa6..3c759b310 100644 --- a/chat_test.go +++ b/chat_test.go @@ -67,6 +67,130 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestChatCompletionsFunctions tests including a function call. +func TestChatCompletionsFunctions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + t.Run("bytes", func(t *testing.T) { + //nolint:lll + msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefine{{ + Name: "test", + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("struct", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefinition{{ + Name: "test", + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("JSONSchemaDefine", func(t *testing.T) { + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefinition{{ + Name: "test", + Parameters: &JSONSchemaDefinition{ + Type: JSONSchemaTypeObject, + Properties: map[string]JSONSchemaDefinition{ + "count": { + Type: JSONSchemaTypeNumber, + Description: "total number of words in sentence", + }, + "words": { + Type: JSONSchemaTypeArray, + Description: "list of words in sentence", + Items: &JSONSchemaDefinition{ + Type: JSONSchemaTypeString, + }, + }, + "enumTest": { + Type: JSONSchemaTypeString, + Enum: []string{"hello", "world"}, + }, + }, + }, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) { + // this is a compatibility check + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefine{{ + Name: "test", + Parameters: &JSONSchemaDefine{ + Type: JSONSchemaTypeObject, + Properties: map[string]JSONSchemaDefine{ + "count": { + Type: JSONSchemaTypeNumber, + Description: "total number of words in sentence", + }, + "words": { + Type: JSONSchemaTypeArray, + Description: "list of words in sentence", + Items: &JSONSchemaDefine{ + Type: JSONSchemaTypeString, + }, + }, + "enumTest": { + Type: JSONSchemaTypeString, + Enum: []string{"hello", "world"}, + }, + }, + }, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) +} + func TestAzureChatCompletions(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() @@ -109,7 +233,34 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { + n := completionReq.N + if n == 0 { + n = 1 + } + for i := 0; i < n; i++ { + // if there are functions, include them + if len(completionReq.Functions) > 0 { + var fcb []byte + b := completionReq.Functions[0].Parameters + fcb, err = json.Marshal(b) + if err != nil { + http.Error(w, "could not marshal function parameters", http.StatusInternalServerError) + return + } + + res.Choices = append(res.Choices, ChatCompletionChoice{ + Message: ChatCompletionMessage{ + Role: ChatMessageRoleFunction, + // this is valid json so it should be fine + FunctionCall: &FunctionCall{ + Name: completionReq.Functions[0].Name, + Arguments: string(fcb), + }, + }, + Index: i, + }) + continue + } // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) @@ -121,8 +272,8 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Index: i, }) } - inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n res.Usage = Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, diff --git a/completion_test.go b/completion_test.go index aeddcfca1..844ef484f 100644 --- a/completion_test.go +++ b/completion_test.go @@ -83,7 +83,11 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { + n := completionReq.N + if n == 0 { + n = 1 + } + for i := 0; i < n; i++ { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) if completionReq.Echo { @@ -94,8 +98,8 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Index: i, }) } - inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N + inputTokens := numTokens(completionReq.Prompt.(string)) * n + completionTokens := completionReq.MaxTokens * n res.Usage = Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, From e19b074a114a5add5f005911668d0cda8476a908 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 21 Jun 2023 23:53:15 +0900 Subject: [PATCH 031/206] docs: add requires go version in README.md (#397) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9a7262332..5f166dc31 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op ``` go get github.com/sashabaranov/go-openai ``` - +Currently, go-openai requires Go version 1.18 or greater. ### ChatGPT example usage: @@ -554,4 +554,4 @@ These tests send real network traffic to the OpenAI API and may reach rate limit OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go ``` -If `OPENAI_TOKEN` environment variables are not available, integration tests will be skipped. \ No newline at end of file +If `OPENAI_TOKEN` environment variables are not available, integration tests will be skipped. From ffa7abc050b22b068ed16680de3b96ef26211651 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Wed, 21 Jun 2023 18:54:10 +0400 Subject: [PATCH 032/206] Update README.md (#399) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5f166dc31..522a85e78 100644 --- a/README.md +++ b/README.md @@ -554,4 +554,4 @@ These tests send real network traffic to the OpenAI API and may reach rate limit OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go ``` -If `OPENAI_TOKEN` environment variables are not available, integration tests will be skipped. +If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. From 157de0680f39f7c521cdd79bf69fb66390380c17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Thu, 22 Jun 2023 18:49:46 +0900 Subject: [PATCH 033/206] add vvatanabe to FUNDING.yml (#402) --- .github/FUNDING.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index d9fd885a9..e36c38239 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,3 +1,3 @@ # These are supported funding model platforms -github: [sashabaranov] +github: [sashabaranov, vvatanabe] From f1b66967a426c3dfaf5e652b118d807cf1e7473f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Thu, 22 Jun 2023 18:57:52 +0900 Subject: [PATCH 034/206] refactor: refactoring http request creation and sending (#395) * refactoring http request creation and sending * fix lint error * increase the test coverage of client.go * refactor: Change the style of HTTPRequestBuilder.Build func to one-argument-per-line. --- api_internal_test.go | 2 +- audio.go | 4 +- chat.go | 2 +- chat_stream.go | 22 ++------ client.go | 94 ++++++++++++++++++++++++-------- client_test.go | 25 +++++++-- completion.go | 2 +- edits.go | 2 +- embeddings.go | 2 +- engines.go | 4 +- files.go | 28 +++------- fine_tunes.go | 12 ++-- image.go | 13 ++--- internal/request_builder.go | 42 +++++++++----- internal/request_builder_test.go | 6 +- models.go | 4 +- models_test.go | 22 ++++++++ moderation.go | 2 +- stream.go | 21 ++----- stream_test.go | 26 +++++++++ 20 files changed, 209 insertions(+), 126 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 214b627bf..0fb0f8993 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) { az.OrgID = c.OrgID cli := NewClientWithConfig(az) - req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "") + req, err := cli.newRequest(context.Background(), "POST", "/chat/completions") if err != nil { t.Errorf("Failed to create request: %v", err) } diff --git a/audio.go b/audio.go index adfc52766..9f469159d 100644 --- a/audio.go +++ b/audio.go @@ -95,11 +95,11 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), + withBody(&formBody), withContentType(builder.FormDataContentType())) if err != nil { return AudioResponse{}, err } - req.Header.Add("Content-Type", builder.FormDataContentType()) if request.HasJSONResponse() { err = c.sendRequest(req, &response) diff --git a/chat.go b/chat.go index f99af2735..b74720d38 100644 --- a/chat.go +++ b/chat.go @@ -152,7 +152,7 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 75aa6858a..9f4e80cff 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -1,10 +1,8 @@ package openai import ( - "bufio" "context" - - utils "github.com/sashabaranov/go-openai/internal" + "net/http" ) type ChatCompletionStreamChoiceDelta struct { @@ -48,27 +46,17 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { - return + return nil, err } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req) if err != nil { return } - if isFailureStatusCode(resp) { - return nil, c.handleErrorResp(resp) - } - stream = &ChatCompletionStream{ - streamReader: &streamReader[ChatCompletionStreamResponse]{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: utils.NewErrorAccumulator(), - unmarshaler: &utils.JSONUnmarshaler{}, - }, + streamReader: resp, } return } diff --git a/client.go b/client.go index f38c1dfc3..5779a8e1c 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package openai import ( + "bufio" "context" "encoding/json" "fmt" @@ -45,6 +46,42 @@ func NewOrgClient(authToken, org string) *Client { return NewClientWithConfig(config) } +type requestOptions struct { + body any + header http.Header +} + +type requestOption func(*requestOptions) + +func withBody(body any) requestOption { + return func(args *requestOptions) { + args.body = body + } +} + +func withContentType(contentType string) requestOption { + return func(args *requestOptions) { + args.header.Set("Content-Type", contentType) + } +} + +func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { + // Default Options + args := &requestOptions{ + body: nil, + header: make(http.Header), + } + for _, setter := range setters { + setter(args) + } + req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header) + if err != nil { + return nil, err + } + c.setCommonHeaders(req) + return req, nil +} + func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Accept", "application/json; charset=utf-8") @@ -55,8 +92,6 @@ func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Content-Type", "application/json; charset=utf-8") } - c.setCommonHeaders(req) - res, err := c.config.HTTPClient.Do(req) if err != nil { return err @@ -71,6 +106,41 @@ func (c *Client) sendRequest(req *http.Request, v any) error { return decodeResponse(res.Body, v) } +func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { + resp, err := c.config.HTTPClient.Do(req) + if err != nil { + return + } + + if isFailureStatusCode(resp) { + err = c.handleErrorResp(resp) + return + } + return resp.Body, nil +} + +func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + if err != nil { + return new(streamReader[T]), err + } + if isFailureStatusCode(resp) { + return new(streamReader[T]), client.handleErrorResp(resp) + } + return &streamReader[T]{ + emptyMessagesLimit: client.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), + response: resp, + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + }, nil +} + func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication @@ -138,26 +208,6 @@ func (c *Client) fullURL(suffix string, args ...any) string { return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } -func (c *Client) newStreamRequest( - ctx context.Context, - method string, - urlSuffix string, - body any, - model string) (*http.Request, error) { - req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - - c.setCommonHeaders(req) - return req, nil -} - func (c *Client) handleErrorResp(resp *http.Response) error { var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) diff --git a/client_test.go b/client_test.go index 00b66feae..29d84edfa 100644 --- a/client_test.go +++ b/client_test.go @@ -16,7 +16,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed") type failingRequestBuilder struct{} -func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) { +func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) { return nil, errTestRequestBuilderFailed } @@ -41,9 +41,10 @@ func TestDecodeResponse(t *testing.T) { stringInput := "" testCases := []struct { - name string - value interface{} - body io.Reader + name string + value interface{} + body io.Reader + hasError bool }{ { name: "nil input", @@ -60,18 +61,32 @@ func TestDecodeResponse(t *testing.T) { value: &map[string]interface{}{}, body: bytes.NewReader([]byte(`{"test": "test"}`)), }, + { + name: "reader return error", + value: &stringInput, + body: &errorReader{err: errors.New("dummy")}, + hasError: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := decodeResponse(tc.body, tc.value) - if err != nil { + if (err != nil) != tc.hasError { t.Errorf("Unexpected error: %v", err) } }) } } +type errorReader struct { + err error +} + +func (e *errorReader) Read(_ []byte) (n int, err error) { + return 0, e.err +} + func TestHandleErrorResp(t *testing.T) { // var errRes *ErrorResponse var errRes ErrorResponse diff --git a/completion.go b/completion.go index e0571b007..b3b3abd1c 100644 --- a/completion.go +++ b/completion.go @@ -165,7 +165,7 @@ func (c *Client) CreateCompletion( return } - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } diff --git a/edits.go b/edits.go index 23b1a64f0..3d3fc8950 100644 --- a/edits.go +++ b/edits.go @@ -32,7 +32,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index 942f3ea3a..ba327ce77 100644 --- a/embeddings.go +++ b/embeddings.go @@ -132,7 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request)) if err != nil { return } diff --git a/engines.go b/engines.go index ac01a00ed..adf6025c2 100644 --- a/engines.go +++ b/engines.go @@ -22,7 +22,7 @@ type EnginesList struct { // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines")) if err != nil { return } @@ -38,7 +38,7 @@ func (c *Client) GetEngine( engineID string, ) (engine Engine, err error) { urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } diff --git a/files.go b/files.go index fb9937bea..ea1f50a73 100644 --- a/files.go +++ b/files.go @@ -57,21 +57,19 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) - err = c.sendRequest(req, &file) - return } // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/files/"+fileID)) if err != nil { return } @@ -83,7 +81,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/files")) if err != nil { return } @@ -96,7 +94,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } @@ -107,23 +105,11 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { urlSuffix := fmt.Sprintf("/files/%s/content", fileID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) - if err != nil { - return - } - - c.setCommonHeaders(req) - - res, err := c.config.HTTPClient.Do(req) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } - if isFailureStatusCode(res) { - err = c.handleErrorResp(res) - return - } - - content = res.Body + content, err = c.sendRequestRaw(req) return } diff --git a/fine_tunes.go b/fine_tunes.go index 069ddccfd..96e731d51 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct { func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) if err != nil { return } @@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // CancelFineTune cancel a fine-tune job. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return } @@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) if err != nil { return } @@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } @@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F } func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) if err != nil { return } @@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) if err != nil { return } diff --git a/image.go b/image.go index df7363865..cb96f4f5e 100644 --- a/image.go +++ b/image.go @@ -44,7 +44,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) if err != nil { return } @@ -107,13 +107,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - urlSuffix := "/images/edits" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits"), + withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } @@ -158,14 +157,12 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - //https://platform.openai.com/docs/api-reference/images/create-variation - urlSuffix := "/images/variations" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations"), + withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } diff --git a/internal/request_builder.go b/internal/request_builder.go index 0a9eabfde..5699f6b18 100644 --- a/internal/request_builder.go +++ b/internal/request_builder.go @@ -3,11 +3,12 @@ package openai import ( "bytes" "context" + "io" "net/http" ) type RequestBuilder interface { - Build(ctx context.Context, method, url string, request any) (*http.Request, error) + Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error) } type HTTPRequestBuilder struct { @@ -20,21 +21,32 @@ func NewRequestBuilder() *HTTPRequestBuilder { } } -func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) { - if request == nil { - return http.NewRequestWithContext(ctx, method, url, nil) +func (b *HTTPRequestBuilder) Build( + ctx context.Context, + method string, + url string, + body any, + header http.Header, +) (req *http.Request, err error) { + var bodyReader io.Reader + if body != nil { + if v, ok := body.(io.Reader); ok { + bodyReader = v + } else { + var reqBytes []byte + reqBytes, err = b.marshaller.Marshal(body) + if err != nil { + return + } + bodyReader = bytes.NewBuffer(reqBytes) + } } - - var reqBytes []byte - reqBytes, err := b.marshaller.Marshal(request) + req, err = http.NewRequestWithContext(ctx, method, url, bodyReader) if err != nil { - return nil, err + return } - - return http.NewRequestWithContext( - ctx, - method, - url, - bytes.NewBuffer(reqBytes), - ) + if header != nil { + req.Header = header + } + return } diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e47d0f6ca..e26022a6b 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -22,7 +22,7 @@ func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { marshaller: &failingMarshaller{}, } - _, err := builder.Build(context.Background(), "", "", struct{}{}) + _, err := builder.Build(context.Background(), "", "", struct{}{}, nil) if !errors.Is(err, errTestMarshallerFailed) { t.Fatalf("Did not return error when marshaller failed: %v", err) } @@ -38,7 +38,7 @@ func TestRequestBuilderReturnsRequest(t *testing.T) { reqBytes, _ = b.marshaller.Marshal(request) want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) ) - got, _ := b.Build(ctx, method, url, request) + got, _ := b.Build(ctx, method, url, request, nil) if !reflect.DeepEqual(got.Body, want.Body) || !reflect.DeepEqual(got.URL, want.URL) || !reflect.DeepEqual(got.Method, want.Method) { @@ -54,7 +54,7 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { want, _ = http.NewRequestWithContext(ctx, method, url, nil) ) b := NewRequestBuilder() - got, _ := b.Build(ctx, method, url, nil) + got, _ := b.Build(ctx, method, url, nil, nil) if !reflect.DeepEqual(got, want) { t.Errorf("Build() got = %v, want %v", got, want) } diff --git a/models.go b/models.go index b3d458366..560402e3f 100644 --- a/models.go +++ b/models.go @@ -41,7 +41,7 @@ type ModelsList struct { // ListModels Lists the currently available models, // and provides basic information about each model such as the model id and parent. func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/models")) if err != nil { return } @@ -54,7 +54,7 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) // the model such as the owner and permissioning. func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) { urlSuffix := fmt.Sprintf("/models/%s", modelID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } diff --git a/models_test.go b/models_test.go index 0b4daf4a8..59b4f5ef7 100644 --- a/models_test.go +++ b/models_test.go @@ -1,6 +1,9 @@ package openai_test import ( + "os" + "time" + . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -56,3 +59,22 @@ func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(Model{}) fmt.Fprintln(w, string(resBytes)) } + +func TestGetModelReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.GetModel(ctx, "text-davinci-003") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} diff --git a/moderation.go b/moderation.go index bae788035..a58d759c0 100644 --- a/moderation.go +++ b/moderation.go @@ -63,7 +63,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) if err != nil { return } diff --git a/stream.go b/stream.go index 94cc0a0a2..b277f3c29 100644 --- a/stream.go +++ b/stream.go @@ -1,11 +1,8 @@ package openai import ( - "bufio" "context" "errors" - - utils "github.com/sashabaranov/go-openai/internal" ) var ( @@ -36,27 +33,17 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) + req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { - return + return nil, err } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := sendRequestStream[CompletionResponse](c, req) if err != nil { return } - if isFailureStatusCode(resp) { - return nil, c.handleErrorResp(resp) - } - stream = &CompletionStream{ - streamReader: &streamReader[CompletionResponse]{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: utils.NewErrorAccumulator(), - unmarshaler: &utils.JSONUnmarshaler{}, - }, + streamReader: resp, } return } diff --git a/stream_test.go b/stream_test.go index 5997f27e8..f3f8f85cd 100644 --- a/stream_test.go +++ b/stream_test.go @@ -6,7 +6,9 @@ import ( "errors" "io" "net/http" + "os" "testing" + "time" . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -300,6 +302,30 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { } } +func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.CreateCompletionStream(ctx, CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + }) + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} + // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { From 5f4ef298e3d4d74784ac53d75d0d43379efa2efc Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 23 Jun 2023 13:07:43 +0400 Subject: [PATCH 035/206] Update README.md (#406) --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 522a85e78..ef1db98cc 100644 --- a/README.md +++ b/README.md @@ -555,3 +555,10 @@ OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go ``` If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. + +## Thank you + +We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project: +- [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com) + +To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together! From 0ca4ea48671c631fb15cc01d50b89c6c3658dafb Mon Sep 17 00:00:00 2001 From: James MacWhyte Date: Sat, 24 Jun 2023 18:22:11 +0200 Subject: [PATCH 036/206] move json schema to directory/package (#407) * move json schema to directory/package * added jsonschema to README --- README.md | 60 ++++++++++++++++++++++++++++++++++++++++++++++ chat.go | 39 ++++-------------------------- chat_test.go | 39 +++++++++++++++--------------- jsonschema/json.go | 35 +++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 54 deletions(-) create mode 100644 jsonschema/json.go diff --git a/README.md b/README.md index ef1db98cc..da1a2804d 100644 --- a/README.md +++ b/README.md @@ -516,6 +516,66 @@ func main() { ``` +
+JSON Schema for function calling + +It is now possible for chat completion to choose to call a function for more information ([see developer docs here](https://platform.openai.com/docs/guides/gpt/function-calling)). + +In order to describe the type of functions that can be called, a JSON schema must be provided. Many JSON schema libraries exist and are more advanced than what we can offer in this library, however we have included a simple `jsonschema` package for those who want to use this feature without formatting their own JSON schema payload. + +The developer documents give this JSON schema definition as an example: + +```json +{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and state, e.g. San Francisco, CA" + }, + "unit":{ + "type":"string", + "enum":[ + "celsius", + "fahrenheit" + ] + } + }, + "required":[ + "location" + ] + } +} +``` + +Using the `jsonschema` package, this schema could be created using structs as such: + +```go +FunctionDefinition{ + Name: "get_current_weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celcius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, +} +``` + +The `Parameters` field of a `FunctionDefinition` can accept either of the above styles, or even a nested struct from another library (as long as it can be marshalled into JSON). +
+
Error handling diff --git a/chat.go b/chat.go index b74720d38..e4f23df07 100644 --- a/chat.go +++ b/chat.go @@ -62,47 +62,16 @@ type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` // Parameters is an object describing the function. - // You can pass a raw byte array describing the schema, - // or you can pass in a struct which serializes to the proper JSONSchema. - // The JSONSchemaDefinition struct is provided for convenience, but you should - // consider another specialized library for more complex schemas. + // You can pass a []byte describing the schema, + // or you can pass in a struct which serializes to the proper JSON schema. + // The jsonschema package is provided for convenience, but you should + // consider another specialized library if you require more complex schemas. Parameters any `json:"parameters"` } // Deprecated: use FunctionDefinition instead. type FunctionDefine = FunctionDefinition -type JSONSchemaType string - -const ( - JSONSchemaTypeObject JSONSchemaType = "object" - JSONSchemaTypeNumber JSONSchemaType = "number" - JSONSchemaTypeString JSONSchemaType = "string" - JSONSchemaTypeArray JSONSchemaType = "array" - JSONSchemaTypeNull JSONSchemaType = "null" - JSONSchemaTypeBoolean JSONSchemaType = "boolean" -) - -// JSONSchemaDefinition is a struct for JSON Schema. -// It is fairly limited and you may have better luck using a third-party library. -type JSONSchemaDefinition struct { - // Type is a type of JSON Schema. - Type JSONSchemaType `json:"type,omitempty"` - // Description is a description of JSON Schema. - Description string `json:"description,omitempty"` - // Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString. - Enum []string `json:"enum,omitempty"` - // Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject. - Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"` - // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. - Required []string `json:"required,omitempty"` - // Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray. - Items *JSONSchemaDefinition `json:"items,omitempty"` -} - -// Deprecated: use JSONSchemaDefinition instead. -type JSONSchemaDefine = JSONSchemaDefinition - type FinishReason string const ( diff --git a/chat_test.go b/chat_test.go index 3c759b310..d5879e60f 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -13,6 +10,10 @@ import ( "strings" "testing" "time" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestChatCompletionsWrongModel(t *testing.T) { @@ -128,22 +129,22 @@ func TestChatCompletionsFunctions(t *testing.T) { }, Functions: []FunctionDefinition{{ Name: "test", - Parameters: &JSONSchemaDefinition{ - Type: JSONSchemaTypeObject, - Properties: map[string]JSONSchemaDefinition{ + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "count": { - Type: JSONSchemaTypeNumber, + Type: jsonschema.Number, Description: "total number of words in sentence", }, "words": { - Type: JSONSchemaTypeArray, + Type: jsonschema.Array, Description: "list of words in sentence", - Items: &JSONSchemaDefinition{ - Type: JSONSchemaTypeString, + Items: &jsonschema.Definition{ + Type: jsonschema.String, }, }, "enumTest": { - Type: JSONSchemaTypeString, + Type: jsonschema.String, Enum: []string{"hello", "world"}, }, }, @@ -165,22 +166,22 @@ func TestChatCompletionsFunctions(t *testing.T) { }, Functions: []FunctionDefine{{ Name: "test", - Parameters: &JSONSchemaDefine{ - Type: JSONSchemaTypeObject, - Properties: map[string]JSONSchemaDefine{ + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "count": { - Type: JSONSchemaTypeNumber, + Type: jsonschema.Number, Description: "total number of words in sentence", }, "words": { - Type: JSONSchemaTypeArray, + Type: jsonschema.Array, Description: "list of words in sentence", - Items: &JSONSchemaDefine{ - Type: JSONSchemaTypeString, + Items: &jsonschema.Definition{ + Type: jsonschema.String, }, }, "enumTest": { - Type: JSONSchemaTypeString, + Type: jsonschema.String, Enum: []string{"hello", "world"}, }, }, diff --git a/jsonschema/json.go b/jsonschema/json.go new file mode 100644 index 000000000..24af8584e --- /dev/null +++ b/jsonschema/json.go @@ -0,0 +1,35 @@ +// Package jsonschema provides very simple functionality for representing a JSON schema as a +// (nested) struct. This struct can be used with the chat completion "function call" feature. +// For more complicated schemas, it is recommended to use a dedicated JSON schema library +// and/or pass in the schema in []byte format. +package jsonschema + +type DataType string + +const ( + Object DataType = "object" + Number DataType = "number" + Integer DataType = "integer" + String DataType = "string" + Array DataType = "array" + Null DataType = "null" + Boolean DataType = "boolean" +) + +// Definition is a struct for describing a JSON Schema. +// It is fairly limited and you may have better luck using a third-party library. +type Definition struct { + // Type specifies the data type of the schema. + Type DataType `json:"type,omitempty"` + // Description is the description of the schema. + Description string `json:"description,omitempty"` + // Enum is used to restrict a value to a fixed set of values. It must be an array with at least + // one element, where each element is unique. You will probably only use this with strings. + Enum []string `json:"enum,omitempty"` + // Properties describes the properties of an object, if the schema type is Object. + Properties map[string]Definition `json:"properties,omitempty"` + // Required specifies which properties are required, if the schema type is Object. + Required []string `json:"required,omitempty"` + // Items specifies which data type an array contains, if the schema type is Array. + Items *Definition `json:"items,omitempty"` +} From a3c0b36b35dac5168c9ef07dacb4c1ad55efc51c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 26 Jun 2023 23:32:57 +0900 Subject: [PATCH 037/206] chore: add an issue template for feature request (#410) --- .github/ISSUE_TEMPLATE/feature_request.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..2359e5c00 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,23 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +Your issue may already be reported! +Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. From 581f70b102d7443aeea4f19cf04d570150e8cc42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 26 Jun 2023 23:33:32 +0900 Subject: [PATCH 038/206] chore: add an issue template for bug report (#408) --- .github/ISSUE_TEMPLATE/bug_report.md | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..536a2ee29 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,32 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +Your issue may already be reported! +Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. + +**Describe the bug** +A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s). + +**To Reproduce** +Steps to reproduce the behavior, including any relevant code snippets. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots/Logs** +If applicable, add screenshots to help explain your problem. For non-graphical issues, please provide any relevant logs or stack traces. + +**Environment (please complete the following information):** + - go-openai version: [e.g. v1.12.0] + - Go version: [e.g. 1.18] + - OpenAI API version: [e.g. v1] + - OS: [e.g. Ubuntu 20.04] + +**Additional context** +Add any other context about the problem here. From 86d0f48d2ddd88fed5ed4036ab32218e90d2ee4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Thu, 29 Jun 2023 02:18:34 +0900 Subject: [PATCH 039/206] chore: add a pull request template (#412) --- .../pull_request_template.md | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE/pull_request_template.md diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md new file mode 100644 index 000000000..b078d1964 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md @@ -0,0 +1,27 @@ +--- +name: Pull Request +about: Propose changes to the codebase +title: '' +labels: '' +assignees: '' + +--- + +A similar PR may already be submitted! +Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. + +Thanks for submitting a pull request! Please provide enough information so that others can review your pull request. + +**Describe the change** +Please provide a clear and concise description of the changes you're proposing. Explain what problem it solves or what feature it adds. + +**Describe your solution** +Describe how your changes address the problem or how they add the feature. This should include a brief description of your approach and any new libraries or dependencies you're using. + +**Tests** +Briefly describe how you have tested these changes. + +**Additional context** +Add any other context or screenshots or logs about your pull request here. If the pull request relates to an open issue, please link to it. + +Issue: #XXXX From 9c99f3626f1d80382e187df8adc38f7e8e929a75 Mon Sep 17 00:00:00 2001 From: ryomak <21288308+ryomak@users.noreply.github.com> Date: Thu, 29 Jun 2023 09:41:22 +0900 Subject: [PATCH 040/206] replace deprecated FunctionDefine in chat_test.go (#416) --- chat_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chat_test.go b/chat_test.go index d5879e60f..5723d6ccf 100644 --- a/chat_test.go +++ b/chat_test.go @@ -85,7 +85,7 @@ func TestChatCompletionsFunctions(t *testing.T) { Content: "Hello!", }, }, - Functions: []FunctionDefine{{ + Functions: []FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -117,7 +117,7 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) - t.Run("JSONSchemaDefine", func(t *testing.T) { + t.Run("JSONSchemaDefinition", func(t *testing.T) { _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo0613, @@ -153,7 +153,7 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) - t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) { + t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { // this is a compatibility check _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ MaxTokens: 5, From 1efcf2d23de7866701bce946c65f271fff2f05e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Fri, 30 Jun 2023 19:49:36 +0900 Subject: [PATCH 041/206] fix: move pull request template (#420) --- .../pull_request_template.md => PULL_REQUEST_TEMPLATE.md} | 2 ++ 1 file changed, 2 insertions(+) rename .github/{PULL_REQUEST_TEMPLATE/pull_request_template.md => PULL_REQUEST_TEMPLATE.md} (79%) diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE.md similarity index 79% rename from .github/PULL_REQUEST_TEMPLATE/pull_request_template.md rename to .github/PULL_REQUEST_TEMPLATE.md index b078d1964..f7e45401b 100644 --- a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -10,6 +10,8 @@ assignees: '' A similar PR may already be submitted! Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. +If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly. + Thanks for submitting a pull request! Please provide enough information so that others can review your pull request. **Describe the change** From 177c143be7c373a9b5d33e6c7c64ad3e6670a32c Mon Sep 17 00:00:00 2001 From: Rick Date: Sat, 1 Jul 2023 06:38:22 +0800 Subject: [PATCH 042/206] Fix OpenAI error when properties is empty in function call : object schema missing properties (#419) Co-authored-by: Rick --- jsonschema/json.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index 24af8584e..c02d250aa 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -27,7 +27,7 @@ type Definition struct { // one element, where each element is unique. You will probably only use this with strings. Enum []string `json:"enum,omitempty"` // Properties describes the properties of an object, if the schema type is Object. - Properties map[string]Definition `json:"properties,omitempty"` + Properties map[string]Definition `json:"properties"` // Required specifies which properties are required, if the schema type is Object. Required []string `json:"required,omitempty"` // Items specifies which data type an array contains, if the schema type is Array. From 204260818e9987c4a84f520d63d9c8758c986cbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 3 Jul 2023 19:46:38 +0900 Subject: [PATCH 043/206] docs: remove medatada in PULL_REQUEST_TEMPLATE.md (#423) --- .github/PULL_REQUEST_TEMPLATE.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index f7e45401b..44bf697ed 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,12 +1,3 @@ ---- -name: Pull Request -about: Propose changes to the codebase -title: '' -labels: '' -assignees: '' - ---- - A similar PR may already be submitted! Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. From 5c7d88212f6e73fdac89723d42b9e3a1b113931c Mon Sep 17 00:00:00 2001 From: Jackson Stone Date: Wed, 5 Jul 2023 16:53:53 -0500 Subject: [PATCH 044/206] Allow embeddings requests to be tokens or strings (#417) * Allow raw tokens to be used as embedding input * fix linting issues (lines too long) * add endpoint test for embedding from tokens * remove redundant comments * fix comment to match new param name * change interface to any * Rename methods and implement convert for base req * add comments to CreateEmbeddings * update tests * shorten line length * rename parameter --- embeddings.go | 62 +++++++++++++++++++++++++++++++++++++++++----- embeddings_test.go | 38 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) diff --git a/embeddings.go b/embeddings.go index ba327ce77..41af50b4b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -113,10 +113,25 @@ type EmbeddingResponse struct { Usage Usage `json:"usage"` } -// EmbeddingRequest is the input to a Create embeddings request. +type EmbeddingRequestConverter interface { + // Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens + Convert() EmbeddingRequest +} + type EmbeddingRequest struct { + Input any `json:"input"` + Model EmbeddingModel `json:"model"` + User string `json:"user"` +} + +func (r EmbeddingRequest) Convert() EmbeddingRequest { + return r +} + +// EmbeddingRequestStrings is the input to a create embeddings request with a slice of strings. +type EmbeddingRequestStrings struct { // Input is a slice of strings for which you want to generate an Embedding vector. - // Each input must not exceed 2048 tokens in length. + // Each input must not exceed 8192 tokens in length. // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they // have observed inferior results when newlines are present. // E.g. @@ -129,15 +144,50 @@ type EmbeddingRequest struct { User string `json:"user"` } -// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. +func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { + return EmbeddingRequest{ + Input: r.Input, + Model: r.Model, + User: r.User, + } +} + +type EmbeddingRequestTokens struct { + // Input is a slice of slices of ints ([][]int) for which you want to generate an Embedding vector. + // Each input must not exceed 8192 tokens in length. + // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they + // have observed inferior results when newlines are present. + // E.g. + // "The food was delicious and the waiter..." + Input [][]int `json:"input"` + // ID of the model to use. You can use the List models API to see all of your available models, + // or see our Model overview for descriptions of them. + Model EmbeddingModel `json:"model"` + // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. + User string `json:"user"` +} + +func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { + return EmbeddingRequest{ + Input: r.Input, + Model: r.Model, + User: r.User, + } +} + +// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |body.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create -func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request)) +// +// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens +// for embedding groups of text already converted to tokens. +func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll + baseReq := conv.Convert() + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) if err != nil { return } - err = c.sendRequest(req, &resp) + err = c.sendRequest(req, &res) return } diff --git a/embeddings_test.go b/embeddings_test.go index d7892cd5d..47c4f5108 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -32,6 +32,7 @@ func TestEmbedding(t *testing.T) { BabbageCodeSearchText, } for _, model := range embeddedModels { + // test embedding request with strings (simple embedding request) embeddingReq := EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", @@ -46,6 +47,34 @@ func TestEmbedding(t *testing.T) { if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { t.Fatalf("Expected embedding request to contain model field") } + + // test embedding request with strings + embeddingReqStrings := EmbeddingRequestStrings{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + } + marshaled, err = json.Marshal(embeddingReqStrings) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + + // test embedding request with tokens + embeddingReqTokens := EmbeddingRequestTokens{ + Input: [][]int{ + {464, 2057, 373, 12625, 290, 262, 46612}, + {6395, 6096, 286, 11525, 12083, 2581}, + }, + Model: model, + } + marshaled, err = json.Marshal(embeddingReqTokens) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } } } @@ -75,6 +104,15 @@ func TestEmbeddingEndpoint(t *testing.T) { fmt.Fprintln(w, string(resBytes)) }, ) + // test create embeddings with strings (simple embedding request) _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") + + // test create embeddings with strings + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + checks.NoError(t, err, "CreateEmbeddings strings error") + + // test create embeddings with tokens + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + checks.NoError(t, err, "CreateEmbeddings tokens error") } From 619ad717353d8b9d5f4d9049b1ce1b168c5851b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Thu, 6 Jul 2023 06:54:27 +0900 Subject: [PATCH 045/206] docs: added instructions for obtaining OpenAI API key to README (#421) * docs: added instructions for obtaining OpenAI API key to README * docs: move 'Getting an OpenAI API key' before 'Other examples' --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index da1a2804d..1f708af70 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,17 @@ func main() { ``` +### Getting an OpenAI API Key: + +1. Visit the OpenAI website at [https://platform.openai.com/account/api-keys](https://platform.openai.com/account/api-keys). +2. If you don't have an account, click on "Sign Up" to create one. If you do, click "Log In". +3. Once logged in, navigate to your API key management page. +4. Click on "Create new secret key". +5. Enter a name for your new key, then click "Create secret key". +6. Your new API key will be displayed. Use this key to interact with the OpenAI API. + +**Note:** Your API key is sensitive information. Do not share it with anyone. + ### Other examples:
From 7b22898f5d3fd86232057ed61e83adb47bf24cb0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Sun, 9 Jul 2023 17:09:50 +0800 Subject: [PATCH 046/206] Implement OpenAI July 2023 Updates (#427) * Implement OpenAI July 2023 Updates * fix: golangci-lint * add comment * fix: remove some model Deprecated --- completion.go | 55 +++++++++++++++++++++++++++++++-------------------- edits.go | 6 +++++- embeddings.go | 16 +++++++++++++++ 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/completion.go b/completion.go index b3b3abd1c..61bfed654 100644 --- a/completion.go +++ b/completion.go @@ -17,29 +17,42 @@ var ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( - GPT432K0613 = "gpt-4-32k-0613" - GPT432K0314 = "gpt-4-32k-0314" - GPT432K = "gpt-4-32k" - GPT40613 = "gpt-4-0613" - GPT40314 = "gpt-4-0314" - GPT4 = "gpt-4" - GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" - GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" - GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" - GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" - GPT3Dot5Turbo = "gpt-3.5-turbo" - GPT3TextDavinci003 = "text-davinci-003" - GPT3TextDavinci002 = "text-davinci-002" - GPT3TextCurie001 = "text-curie-001" - GPT3TextBabbage001 = "text-babbage-001" - GPT3TextAda001 = "text-ada-001" - GPT3TextDavinci001 = "text-davinci-001" + GPT432K0613 = "gpt-4-32k-0613" + GPT432K0314 = "gpt-4-32k-0314" + GPT432K = "gpt-4-32k" + GPT40613 = "gpt-4-0613" + GPT40314 = "gpt-4-0314" + GPT4 = "gpt-4" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" + GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" + GPT3Dot5Turbo = "gpt-3.5-turbo" + GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci003 = "text-davinci-003" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci002 = "text-davinci-002" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextCurie001 = "text-curie-001" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextBabbage001 = "text-babbage-001" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextAda001 = "text-ada-001" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci001 = "text-davinci-001" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" GPT3Davinci = "davinci" - GPT3CurieInstructBeta = "curie-instruct-beta" - GPT3Curie = "curie" - GPT3Ada = "ada" - GPT3Babbage = "babbage" + GPT3Davinci002 = "davinci-002" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3CurieInstructBeta = "curie-instruct-beta" + GPT3Curie = "curie" + GPT3Curie002 = "curie-002" + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. diff --git a/edits.go b/edits.go index 3d3fc8950..831aade2f 100644 --- a/edits.go +++ b/edits.go @@ -30,7 +30,11 @@ type EditsResponse struct { Choices []EditsChoice `json:"choices"` } -// Perform an API call to the Edits endpoint. +// Edits Perform an API call to the Edits endpoint. +/* Deprecated: Users of the Edits API and its associated models (e.g., text-davinci-edit-001 or code-davinci-edit-001) +will need to migrate to GPT-3.5 Turbo by January 4, 2024. +You can use CreateChatCompletion or CreateChatCompletionStream instead. +*/ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) if err != nil { diff --git a/embeddings.go b/embeddings.go index 41af50b4b..1d3199597 100644 --- a/embeddings.go +++ b/embeddings.go @@ -34,21 +34,37 @@ func (e *EmbeddingModel) UnmarshalText(b []byte) error { const ( Unknown EmbeddingModel = iota + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaSimilarity + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageSimilarity + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. CurieSimilarity + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. DavinciSimilarity + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaSearchDocument + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaSearchQuery + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageSearchDocument + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageSearchQuery + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. CurieSearchDocument + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. CurieSearchQuery + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. DavinciSearchDocument + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. DavinciSearchQuery + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaCodeSearchCode + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaCodeSearchText + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageCodeSearchCode + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageCodeSearchText AdaEmbeddingV2 ) From 181fc2ade904c7d6a0910cfebdd6c90e7a4d80ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sun, 9 Jul 2023 18:11:39 +0900 Subject: [PATCH 047/206] docs: explanation about LogitBias. (129) (#426) --- chat.go | 11 +++++++---- completion.go | 35 +++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/chat.go b/chat.go index e4f23df07..17d7cd574 100644 --- a/chat.go +++ b/chat.go @@ -52,10 +52,13 @@ type ChatCompletionRequest struct { Stop []string `json:"stop,omitempty"` PresencePenalty float32 `json:"presence_penalty,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` - Functions []FunctionDefinition `json:"functions,omitempty"` - FunctionCall any `json:"function_call,omitempty"` + // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. + // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` + Functions []FunctionDefinition `json:"functions,omitempty"` + FunctionCall any `json:"function_call,omitempty"` } type FunctionDefinition struct { diff --git a/completion.go b/completion.go index 61bfed654..7b9ae89e7 100644 --- a/completion.go +++ b/completion.go @@ -109,22 +109,25 @@ func checkPromptType(prompt any) bool { // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { - Model string `json:"model"` - Prompt any `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model"` + Prompt any `json:"prompt,omitempty"` + Suffix string `json:"suffix,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + Echo bool `json:"echo,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + BestOf int `json:"best_of,omitempty"` + // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. + // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` + // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From f028c289d2e2ae7562d97594d122447fd23a632d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 10 Jul 2023 02:07:01 +0900 Subject: [PATCH 048/206] fix: function call error due to nil properties (429) (#431) * fix: fix function call error due to nil properties (429) * refactor: refactoring initializeProperties func in jsonschema pkg (429) --- jsonschema/json.go | 25 ++++- jsonschema/json_test.go | 201 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 jsonschema/json_test.go diff --git a/jsonschema/json.go b/jsonschema/json.go index c02d250aa..e4eef98e7 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,6 +4,8 @@ // and/or pass in the schema in []byte format. package jsonschema +import "encoding/json" + type DataType string const ( @@ -17,7 +19,7 @@ const ( ) // Definition is a struct for describing a JSON Schema. -// It is fairly limited and you may have better luck using a third-party library. +// It is fairly limited, and you may have better luck using a third-party library. type Definition struct { // Type specifies the data type of the schema. Type DataType `json:"type,omitempty"` @@ -33,3 +35,24 @@ type Definition struct { // Items specifies which data type an array contains, if the schema type is Array. Items *Definition `json:"items,omitempty"` } + +func (d *Definition) MarshalJSON() ([]byte, error) { + d.initializeProperties() + return json.Marshal(*d) +} + +func (d *Definition) initializeProperties() { + if d.Properties == nil { + d.Properties = make(map[string]Definition) + return + } + + for k, v := range d.Properties { + if v.Properties == nil { + v.Properties = make(map[string]Definition) + } else { + v.initializeProperties() + } + d.Properties[k] = v + } +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go new file mode 100644 index 000000000..0dc31a58a --- /dev/null +++ b/jsonschema/json_test.go @@ -0,0 +1,201 @@ +package jsonschema_test + +import ( + "encoding/json" + "reflect" + "testing" + + . "github.com/sashabaranov/go-openai/jsonschema" +) + +func TestDefinition_MarshalJSON(t *testing.T) { + tests := []struct { + name string + def Definition + want string + }{ + { + name: "Test with empty Definition", + def: Definition{}, + want: `{"properties":{}}`, + }, + { + name: "Test with Definition properties set", + def: Definition{ + Type: String, + Description: "A string type", + Properties: map[string]Definition{ + "name": { + Type: String, + }, + }, + }, + want: `{ + "type":"string", + "description":"A string type", + "properties":{ + "name":{ + "type":"string", + "properties":{} + } + } +}`, + }, + { + name: "Test with nested Definition properties", + def: Definition{ + Type: Object, + Properties: map[string]Definition{ + "user": { + Type: Object, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + "age": { + Type: Integer, + }, + }, + }, + }, + }, + want: `{ + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "properties":{} + }, + "age":{ + "type":"integer", + "properties":{} + } + } + } + } +}`, + }, + { + name: "Test with complex nested Definition", + def: Definition{ + Type: Object, + Properties: map[string]Definition{ + "user": { + Type: Object, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + "age": { + Type: Integer, + }, + "address": { + Type: Object, + Properties: map[string]Definition{ + "city": { + Type: String, + }, + "country": { + Type: String, + }, + }, + }, + }, + }, + }, + }, + want: `{ + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "properties":{} + }, + "age":{ + "type":"integer", + "properties":{} + }, + "address":{ + "type":"object", + "properties":{ + "city":{ + "type":"string", + "properties":{} + }, + "country":{ + "type":"string", + "properties":{} + } + } + } + } + } + } +}`, + }, + { + name: "Test with Array type Definition", + def: Definition{ + Type: Array, + Items: &Definition{ + Type: String, + }, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + }, + }, + want: `{ + "type":"array", + "items":{ + "type":"string", + "properties":{ + + } + }, + "properties":{ + "name":{ + "type":"string", + "properties":{} + } + } +}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotBytes, err := json.Marshal(&tt.def) + if err != nil { + t.Errorf("Failed to Marshal JSON: error = %v", err) + return + } + + var got map[string]interface{} + err = json.Unmarshal(gotBytes, &got) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + wantBytes := []byte(tt.want) + var want map[string]interface{} + err = json.Unmarshal(wantBytes, &want) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, want) + } + }) + } +} From c3b2451f7c7dc477d98e1baa10993ac55392c7dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Tue, 11 Jul 2023 20:48:15 +0900 Subject: [PATCH 049/206] fix: invalid schema for function 'func_name': None is not of type 'object' (#429)(#432) (#434) * fix: invalid schema for function 'func_name': None is not of type 'object' (#429)(#432) * test: add integration test for function call (#429)(#432) * style: remove duplicate import (#429)(#432) --- api_integration_test.go | 32 ++++++++++++++++++++++++++++++++ jsonschema/json.go | 23 +++++++---------------- jsonschema/json_test.go | 38 ++++++++++++++++++++++++-------------- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index d4e7328a2..254fbeb03 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -11,6 +11,7 @@ import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestAPI(t *testing.T) { @@ -100,6 +101,37 @@ func TestAPI(t *testing.T) { if counter == 0 { t.Error("Stream did not return any responses") } + + _, err = c.CreateChatCompletion( + context.Background(), + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "What is the weather like in Boston?", + }, + }, + Functions: []FunctionDefinition{{ + Name: "get_current_weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }}, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } func TestAPIError(t *testing.T) { diff --git a/jsonschema/json.go b/jsonschema/json.go index e4eef98e7..cb941eb75 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -36,23 +36,14 @@ type Definition struct { Items *Definition `json:"items,omitempty"` } -func (d *Definition) MarshalJSON() ([]byte, error) { - d.initializeProperties() - return json.Marshal(*d) -} - -func (d *Definition) initializeProperties() { +func (d Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) - return - } - - for k, v := range d.Properties { - if v.Properties == nil { - v.Properties = make(map[string]Definition) - } else { - v.initializeProperties() - } - d.Properties[k] = v } + type Alias Definition + return json.Marshal(struct { + Alias + }{ + Alias: (Alias)(d), + }) } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 0dc31a58a..c8d0c1d9e 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -172,30 +172,40 @@ func TestDefinition_MarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotBytes, err := json.Marshal(&tt.def) - if err != nil { - t.Errorf("Failed to Marshal JSON: error = %v", err) - return - } - - var got map[string]interface{} - err = json.Unmarshal(gotBytes, &got) - if err != nil { - t.Errorf("Failed to Unmarshal JSON: error = %v", err) - return - } - wantBytes := []byte(tt.want) var want map[string]interface{} - err = json.Unmarshal(wantBytes, &want) + err := json.Unmarshal(wantBytes, &want) if err != nil { t.Errorf("Failed to Unmarshal JSON: error = %v", err) return } + got := structToMap(t, tt.def) + gotPtr := structToMap(t, &tt.def) + if !reflect.DeepEqual(got, want) { t.Errorf("MarshalJSON() got = %v, want %v", got, want) } + if !reflect.DeepEqual(gotPtr, want) { + t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) + } }) } } + +func structToMap(t *testing.T, v any) map[string]any { + t.Helper() + gotBytes, err := json.Marshal(v) + if err != nil { + t.Errorf("Failed to Marshal JSON: error = %v", err) + return nil + } + + var got map[string]interface{} + err = json.Unmarshal(gotBytes, &got) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return nil + } + return got +} From 39b2acb5c93c3ee12020cda8d1a5cc0cf2bea1a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 12 Jul 2023 23:15:39 +0900 Subject: [PATCH 050/206] ci: set up closing-inactive-issues in GitHub Action (129) (#428) --- .github/workflows/close-inactive-issues.yml | 23 +++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/close-inactive-issues.yml diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml new file mode 100644 index 000000000..bfe9b5c96 --- /dev/null +++ b/.github/workflows/close-inactive-issues.yml @@ -0,0 +1,23 @@ +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" + +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v5 + with: + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + exempt-issue-labels: 'bug,enhancement' + stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." + close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file From e22a29d84ebb8c5c911937669f27ac3265f3c982 Mon Sep 17 00:00:00 2001 From: Munar <118156704+MunaerYesiyan@users.noreply.github.com> Date: Thu, 13 Jul 2023 13:30:58 +0900 Subject: [PATCH 051/206] Check if the model param is valid for moderations endpoint (#437) * chore: check for models before sending moderation requets to openai endpoint * chore: table driven tests to include more model cases for moderations endpoint --- moderation.go | 17 ++++++++++++++++- moderation_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/moderation.go b/moderation.go index a58d759c0..a32f123f3 100644 --- a/moderation.go +++ b/moderation.go @@ -2,6 +2,7 @@ package openai import ( "context" + "errors" "net/http" ) @@ -15,9 +16,19 @@ import ( const ( ModerationTextStable = "text-moderation-stable" ModerationTextLatest = "text-moderation-latest" - ModerationText001 = "text-moderation-001" + // Deprecated: use ModerationTextStable and ModerationTextLatest instead. + ModerationText001 = "text-moderation-001" ) +var ( + ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll +) + +var validModerationModel = map[string]struct{}{ + ModerationTextStable: {}, + ModerationTextLatest: {}, +} + // ModerationRequest represents a request structure for moderation API. type ModerationRequest struct { Input string `json:"input,omitempty"` @@ -63,6 +74,10 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { + if _, ok := validModerationModel[request.Model]; len(request.Model) > 0 && !ok { + err = ErrModerationInvalidModel + return + } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) if err != nil { return diff --git a/moderation_test.go b/moderation_test.go index 4e756137e..68f9565e1 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -27,6 +27,41 @@ func TestModerations(t *testing.T) { checks.NoError(t, err, "Moderation error") } +// TestModerationsWithIncorrectModel Tests passing valid and invalid models to moderations endpoint. +func TestModerationsWithDifferentModelOptions(t *testing.T) { + var modelOptions []struct { + model string + expect error + } + modelOptions = append(modelOptions, + getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel), + getModerationModelTestOption(ModerationTextStable, nil), + getModerationModelTestOption(ModerationTextLatest, nil), + getModerationModelTestOption("", nil), + ) + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/moderations", handleModerationEndpoint) + for _, modelTest := range modelOptions { + _, err := client.Moderations(context.Background(), ModerationRequest{ + Model: modelTest.model, + Input: "I want to kill them.", + }) + checks.ErrorIs(t, err, modelTest.expect, + fmt.Sprintf("Moderations(..) expects err: %v, actual err:%v", modelTest.expect, err)) + } +} + +func getModerationModelTestOption(model string, expect error) struct { + model string + expect error +} { + return struct { + model string + expect error + }{model: model, expect: expect} +} + // handleModerationEndpoint Handles the moderation endpoint by the test server. func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { var err error From 0234c1e0c2769c9599f0799259fb8db5c4e3e011 Mon Sep 17 00:00:00 2001 From: Mehul Gohil Date: Sat, 15 Jul 2023 03:43:05 +0530 Subject: [PATCH 052/206] add example: fine tune (#438) * add example for fine tune * update example for fine tune * fix comments --- README.md | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/README.md b/README.md index 1f708af70..19aadde2a 100644 --- a/README.md +++ b/README.md @@ -611,6 +611,73 @@ if errors.As(err, &e) { ```
+
+Fine Tune Model + +```go +package main + +import ( + "context" + "fmt" + "github.com/sashabaranov/go-openai" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + // create a .jsonl file with your training data + // {"prompt": "", "completion": ""} + // {"prompt": "", "completion": ""} + // {"prompt": "", "completion": ""} + + // you can use openai cli tool to validate the data + // For more info - https://platform.openai.com/docs/guides/fine-tuning + + file, err := client.CreateFile(ctx, openai.FileRequest{ + FilePath: "training_prepared.jsonl", + Purpose: "fine-tune", + }) + if err != nil { + fmt.Printf("Upload JSONL file error: %v\n", err) + return + } + + // create a fine tune job + // Streams events until the job is done (this often takes minutes, but can take hours if there are many jobs in the queue or your dataset is large) + // use below get method to know the status of your model + tune, err := client.CreateFineTune(ctx, openai.FineTuneRequest{ + TrainingFile: file.ID, + Model: "ada", // babbage, curie, davinci, or a fine-tuned model created after 2022-04-21. + }) + if err != nil { + fmt.Printf("Creating new fine tune model error: %v\n", err) + return + } + + getTune, err := client.GetFineTune(ctx, tune.ID) + if err != nil { + fmt.Printf("Getting fine tune model error: %v\n", err) + return + } + fmt.Println(getTune.FineTunedModel) + + // once the status of getTune is `succeeded`, you can use your fine tune model in Completion Request + + // resp, err := client.CreateCompletion(ctx, openai.CompletionRequest{ + // Model: getTune.FineTunedModel, + // Prompt: "your prompt", + // }) + // if err != nil { + // fmt.Printf("Create completion error %v\n", err) + // return + // } + // + // fmt.Println(resp.Choices[0].Text) +} +``` +
See the `examples/` folder for more. ### Integration tests: From 1876e0c20716afc4d012688bc393ccd5f28def79 Mon Sep 17 00:00:00 2001 From: Savannah Ostrowski Date: Fri, 14 Jul 2023 21:33:55 -0700 Subject: [PATCH 053/206] update to json.RawMessage (#441) --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 17d7cd574..7a6438e7f 100644 --- a/chat.go +++ b/chat.go @@ -65,7 +65,7 @@ type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` // Parameters is an object describing the function. - // You can pass a []byte describing the schema, + // You can pass json.RawMessage to describe the schema, // or you can pass in a struct which serializes to the proper JSON schema. // The jsonschema package is provided for convenience, but you should // consider another specialized library if you require more complex schemas. From 1153eb2595d1529927757dd6df4de71faaafde02 Mon Sep 17 00:00:00 2001 From: ZeroDeng Date: Fri, 21 Jul 2023 00:25:58 +0800 Subject: [PATCH 054/206] Add support for azure openai new version API (2023-07-01-preview) (#451) --- chat.go | 29 +++++++++++++++++++++++++++++ chat_stream.go | 18 ++++++++++-------- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/chat.go b/chat.go index 7a6438e7f..514aaee75 100644 --- a/chat.go +++ b/chat.go @@ -21,6 +21,35 @@ var ( ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll ) +type Hate struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type SelfHarm struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type Sexual struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type Violence struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} + +type ContentFilterResults struct { + Hate Hate `json:"hate,omitempty"` + SelfHarm SelfHarm `json:"self_harm,omitempty"` + Sexual Sexual `json:"sexual,omitempty"` + Violence Violence `json:"violence,omitempty"` +} + +type PromptAnnotation struct { + PromptIndex int `json:"prompt_index,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` diff --git a/chat_stream.go b/chat_stream.go index 9f4e80cff..f1faa3964 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -12,17 +12,19 @@ type ChatCompletionStreamChoiceDelta struct { } type ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason FinishReason `json:"finish_reason"` + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + FinishReason FinishReason `json:"finish_reason"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionStreamChoice `json:"choices"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` } // ChatCompletionStream From 62dc817b395d16fb0e65be490d49e294ac8c40b0 Mon Sep 17 00:00:00 2001 From: Yu <1095780+yuikns@users.noreply.github.com> Date: Fri, 28 Jul 2023 12:06:48 +0800 Subject: [PATCH 055/206] feat: make finish reason nullable in json marshal (#449) --- chat.go | 7 +++++++ chat_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/chat.go b/chat.go index 514aaee75..8d29b3237 100644 --- a/chat.go +++ b/chat.go @@ -114,6 +114,13 @@ const ( FinishReasonNull FinishReason = "null" ) +func (r FinishReason) MarshalJSON() ([]byte, error) { + if r == FinishReasonNull || r == "" { + return []byte("null"), nil + } + return []byte(`"` + string(r) + `"`), nil // best effort to not break future API changes +} + type ChatCompletionChoice struct { Index int `json:"index"` Message ChatCompletionMessage `json:"message"` diff --git a/chat_test.go b/chat_test.go index 5723d6ccf..38d66fa64 100644 --- a/chat_test.go +++ b/chat_test.go @@ -298,3 +298,34 @@ func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { } return completion, nil } + +func TestFinishReason(t *testing.T) { + c := &ChatCompletionChoice{ + FinishReason: FinishReasonNull, + } + resBytes, _ := json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + c.FinishReason = "" + + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + otherReasons := []FinishReason{ + FinishReasonStop, + FinishReasonLength, + FinishReasonFunctionCall, + FinishReasonContentFilter, + } + for _, r := range otherReasons { + c.FinishReason = r + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), fmt.Sprintf(`"finish_reason":"%s"`, r)) { + t.Errorf("%s should be quoted", r) + } + } +} From 71a24931dbc5b7029901ff963dc4d0d2509aa7ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 31 Jul 2023 04:58:49 +0900 Subject: [PATCH 056/206] docs: add Frequently Asked Questions to README.md (#462) * docs: add Frequently Asked Questions to README.md * Update README.md Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- README.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/README.md b/README.md index 19aadde2a..d627a19ce 100644 --- a/README.md +++ b/README.md @@ -694,6 +694,37 @@ OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. +## Frequently Asked Questions + +### Why don't we get the same answer when specifying a temperature field of 0 and asking the same question? + +Even when specifying a temperature field of 0, it doesn't guarantee that you'll always get the same response. Several factors come into play. + +1. Go OpenAI Behavior: When you specify a temperature field of 0 in Go OpenAI, the omitempty tag causes that field to be removed from the request. Consequently, the OpenAI API applies the default value of 1. +2. Token Count for Input/Output: If there's a large number of tokens in the input and output, setting the temperature to 0 can still result in non-deterministic behavior. In particular, when using around 32k tokens, the likelihood of non-deterministic behavior becomes highest even with a temperature of 0. + +Due to the factors mentioned above, different answers may be returned even for the same question. + +**Workarounds:** +1. Using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. +2. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. + +By adopting these strategies, you can expect more consistent results. + +**Related Issues:** +[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9) + +### Does Go OpenAI provide a method to count tokens? + +No, Go OpenAI does not offer a feature to count tokens, and there are no plans to provide such a feature in the future. However, if there's a way to implement a token counting feature with zero dependencies, it might be possible to merge that feature into Go OpenAI. Otherwise, it would be more appropriate to implement it in a dedicated library or repository. + +For counting tokens, you might find the following links helpful: +- [Counting Tokens For Chat API Calls](https://github.com/pkoukk/tiktoken-go#counting-tokens-for-chat-api-calls) +- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) + +**Related Issues:** +[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62) + ## Thank you We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project: From 34569895f6a0ab4a3ccc497573a8343ee33dc3b1 Mon Sep 17 00:00:00 2001 From: ZeroDeng Date: Wed, 9 Aug 2023 12:05:39 +0800 Subject: [PATCH 057/206] Compatible with the 2023-07-01-preview API interface of Azure Openai, when content interception is triggered, the error message will contain innererror (#460) * Compatible with Azure Openai's 2023-07-01-preview version API interface about the error information returned by the intercepted interface * Compatible with the 2023-07-01-preview API interface of Azure Openai, when content interception is triggered, the error message will contain innererror.InnerError struct is only valid for Azure OpenAI Service. --- error.go | 25 +++++++++++++---- error_test.go | 78 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/error.go b/error.go index f68e92875..b2d01e22e 100644 --- a/error.go +++ b/error.go @@ -7,12 +7,20 @@ import ( ) // APIError provides error information returned by the OpenAI API. +// InnerError struct is only valid for Azure OpenAI Service. type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - HTTPStatusCode int `json:"-"` + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` + HTTPStatusCode int `json:"-"` + InnerError *InnerError `json:"innererror,omitempty"` +} + +// InnerError Azure Content filtering. Only valid for Azure OpenAI Service. +type InnerError struct { + Code string `json:"code,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } // RequestError provides informations about generic request errors. @@ -61,6 +69,13 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } } + if _, ok := rawMap["innererror"]; ok { + err = json.Unmarshal(rawMap["innererror"], &e.InnerError) + if err != nil { + return + } + } + // optional fields if _, ok := rawMap["param"]; ok { err = json.Unmarshal(rawMap["param"], &e.Param) diff --git a/error_test.go b/error_test.go index e2309abd7..a0806b7ed 100644 --- a/error_test.go +++ b/error_test.go @@ -3,6 +3,7 @@ package openai_test import ( "errors" "net/http" + "reflect" "testing" . "github.com/sashabaranov/go-openai" @@ -57,6 +58,77 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { assertAPIErrorMessage(t, apiErr, "") }, }, + { + name: "parse succeeds when the innerError is not exists (Azure Openai)", + response: `{ + "message": "test message", + "type": null, + "param": "prompt", + "code": "content_filter", + "status": 400, + "innererror": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_result": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": true, + "severity": "medium" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + } + } + }`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{ + Code: "ResponsibleAIPolicyViolation", + ContentFilterResults: ContentFilterResults{ + Hate: Hate{ + Filtered: false, + Severity: "safe", + }, + SelfHarm: SelfHarm{ + Filtered: false, + Severity: "safe", + }, + Sexual: Sexual{ + Filtered: true, + Severity: "medium", + }, + Violence: Violence{ + Filtered: false, + Severity: "safe", + }, + }, + }) + }, + }, + { + name: "parse succeeds when the innerError is empty (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{}) + }, + }, + { + name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{}) + }, + }, { name: "parse failed when the message is object", response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, @@ -152,6 +224,12 @@ func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { } } +func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { + if !reflect.DeepEqual(apiErr.InnerError, expected) { + t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) + } +} + func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { switch v := apiErr.Code.(type) { case int: From a14bc103f4bc2b3ac40c844079fdf59dfdf62b0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 9 Aug 2023 13:07:14 +0900 Subject: [PATCH 058/206] docs: Add Contributing Guidelines (#463) --- CONTRIBUTING.md | 88 +++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 24 +++++--------- 2 files changed, 97 insertions(+), 15 deletions(-) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..4dd184042 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,88 @@ +# Contributing Guidelines + +## Overview +Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests. + +## Reporting Bugs +If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it. + +## Suggesting Features +If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion. + +## Reporting Vulnerabilities +If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published. + +## Questions for Users +If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions). + +## Contributing Code +There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one. + +### Requirements for Merging a Pull Request + +The requirements to accept a pull request are as follows: + +- Features not provided by the OpenAI API will not be accepted. +- The functionality of the feature must match that of the official OpenAI API. +- All pull requests should be written in Go according to common conventions, formatted with `goimports`, and free of warnings from tools like `golangci-lint`. +- Include tests and ensure all tests pass. +- Maintain test coverage without any reduction. +- All pull requests require approval from at least one Go OpenAI maintainer. + +**Note:** +The merging method for pull requests in this repository is squash merge. + +### Creating a Pull Request +- Fork the repository. +- Create a new branch and commit your changes. +- Push that branch to GitHub. +- Start a new Pull Request on GitHub. (Please use the pull request template to provide detailed information.) + +**Note:** +If your changes introduce breaking changes, please prefix your pull request title with "[BREAKING_CHANGES]". + +### Code Style +In this project, we adhere to the standard coding style of Go. Your code should maintain consistency with the rest of the codebase. To achieve this, please format your code using tools like `goimports` and resolve any syntax or style issues with `golangci-lint`. + +**Run goimports:** +``` +go install golang.org/x/tools/cmd/goimports@latest +``` + +``` +goimports -w . +``` + +**Run golangci-lint:** +``` +go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest +``` + +``` +golangci-lint run --out-format=github-actions +``` + +### Unit Test +Please create or update tests relevant to your changes. Ensure all tests run successfully to verify that your modifications do not adversely affect other functionalities. + +**Run test:** +``` +go test -v ./... +``` + +### Integration Test +Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. + +**Notes:** +These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. + +**Run integration test:** +``` +OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go +``` + +If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. + +--- + +We wholeheartedly welcome your active participation. Let's build an amazing project together! diff --git a/README.md b/README.md index d627a19ce..9714d89fe 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,16 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * DALL·E 2 * Whisper -### Installation: +## Installation + ``` go get github.com/sashabaranov/go-openai ``` Currently, go-openai requires Go version 1.18 or greater. + +## Usage + ### ChatGPT example usage: ```go @@ -680,20 +684,6 @@ func main() {
See the `examples/` folder for more. -### Integration tests: - -Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. - -**Notes:** -These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. - -**Run tests using:** -``` -OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go -``` - -If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. - ## Frequently Asked Questions ### Why don't we get the same answer when specifying a temperature field of 0 and asking the same question? @@ -725,6 +715,10 @@ For counting tokens, you might find the following links helpful: **Related Issues:** [Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62) +## Contributing + +By following [Contributing Guidelines](https://github.com/sashabaranov/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently. + ## Thank you We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project: From a2ca01bb6dae1a7d58860a5b2d5d5273667e089e Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 29 Aug 2023 14:04:27 +0200 Subject: [PATCH 059/206] feat: implement new fine tuning job API (#479) * feat: implement new fine tuning job API * fix: export ListFineTuningJobEventsParameter * fix: lint errors * fix: test errors * fix: code test coverage * fix: code test coverage * fix: use any * chore: use url.Values --- client_test.go | 12 ++++ fine_tuning_job.go | 153 ++++++++++++++++++++++++++++++++++++++++ fine_tuning_job_test.go | 90 +++++++++++++++++++++++ 3 files changed, 255 insertions(+) create mode 100644 fine_tuning_job.go create mode 100644 fine_tuning_job_test.go diff --git a/client_test.go b/client_test.go index 29d84edfa..9b5046899 100644 --- a/client_test.go +++ b/client_test.go @@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListFineTuneEvents", func() (any, error) { return client.ListFineTuneEvents(ctx, "") }}, + {"CreateFineTuningJob", func() (any, error) { + return client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + }}, + {"CancelFineTuningJob", func() (any, error) { + return client.CancelFineTuningJob(ctx, "") + }}, + {"RetrieveFineTuningJob", func() (any, error) { + return client.RetrieveFineTuningJob(ctx, "") + }}, + {"ListFineTuningJobEvents", func() (any, error) { + return client.ListFineTuningJobEvents(ctx, "") + }}, {"Moderations", func() (any, error) { return client.Moderations(ctx, ModerationRequest{}) }}, diff --git a/fine_tuning_job.go b/fine_tuning_job.go new file mode 100644 index 000000000..a840b7ec3 --- /dev/null +++ b/fine_tuning_job.go @@ -0,0 +1,153 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type FineTuningJob struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at"` + Model string `json:"model"` + FineTunedModel string `json:"fine_tuned_model,omitempty"` + OrganizationID string `json:"organization_id"` + Status string `json:"status"` + Hyperparameters Hyperparameters `json:"hyperparameters"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + ResultFiles []string `json:"result_files"` + TrainedTokens int `json:"trained_tokens"` +} + +type Hyperparameters struct { + Epochs int `json:"n_epochs"` +} + +type FineTuningJobRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTuningJobEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` + HasMore bool `json:"has_more"` +} + +type FineTuningJobEvent struct { + Object string `json:"object"` + ID string `json:"id"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + Data any `json:"data"` + Type string `json:"type"` +} + +// CreateFineTuningJob create a fine tuning job. +func (c *Client) CreateFineTuningJob( + ctx context.Context, + request FineTuningJobRequest, +) (response FineTuningJob, err error) { + urlSuffix := "/fine_tuning/jobs" + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelFineTuningJob cancel a fine tuning job. +func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel")) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveFineTuningJob retrieve a fine tuning job. +func (c *Client) RetrieveFineTuningJob( + ctx context.Context, + fineTuningJobID string, +) (response FineTuningJob, err error) { + urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type listFineTuningJobEventsParameters struct { + after *string + limit *int +} + +type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) + +func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.after = &after + } +} + +func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.limit = &limit + } +} + +// ListFineTuningJobs list fine tuning jobs events. +func (c *Client) ListFineTuningJobEvents( + ctx context.Context, + fineTuningJobID string, + setters ...ListFineTuningJobEventsParameter, +) (response FineTuningJobEventList, err error) { + parameters := &listFineTuningJobEventsParameters{ + after: nil, + limit: nil, + } + + for _, setter := range setters { + setter(parameters) + } + + urlValues := url.Values{} + if parameters.after != nil { + urlValues.Add("after", *parameters.after) + } + if parameters.limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit)) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go new file mode 100644 index 000000000..519c6cd2d --- /dev/null +++ b/fine_tuning_job_test.go @@ -0,0 +1,90 @@ +package openai_test + +import ( + "context" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +const testFineTuninigJobID = "fine-tuning-job-id" + +// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. +func TestFineTuningJob(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler( + "/v1/fine_tuning/jobs", + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID, + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJobEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + ctx := context.Background() + + _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + checks.NoError(t, err, "CreateFineTuningJob error") + + _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "CancelFineTuningJob error") + + _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "RetrieveFineTuningJob error") + + _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") +} From 25da859c189c62c2454717fb2214da079017ff8e Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 31 Aug 2023 12:14:39 +0200 Subject: [PATCH 060/206] Chore Deprecate legacy fine tunes API (#484) * chore: add deprecation message * chore: use new fine tuning API in README example --- README.md | 21 +++++++++++++-------- fine_tunes.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9714d89fe..440c40968 100644 --- a/README.md +++ b/README.md @@ -631,11 +631,16 @@ func main() { client := openai.NewClient("your token") ctx := context.Background() - // create a .jsonl file with your training data + // create a .jsonl file with your training data for conversational model // {"prompt": "", "completion": ""} // {"prompt": "", "completion": ""} // {"prompt": "", "completion": ""} + // chat models are trained using the following file format: + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} + // you can use openai cli tool to validate the data // For more info - https://platform.openai.com/docs/guides/fine-tuning @@ -648,29 +653,29 @@ func main() { return } - // create a fine tune job + // create a fine tuning job // Streams events until the job is done (this often takes minutes, but can take hours if there are many jobs in the queue or your dataset is large) // use below get method to know the status of your model - tune, err := client.CreateFineTune(ctx, openai.FineTuneRequest{ + fineTuningJob, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{ TrainingFile: file.ID, - Model: "ada", // babbage, curie, davinci, or a fine-tuned model created after 2022-04-21. + Model: "davinci-002", // gpt-3.5-turbo-0613, babbage-002. }) if err != nil { fmt.Printf("Creating new fine tune model error: %v\n", err) return } - getTune, err := client.GetFineTune(ctx, tune.ID) + fineTuningJob, err = client.RetrieveFineTuningJob(ctx, fineTuningJob.ID) if err != nil { fmt.Printf("Getting fine tune model error: %v\n", err) return } - fmt.Println(getTune.FineTunedModel) + fmt.Println(fineTuningJob.FineTunedModel) - // once the status of getTune is `succeeded`, you can use your fine tune model in Completion Request + // once the status of fineTuningJob is `succeeded`, you can use your fine tune model in Completion Request or Chat Completion Request // resp, err := client.CreateCompletion(ctx, openai.CompletionRequest{ - // Model: getTune.FineTunedModel, + // Model: fineTuningJob.FineTunedModel, // Prompt: "your prompt", // }) // if err != nil { diff --git a/fine_tunes.go b/fine_tunes.go index 96e731d51..7d3b59dbd 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -6,6 +6,9 @@ import ( "net/http" ) +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneRequest struct { TrainingFile string `json:"training_file"` ValidationFile string `json:"validation_file,omitempty"` @@ -21,6 +24,9 @@ type FineTuneRequest struct { Suffix string `json:"suffix,omitempty"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTune struct { ID string `json:"id"` Object string `json:"object"` @@ -37,6 +43,9 @@ type FineTune struct { UpdatedAt int64 `json:"updated_at"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEvent struct { Object string `json:"object"` CreatedAt int64 `json:"created_at"` @@ -44,6 +53,9 @@ type FineTuneEvent struct { Message string `json:"message"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneHyperParams struct { BatchSize int `json:"batch_size"` LearningRateMultiplier float64 `json:"learning_rate_multiplier"` @@ -51,21 +63,34 @@ type FineTuneHyperParams struct { PromptLossWeight float64 `json:"prompt_loss_weight"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` } + +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) @@ -78,6 +103,9 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r } // CancelFineTune cancel a fine-tune job. +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { @@ -88,6 +116,9 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) if err != nil { @@ -98,6 +129,9 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) @@ -109,6 +143,9 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) if err != nil { @@ -119,6 +156,9 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) if err != nil { From 3589837b229aeace205f312aa839bf73154e2820 Mon Sep 17 00:00:00 2001 From: NullpointerW <58949721+NullpointerW@users.noreply.github.com> Date: Thu, 7 Sep 2023 18:52:47 +0800 Subject: [PATCH 061/206] Update OpenAPI file return struct (#486) * completionBatchingRequestSupport * lint fix * fix Run test fail * fix TestClientReturnsRequestBuilderErrors fail * fix Codecov check * ignore TestClientReturnsRequestBuilderErrors lint * fix lint again * lint again*2 * replace checkPromptType implementation * remove nil check * update file return struct --------- Co-authored-by: W <825708370@qq.com> --- files.go | 15 ++++++++------- files_api_test.go | 1 - 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/files.go b/files.go index ea1f50a73..8b933c362 100644 --- a/files.go +++ b/files.go @@ -17,13 +17,14 @@ type FileRequest struct { // File struct represents an OpenAPI file. type File struct { - Bytes int `json:"bytes"` - CreatedAt int64 `json:"created_at"` - ID string `json:"id"` - FileName string `json:"filename"` - Object string `json:"object"` - Owner string `json:"owner"` - Purpose string `json:"purpose"` + Bytes int `json:"bytes"` + CreatedAt int64 `json:"created_at"` + ID string `json:"id"` + FileName string `json:"filename"` + Object string `json:"object"` + Status string `json:"status"` + Purpose string `json:"purpose"` + StatusDetails string `json:"status_details"` } // FilesList is a list of files that belong to the user or organization. diff --git a/files_api_test.go b/files_api_test.go index f0a08764d..1cbc72894 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -64,7 +64,6 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { Purpose: purpose, CreatedAt: time.Now().Unix(), Object: "test-objecct", - Owner: "test-owner", } resBytes, _ = json.Marshal(fileReq) From 8e4b7963a3f378332bd512a5040d75d8504505c8 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 11 Sep 2023 15:44:46 +0200 Subject: [PATCH 062/206] Chore Support base64 embedding format (#485) * chore: support base64 embedding format * fix: add sizeOfFloat32 * chore: refactor base64 decoding * chore: add tests * fix linting * fix test * fix return error * fix: use smaller slice for tests * fix [skip ci] * chore: refactor test to consider CreateEmbeddings response * trigger build * chore: remove named returns * chore: refactor code to simplify the understanding * chore: tests have been refactored to match the encoding format passed by request * chore: fix tests * fix * fix --- embeddings.go | 116 +++++++++++++++++++++++++++++++++++---- embeddings_test.go | 131 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 229 insertions(+), 18 deletions(-) diff --git a/embeddings.go b/embeddings.go index 1d3199597..5ba91f235 100644 --- a/embeddings.go +++ b/embeddings.go @@ -2,6 +2,9 @@ package openai import ( "context" + "encoding/base64" + "encoding/binary" + "math" "net/http" ) @@ -129,15 +132,83 @@ type EmbeddingResponse struct { Usage Usage `json:"usage"` } +type base64String string + +func (b base64String) Decode() ([]float32, error) { + decodedData, err := base64.StdEncoding.DecodeString(string(b)) + if err != nil { + return nil, err + } + + const sizeOfFloat32 = 4 + floats := make([]float32, len(decodedData)/sizeOfFloat32) + for i := 0; i < len(floats); i++ { + floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4])) + } + + return floats, nil +} + +// Base64Embedding is a container for base64 encoded embeddings. +type Base64Embedding struct { + Object string `json:"object"` + Embedding base64String `json:"embedding"` + Index int `json:"index"` +} + +// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format. +type EmbeddingResponseBase64 struct { + Object string `json:"object"` + Data []Base64Embedding `json:"data"` + Model EmbeddingModel `json:"model"` + Usage Usage `json:"usage"` +} + +// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. +func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) { + data := make([]Embedding, len(r.Data)) + + for i, base64Embedding := range r.Data { + embedding, err := base64Embedding.Embedding.Decode() + if err != nil { + return EmbeddingResponse{}, err + } + + data[i] = Embedding{ + Object: base64Embedding.Object, + Embedding: embedding, + Index: base64Embedding.Index, + } + } + + return EmbeddingResponse{ + Object: r.Object, + Model: r.Model, + Data: data, + Usage: r.Usage, + }, nil +} + type EmbeddingRequestConverter interface { // Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens Convert() EmbeddingRequest } +// EmbeddingEncodingFormat is the format of the embeddings data. +// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. +// If not specified OpenAI will use "float". +type EmbeddingEncodingFormat string + +const ( + EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float" + EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64" +) + type EmbeddingRequest struct { - Input any `json:"input"` - Model EmbeddingModel `json:"model"` - User string `json:"user"` + Input any `json:"input"` + Model EmbeddingModel `json:"model"` + User string `json:"user"` + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -158,13 +229,18 @@ type EmbeddingRequestStrings struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { return EmbeddingRequest{ - Input: r.Input, - Model: r.Model, - User: r.User, + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, } } @@ -181,13 +257,18 @@ type EmbeddingRequestTokens struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { return EmbeddingRequest{ - Input: r.Input, - Model: r.Model, - User: r.User, + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, } } @@ -196,14 +277,27 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { // // Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens // for embedding groups of text already converted to tokens. -func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll +func (c *Client) CreateEmbeddings( + ctx context.Context, + conv EmbeddingRequestConverter, +) (res EmbeddingResponse, err error) { baseReq := conv.Convert() req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) if err != nil { return } - err = c.sendRequest(req, &res) + if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 { + err = c.sendRequest(req, &res) + return + } + + base64Response := &EmbeddingResponseBase64{} + err = c.sendRequest(req, base64Response) + if err != nil { + return + } + res, err = base64Response.ToEmbeddingResponse() return } diff --git a/embeddings_test.go b/embeddings_test.go index 47c4f5108..9c48c5b8f 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -1,15 +1,16 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "bytes" "context" "encoding/json" "fmt" "net/http" + "reflect" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { @@ -97,22 +98,138 @@ func TestEmbeddingModel(t *testing.T) { func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() + + sampleEmbeddings := []Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + } + + sampleBase64Embeddings := []Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + } + server.RegisterHandler( "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EmbeddingResponse{}) + var req struct { + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` + } + _ = json.NewDecoder(r.Body).Decode(&req) + + var resBytes []byte + switch { + case req.User == "invalid": + w.WriteHeader(http.StatusBadRequest) + return + case req.EncodingFormat == EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + default: + resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) + } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) - _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (simple embedding request) + res, err = client.CreateEmbeddings( + context.Background(), + EmbeddingRequest{ + EncodingFormat: EmbeddingEncodingFormatBase64, + }, + ) checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } // test create embeddings with strings - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } // test create embeddings with tokens - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test failed sendRequest + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ + User: "invalid", + EncodingFormat: EmbeddingEncodingFormatBase64, + }) + checks.HasError(t, err, "CreateEmbeddings error") +} + +func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { + type fields struct { + Object string + Data []Base64Embedding + Model EmbeddingModel + Usage Usage + } + tests := []struct { + name string + fields fields + want EmbeddingResponse + wantErr bool + }{ + { + name: "test embedding response base64 to embedding response", + fields: fields{ + Data: []Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + }, + }, + want: EmbeddingResponse{ + Data: []Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + }, + }, + wantErr: false, + }, + { + name: "Invalid embedding", + fields: fields{ + Data: []Base64Embedding{ + { + Embedding: "----", + }, + }, + }, + want: EmbeddingResponse{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &EmbeddingResponseBase64{ + Object: tt.fields.Object, + Data: tt.fields.Data, + Model: tt.fields.Model, + Usage: tt.fields.Usage, + } + got, err := r.ToEmbeddingResponse() + if (err != nil) != tt.wantErr { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want) + } + }) + } } From 0d5256fb820a34a95b8944b9410a1e562087cd8f Mon Sep 17 00:00:00 2001 From: Brendan Martin Date: Mon, 25 Sep 2023 04:08:45 -0400 Subject: [PATCH 063/206] added delete fine tune model endpoint (#497) --- client_test.go | 3 +++ models.go | 20 ++++++++++++++++++++ models_test.go | 15 +++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/client_test.go b/client_test.go index 9b5046899..2c1d749ed 100644 --- a/client_test.go +++ b/client_test.go @@ -271,6 +271,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"GetModel", func() (any, error) { return client.GetModel(ctx, "text-davinci-003") }}, + {"DeleteFineTuneModel", func() (any, error) { + return client.DeleteFineTuneModel(ctx, "") + }}, } for _, testCase := range testCases { diff --git a/models.go b/models.go index 560402e3f..c207f0a86 100644 --- a/models.go +++ b/models.go @@ -33,6 +33,13 @@ type Permission struct { IsBlocking bool `json:"is_blocking"` } +// FineTuneModelDeleteResponse represents the deletion status of a fine-tuned model. +type FineTuneModelDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` @@ -62,3 +69,16 @@ func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err err = c.sendRequest(req, &model) return } + +// DeleteFineTuneModel Deletes a fine-tune model. You must have the Owner +// role in your organization to delete a model. +func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( + response FineTuneModelDeleteResponse, err error) { + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/models/"+modelID)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/models_test.go b/models_test.go index 59b4f5ef7..9ff73042a 100644 --- a/models_test.go +++ b/models_test.go @@ -14,6 +14,8 @@ import ( "testing" ) +const testFineTuneModelID = "fine-tune-model-id" + // TestListModels Tests the list models endpoint of the API using the mocked server. func TestListModels(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -78,3 +80,16 @@ func TestGetModelReturnTimeoutError(t *testing.T) { t.Fatal("Did not return timeout error") } } + +func TestDeleteFineTuneModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/"+testFineTuneModelID, handleDeleteFineTuneModelEndpoint) + _, err := client.DeleteFineTuneModel(context.Background(), testFineTuneModelID) + checks.NoError(t, err, "DeleteFineTuneModel error") +} + +func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) + fmt.Fprintln(w, string(resBytes)) +} From 84f77a0acda6eb541f3312ed8f7711c89e661443 Mon Sep 17 00:00:00 2001 From: "e. alvarez" <55966724+ealvar3z@users.noreply.github.com> Date: Mon, 2 Oct 2023 07:39:10 -0700 Subject: [PATCH 064/206] Add DotProduct Method and README Example for Embedding Similarity Search (#492) * Add DotProduct Method and README Example for Embedding Similarity Search - Implement a DotProduct() method for the Embedding struct to calculate the dot product between two embeddings. - Add a custom error type for vector length mismatch. - Update README.md with a complete example demonstrating how to perform an embedding similarity search for user queries. - Add unit tests to validate the new DotProduct() method and error handling. * Update README to focus on Embedding Semantic Similarity --- README.md | 56 ++++++++++++++++++++++++++++++++++++++++++++++ embeddings.go | 20 +++++++++++++++++ embeddings_test.go | 38 +++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+) diff --git a/README.md b/README.md index 440c40968..c618cd7fa 100644 --- a/README.md +++ b/README.md @@ -483,6 +483,62 @@ func main() { ``` + +Embedding Semantic Similarity + +```go +package main + +import ( + "context" + "log" + openai "github.com/sashabaranov/go-openai" + +) + +func main() { + client := openai.NewClient("your-token") + + // Create an EmbeddingRequest for the user query + queryReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the user query + queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq) + if err != nil { + log.Fatal("Error creating query embedding:", err) + } + + // Create an EmbeddingRequest for the target text + targetReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the target text + targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq) + if err != nil { + log.Fatal("Error creating target embedding:", err) + } + + // Now that we have the embeddings for the user query and the target text, we + // can calculate their similarity. + queryEmbedding := queryResponse.Data[0] + targetEmbedding := targetResponse.Data[0] + + similarity, err := queryEmbedding.DotProduct(&targetEmbedding) + if err != nil { + log.Fatal("Error calculating dot product:", err) + } + + log.Printf("The similarity score between the query and the target is %f", similarity) +} + +``` + +
Azure OpenAI Embeddings diff --git a/embeddings.go b/embeddings.go index 5ba91f235..660bc24c3 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,10 +4,13 @@ import ( "context" "encoding/base64" "encoding/binary" + "errors" "math" "net/http" ) +var ErrVectorLengthMismatch = errors.New("vector length mismatch") + // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. type EmbeddingModel int @@ -124,6 +127,23 @@ type Embedding struct { Index int `json:"index"` } +// DotProduct calculates the dot product of the embedding vector with another +// embedding vector. Both vectors must have the same length; otherwise, an +// ErrVectorLengthMismatch is returned. The method returns the calculated dot +// product as a float32 value. +func (e *Embedding) DotProduct(other *Embedding) (float32, error) { + if len(e.Embedding) != len(other.Embedding) { + return 0, ErrVectorLengthMismatch + } + + var dotProduct float32 + for i := range e.Embedding { + dotProduct += e.Embedding[i] * other.Embedding[i] + } + + return dotProduct, nil +} + // EmbeddingResponse is the response from a Create embeddings request. type EmbeddingResponse struct { Object string `json:"object"` diff --git a/embeddings_test.go b/embeddings_test.go index 9c48c5b8f..72e8c245f 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" + "math" "net/http" "reflect" "testing" @@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { }) } } + +func TestDotProduct(t *testing.T) { + v1 := &Embedding{Embedding: []float32{1, 2, 3}} + v2 := &Embedding{Embedding: []float32{2, 4, 6}} + expected := float32(28.0) + + result, err := v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1, 0}} + expected = float32(0.0) + + result, err = v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + // Test for VectorLengthMismatchError + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1}} + _, err = v1.DotProduct(v2) + if !errors.Is(err, ErrVectorLengthMismatch) { + t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) + } +} From 533935e4fc31f2542ef77d3e545a527c756b641c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 6 Oct 2023 11:32:21 +0200 Subject: [PATCH 065/206] fix: use any for n_epochs (#499) * fix: use custom marshaler for n_epochs * chore: use any for n_epochs --- fine_tuning_job.go | 2 +- fine_tuning_job_test.go | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index a840b7ec3..07b0c337c 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -24,7 +24,7 @@ type FineTuningJob struct { } type Hyperparameters struct { - Epochs int `json:"n_epochs"` + Epochs any `json:"n_epochs,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index 519c6cd2d..f6d41c33d 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -21,8 +21,23 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/v1/fine_tuning/jobs", func(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - resBytes, _ = json.Marshal(FineTuningJob{}) + resBytes, _ := json.Marshal(FineTuningJob{ + Object: "fine_tuning.job", + ID: testFineTuninigJobID, + Model: "davinci-002", + CreatedAt: 1692661014, + FinishedAt: 1692661190, + FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + OrganizationID: "org-123", + ResultFiles: []string{"file-abc123"}, + Status: "succeeded", + ValidationFile: "", + TrainingFile: "file-abc123", + Hyperparameters: Hyperparameters{ + Epochs: "auto", + }, + TrainedTokens: 5768, + }) fmt.Fprintln(w, string(resBytes)) }, ) From 8e165dc9aadc9f7045b91dd1b02d6404940dc023 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 9 Oct 2023 17:41:54 +0200 Subject: [PATCH 066/206] Feat Add headers to openai responses (#506) * feat: add headers to http response * chore: add test * fix: rename to httpHeader --- audio.go | 19 ++++++++++++++++++- chat.go | 2 ++ chat_test.go | 30 ++++++++++++++++++++++++++++++ client.go | 20 +++++++++++++++++++- completion.go | 2 ++ edits.go | 2 ++ embeddings.go | 4 ++++ engines.go | 4 ++++ files.go | 4 ++++ fine_tunes.go | 8 ++++++++ fine_tuning_job.go | 4 ++++ image.go | 2 ++ models.go | 6 ++++++ moderation.go | 2 ++ 14 files changed, 107 insertions(+), 2 deletions(-) diff --git a/audio.go b/audio.go index 9f469159d..4cbe4fe64 100644 --- a/audio.go +++ b/audio.go @@ -63,6 +63,21 @@ type AudioResponse struct { Transient bool `json:"transient"` } `json:"segments"` Text string `json:"text"` + + httpHeader +} + +type audioTextResponse struct { + Text string `json:"text"` + + httpHeader +} + +func (r *audioTextResponse) ToAudioResponse() AudioResponse { + return AudioResponse{ + Text: r.Text, + httpHeader: r.httpHeader, + } } // CreateTranscription — API call to create a transcription. Returns transcribed text. @@ -104,7 +119,9 @@ func (c *Client) callAudioAPI( if request.HasJSONResponse() { err = c.sendRequest(req, &response) } else { - err = c.sendRequest(req, &response.Text) + var textResponse audioTextResponse + err = c.sendRequest(req, &textResponse) + response = textResponse.ToAudioResponse() } if err != nil { return AudioResponse{}, err diff --git a/chat.go b/chat.go index 8d29b3237..df0e5f970 100644 --- a/chat.go +++ b/chat.go @@ -142,6 +142,8 @@ type ChatCompletionResponse struct { Model string `json:"model"` Choices []ChatCompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateChatCompletion — API call to Create a completion for the chat message. diff --git a/chat_test.go b/chat_test.go index 38d66fa64..52cd0bdef 100644 --- a/chat_test.go +++ b/chat_test.go @@ -16,6 +16,11 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) +const ( + xCustomHeader = "X-CUSTOM-HEADER" + xCustomHeaderValue = "test" +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -68,6 +73,30 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + a := resp.Header().Get(xCustomHeader) + _ = a + if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { + t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -281,6 +310,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 5779a8e1c..19902285b 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,20 @@ type Client struct { createFormBuilder func(io.Writer) utils.FormBuilder } +type Response interface { + SetHeader(http.Header) +} + +type httpHeader http.Header + +func (h *httpHeader) SetHeader(header http.Header) { + *h = httpHeader(header) +} + +func (h httpHeader) Header() http.Header { + return http.Header(h) +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -82,7 +96,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... return req, nil } -func (c *Client) sendRequest(req *http.Request, v any) error { +func (c *Client) sendRequest(req *http.Request, v Response) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Check whether Content-Type is already set, Upload Files API requires @@ -103,6 +117,10 @@ func (c *Client) sendRequest(req *http.Request, v any) error { return c.handleErrorResp(res) } + if v != nil { + v.SetHeader(res.Header) + } + return decodeResponse(res.Body, v) } diff --git a/completion.go b/completion.go index 7b9ae89e7..c7ff94afc 100644 --- a/completion.go +++ b/completion.go @@ -154,6 +154,8 @@ type CompletionResponse struct { Model string `json:"model"` Choices []CompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well diff --git a/edits.go b/edits.go index 831aade2f..97d026029 100644 --- a/edits.go +++ b/edits.go @@ -28,6 +28,8 @@ type EditsResponse struct { Created int64 `json:"created"` Usage Usage `json:"usage"` Choices []EditsChoice `json:"choices"` + + httpHeader } // Edits Perform an API call to the Edits endpoint. diff --git a/embeddings.go b/embeddings.go index 660bc24c3..7e2aa7eb0 100644 --- a/embeddings.go +++ b/embeddings.go @@ -150,6 +150,8 @@ type EmbeddingResponse struct { Data []Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } type base64String string @@ -182,6 +184,8 @@ type EmbeddingResponseBase64 struct { Data []Base64Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } // ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. diff --git a/engines.go b/engines.go index adf6025c2..5a0dba858 100644 --- a/engines.go +++ b/engines.go @@ -12,11 +12,15 @@ type Engine struct { Object string `json:"object"` Owner string `json:"owner"` Ready bool `json:"ready"` + + httpHeader } // EnginesList is a list of engines. type EnginesList struct { Engines []Engine `json:"data"` + + httpHeader } // ListEngines Lists the currently available engines, and provides basic diff --git a/files.go b/files.go index 8b933c362..9e521fbbe 100644 --- a/files.go +++ b/files.go @@ -25,11 +25,15 @@ type File struct { Status string `json:"status"` Purpose string `json:"purpose"` StatusDetails string `json:"status_details"` + + httpHeader } // FilesList is a list of files that belong to the user or organization. type FilesList struct { Files []File `json:"data"` + + httpHeader } // CreateFile uploads a jsonl file to GPT3 diff --git a/fine_tunes.go b/fine_tunes.go index 7d3b59dbd..ca840781c 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -41,6 +41,8 @@ type FineTune struct { ValidationFiles []File `json:"validation_files"` TrainingFiles []File `json:"training_files"` UpdatedAt int64 `json:"updated_at"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -69,6 +71,8 @@ type FineTuneHyperParams struct { type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -77,6 +81,8 @@ type FineTuneList struct { type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -86,6 +92,8 @@ type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 07b0c337c..9dcb49de1 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -21,6 +21,8 @@ type FineTuningJob struct { ValidationFile string `json:"validation_file,omitempty"` ResultFiles []string `json:"result_files"` TrainedTokens int `json:"trained_tokens"` + + httpHeader } type Hyperparameters struct { @@ -39,6 +41,8 @@ type FineTuningJobEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` HasMore bool `json:"has_more"` + + httpHeader } type FineTuningJobEvent struct { diff --git a/image.go b/image.go index cb96f4f5e..4addcdb1e 100644 --- a/image.go +++ b/image.go @@ -33,6 +33,8 @@ type ImageRequest struct { type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + + httpHeader } // ImageResponseDataInner represents a response data structure for image API. diff --git a/models.go b/models.go index c207f0a86..d94f98836 100644 --- a/models.go +++ b/models.go @@ -15,6 +15,8 @@ type Model struct { Permission []Permission `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` + + httpHeader } // Permission struct represents an OpenAPI permission. @@ -38,11 +40,15 @@ type FineTuneModelDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` + + httpHeader } // ListModels Lists the currently available models, diff --git a/moderation.go b/moderation.go index a32f123f3..f8d20ee51 100644 --- a/moderation.go +++ b/moderation.go @@ -69,6 +69,8 @@ type ModerationResponse struct { ID string `json:"id"` Model string `json:"model"` Results []Result `json:"results"` + + httpHeader } // Moderations — perform a moderation api call over a string. From b77d01edca43500f267c4b43333f645b84a4fcf0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 10 Oct 2023 10:29:41 -0500 Subject: [PATCH 067/206] Support get http header and x-ratelimit-* headers (#507) * feat: add headers to http response * feat: support rate limit headers * fix: go lint * fix: test coverage * refactor streamReader * refactor streamReader * refactor: NewRateLimitHeaders to newRateLimitHeaders * refactor: RateLimitHeaders Resets filed * refactor: move RateLimitHeaders struct --- chat_stream_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++-- chat_test.go | 53 +++++++++++++++++++++++++++ client.go | 9 ++++- ratelimit.go | 43 ++++++++++++++++++++++ stream_reader.go | 2 + 5 files changed, 191 insertions(+), 5 deletions(-) create mode 100644 ratelimit.go diff --git a/chat_stream_test.go b/chat_stream_test.go index 5fc70b032..2c109d454 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,15 +1,17 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" + "fmt" "io" "net/http" + "strconv" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { @@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set(xCustomHeader, xCustomHeaderValue) + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + value := stream.Header().Get(xCustomHeader) + if value != xCustomHeaderValue { + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) + } +} + +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + headers := stream.GetRateLimitHeaders() + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index 52cd0bdef..329b2b9cb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -21,6 +21,17 @@ const ( xCustomHeaderValue = "test" ) +var ( + rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", + } +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) { } } +// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + headers := resp.GetRateLimitHeaders() + resetRequests := headers.ResetRequests.String() + if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] { + t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"]) + } + resetRequestsTime := headers.ResetRequests.Time() + if resetRequestsTime.Before(time.Now()) { + t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + } + + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } resBytes, _ = json.Marshal(res) w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 19902285b..65ece812f 100644 --- a/client.go +++ b/client.go @@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) { *h = httpHeader(header) } -func (h httpHeader) Header() http.Header { - return http.Header(h) +func (h *httpHeader) Header() http.Header { + return http.Header(*h) +} + +func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { + return newRateLimitHeaders(h.Header()) } // NewClient creates new OpenAI API client. @@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream response: resp, errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, + httpHeader: httpHeader(resp.Header), }, nil } diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 000000000..e8953f716 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,43 @@ +package openai + +import ( + "net/http" + "strconv" + "time" +) + +// RateLimitHeaders struct represents Openai rate limits headers. +type RateLimitHeaders struct { + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` +} + +type ResetTime string + +func (r ResetTime) String() string { + return string(r) +} + +func (r ResetTime) Time() time.Time { + d, _ := time.ParseDuration(string(r)) + return time.Now().Add(d) +} + +func newRateLimitHeaders(h http.Header) RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return RateLimitHeaders{ + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), + } +} diff --git a/stream_reader.go b/stream_reader.go index 87e59e0ca..d17412591 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -27,6 +27,8 @@ type streamReader[T streamable] struct { response *http.Response errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler + + httpHeader } func (stream *streamReader[T]) Recv() (response T, err error) { From c47ddfc1a13b850115a80b03f3f9dd1822733bf7 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 10 Oct 2023 21:22:45 +0400 Subject: [PATCH 068/206] Update README.md (#511) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c618cd7fa..b41947be5 100644 --- a/README.md +++ b/README.md @@ -483,7 +483,7 @@ func main() { ```
- +
Embedding Semantic Similarity ```go @@ -537,7 +537,7 @@ func main() { } ``` - +
Azure OpenAI Embeddings From 6c52952b691ec294b7987689a5292a87a9acdbcb Mon Sep 17 00:00:00 2001 From: Simon Klee Date: Mon, 6 Nov 2023 21:22:48 +0100 Subject: [PATCH 069/206] feat(completion): add constants for new GPT models (#520) Added constants for new GPT models including `gpt-4-1106-preview`, `gpt-4-vision-preview` and `gpt-3.5-turbo-1106`. The models were announced in the following blog post: https://openai.com/blog/new-models-and-developer-products-announced-at-devday --- completion.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/completion.go b/completion.go index c7ff94afc..2709c8b03 100644 --- a/completion.go +++ b/completion.go @@ -22,7 +22,10 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4TurboPreview = "gpt-4-1106-preview" + GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" @@ -69,9 +72,12 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo1106: true, GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, + GPT4TurboPreview: true, + GPT4VisionPreview: true, GPT40314: true, GPT40613: true, GPT432K: true, From 9e0232f941a0f2c1780bf20743effd051a39e4d3 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Mon, 6 Nov 2023 12:27:08 -0800 Subject: [PATCH 070/206] Fix typo in README: AdaEmbeddingV2 (#516) Copy-pasting the old sample caused compilation errors --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b41947be5..f0b609088 100644 --- a/README.md +++ b/README.md @@ -502,7 +502,7 @@ func main() { // Create an EmbeddingRequest for the user query queryReq := openai.EmbeddingRequest{ Input: []string{"How many chucks would a woodchuck chuck"}, - Model: openai.AdaEmbeddingv2, + Model: openai.AdaEmbeddingV2, } // Create an embedding for the user query @@ -514,7 +514,7 @@ func main() { // Create an EmbeddingRequest for the target text targetReq := openai.EmbeddingRequest{ Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, - Model: openai.AdaEmbeddingv2, + Model: openai.AdaEmbeddingV2, } // Create an embedding for the target text From 0664105387f52c99b13bb40fcbf966a8b8c8d838 Mon Sep 17 00:00:00 2001 From: Simon Klee Date: Tue, 7 Nov 2023 10:23:06 +0100 Subject: [PATCH 071/206] lint: fix linter warnings reported by golangci-lint (#522) - Fix #519 --- api_integration_test.go | 1 - audio_api_test.go | 14 ++-- audio_test.go | 2 +- chat_stream_test.go | 110 ++++++++++++++-------------- chat_test.go | 154 ++++++++++++++++++++-------------------- completion_test.go | 42 +++++------ config_test.go | 4 +- edits_test.go | 24 +++---- embeddings_test.go | 110 ++++++++++++++-------------- engines_test.go | 12 ++-- error_test.go | 60 ++++++++-------- example_test.go | 2 - files_api_test.go | 12 ++-- files_test.go | 6 +- fine_tunes.go | 1 + fine_tunes_test.go | 24 +++---- fine_tuning_job_test.go | 35 +++++---- image_api_test.go | 52 +++++++------- jsonschema/json_test.go | 62 ++++++++-------- models_test.go | 17 +++-- moderation_test.go | 52 +++++++------- openai_test.go | 14 ++-- stream_test.go | 46 ++++++------ 23 files changed, 425 insertions(+), 431 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 254fbeb03..6be188bc6 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -9,7 +9,6 @@ import ( "os" "testing" - . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) diff --git a/audio_api_test.go b/audio_api_test.go index aad7a225a..a0efc7921 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -12,7 +12,7 @@ import ( "strings" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -26,7 +26,7 @@ func TestAudio(t *testing.T) { testcases := []struct { name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) }{ { "transcribe", @@ -48,7 +48,7 @@ func TestAudio(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: path, Model: "whisper-3", } @@ -57,7 +57,7 @@ func TestAudio(t *testing.T) { }) t.Run(tc.name+" (with reader)", func(t *testing.T) { - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: "fake.webm", Reader: bytes.NewBuffer([]byte(`some webm binary data`)), Model: "whisper-3", @@ -76,7 +76,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { testcases := []struct { name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) }{ { "transcribe", @@ -98,13 +98,13 @@ func TestAudioWithOptionalArgs(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: path, Model: "whisper-3", Prompt: "用简体中文", Temperature: 0.5, Language: "zh", - Format: AudioResponseFormatSRT, + Format: openai.AudioResponseFormatSRT, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index e19a873f3..5346244c8 100644 --- a/audio_test.go +++ b/audio_test.go @@ -40,7 +40,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } diff --git a/chat_stream_test.go b/chat_stream_test.go index 2c109d454..bd571cb48 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -10,28 +10,28 @@ import ( "strconv" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletionStream(ctx, req) - if !errors.Is(err, ErrChatCompletionInvalidModel) { + if !errors.Is(err, openai.ErrChatCompletionInvalidModel) { t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) } } @@ -39,7 +39,7 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -61,12 +61,12 @@ func TestCreateChatCompletionStream(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -75,15 +75,15 @@ func TestCreateChatCompletionStream(t *testing.T) { checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []ChatCompletionStreamResponse{ + expectedResponses := []openai.ChatCompletionStreamResponse{ { ID: "1", Object: "completion", Created: 1598069254, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + Model: openai.GPT3Dot5Turbo, + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response1", }, FinishReason: "max_tokens", @@ -94,10 +94,10 @@ func TestCreateChatCompletionStream(t *testing.T) { ID: "2", Object: "completion", Created: 1598069255, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + Model: openai.GPT3Dot5Turbo, + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response2", }, FinishReason: "max_tokens", @@ -133,7 +133,7 @@ func TestCreateChatCompletionStream(t *testing.T) { func TestCreateChatCompletionStreamError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -156,12 +156,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -173,7 +173,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -183,7 +183,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set(xCustomHeader, xCustomHeaderValue) @@ -196,12 +196,12 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -219,7 +219,7 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") for k, v := range rateLimitHeaders { switch val := v.(type) { @@ -239,12 +239,12 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -264,7 +264,7 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -276,12 +276,12 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -293,7 +293,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -303,7 +303,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -317,18 +317,18 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Stream: true, }) - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(err, &apiErr) { t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") } @@ -345,7 +345,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) // Send test responses @@ -355,13 +355,13 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - apiErr := &APIError{} - _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + apiErr := &openai.APIError{} + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -387,7 +387,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } // Helper funcs. -func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { +func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -402,7 +402,7 @@ func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { return true } -func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { +func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false } diff --git a/chat_test.go b/chat_test.go index 329b2b9cb..5bf1eaf6c 100644 --- a/chat_test.go +++ b/chat_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -21,49 +21,47 @@ const ( xCustomHeaderValue = "test" ) -var ( - rateLimitHeaders = map[string]any{ - "x-ratelimit-limit-requests": 60, - "x-ratelimit-limit-tokens": 150000, - "x-ratelimit-remaining-requests": 59, - "x-ratelimit-remaining-tokens": 149984, - "x-ratelimit-reset-requests": "1s", - "x-ratelimit-reset-tokens": "6m0s", - } -) +var rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", +} func TestChatCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletion(ctx, req) msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) - checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg) + checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } func TestChatCompletionsWithStream(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ Stream: true, } _, err := client.CreateChatCompletion(ctx, req) - checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error") + checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error") } // TestCompletions Tests the completions endpoint of the API using the mocked server. @@ -71,12 +69,12 @@ func TestChatCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -89,12 +87,12 @@ func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -113,12 +111,12 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -150,16 +148,16 @@ func TestChatCompletionsFunctions(t *testing.T) { t.Run("bytes", func(t *testing.T) { //nolint:lll msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -175,16 +173,16 @@ func TestChatCompletionsFunctions(t *testing.T) { Count: 2, Words: []string{"hello", "world"}, } - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -192,16 +190,16 @@ func TestChatCompletionsFunctions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion with functions error") }) t.Run("JSONSchemaDefinition", func(t *testing.T) { - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, @@ -229,16 +227,16 @@ func TestChatCompletionsFunctions(t *testing.T) { }) t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { // this is a compatibility check - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefine{{ + Functions: []openai.FunctionDefine{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, @@ -271,12 +269,12 @@ func TestAzureChatCompletions(t *testing.T) { defer teardown() server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -293,12 +291,12 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq ChatCompletionRequest + var completionReq openai.ChatCompletionRequest if completionReq, err = getChatCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ChatCompletionResponse{ + res := openai.ChatCompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -323,11 +321,11 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { return } - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleFunction, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleFunction, // this is valid json so it should be fine - FunctionCall: &FunctionCall{ + FunctionCall: &openai.FunctionCall{ Name: completionReq.Functions[0].Name, Arguments: string(fcb), }, @@ -339,9 +337,9 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleAssistant, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, Content: completionStr, }, Index: i, @@ -349,7 +347,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } inputTokens := numTokens(completionReq.Messages[0].Content) * n completionTokens := completionReq.MaxTokens * n - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -368,23 +366,23 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } // getChatCompletionBody Returns the body of the request to create a completion. -func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { - completion := ChatCompletionRequest{} +func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { + completion := openai.ChatCompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } return completion, nil } func TestFinishReason(t *testing.T) { - c := &ChatCompletionChoice{ - FinishReason: FinishReasonNull, + c := &openai.ChatCompletionChoice{ + FinishReason: openai.FinishReasonNull, } resBytes, _ := json.Marshal(c) if !strings.Contains(string(resBytes), `"finish_reason":null`) { @@ -398,11 +396,11 @@ func TestFinishReason(t *testing.T) { t.Error("null should not be quoted") } - otherReasons := []FinishReason{ - FinishReasonStop, - FinishReasonLength, - FinishReasonFunctionCall, - FinishReasonContentFilter, + otherReasons := []openai.FinishReason{ + openai.FinishReasonStop, + openai.FinishReasonLength, + openai.FinishReasonFunctionCall, + openai.FinishReasonContentFilter, } for _, r := range otherReasons { c.FinishReason = r diff --git a/completion_test.go b/completion_test.go index 844ef484f..89950bf94 100644 --- a/completion_test.go +++ b/completion_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" @@ -14,33 +11,36 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletion( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } func TestCompletionWithStream(t *testing.T) { - config := DefaultConfig("whatever") - client := NewClientWithConfig(config) + config := openai.DefaultConfig("whatever") + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := CompletionRequest{Stream: true} + req := openai.CompletionRequest{Stream: true} _, err := client.CreateCompletion(ctx, req) - if !errors.Is(err, ErrCompletionStreamNotSupported) { + if !errors.Is(err, openai.ErrCompletionStreamNotSupported) { t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported") } } @@ -50,7 +50,7 @@ func TestCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - req := CompletionRequest{ + req := openai.CompletionRequest{ MaxTokens: 5, Model: "ada", Prompt: "Lorem ipsum", @@ -68,12 +68,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq CompletionRequest + var completionReq openai.CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := CompletionResponse{ + res := openai.CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -93,14 +93,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if completionReq.Echo { completionStr = completionReq.Prompt.(string) + completionStr } - res.Choices = append(res.Choices, CompletionChoice{ + res.Choices = append(res.Choices, openai.CompletionChoice{ Text: completionStr, Index: i, }) } inputTokens := numTokens(completionReq.Prompt.(string)) * n completionTokens := completionReq.MaxTokens * n - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -110,16 +110,16 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} +func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { + completion := openai.CompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } return completion, nil } diff --git a/config_test.go b/config_test.go index 488511b11..3e528c3e9 100644 --- a/config_test.go +++ b/config_test.go @@ -3,7 +3,7 @@ package openai_test import ( "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" ) func TestGetAzureDeploymentByModel(t *testing.T) { @@ -49,7 +49,7 @@ func TestGetAzureDeploymentByModel(t *testing.T) { for _, c := range cases { t.Run(c.Model, func(t *testing.T) { - conf := DefaultAzureConfig("", "https://test.openai.azure.com/") + conf := openai.DefaultAzureConfig("", "https://test.openai.azure.com/") if c.AzureModelMapperFunc != nil { conf.AzureModelMapperFunc = c.AzureModelMapperFunc } diff --git a/edits_test.go b/edits_test.go index c0bb84392..d2a6db40d 100644 --- a/edits_test.go +++ b/edits_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -11,6 +8,9 @@ import ( "net/http" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestEdits Tests the edits endpoint of the API using the mocked server. @@ -20,7 +20,7 @@ func TestEdits(t *testing.T) { server.RegisterHandler("/v1/edits", handleEditEndpoint) // create an edit request model := "ada" - editReq := EditsRequest{ + editReq := openai.EditsRequest{ Model: &model, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + @@ -45,14 +45,14 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq EditsRequest + var editReq openai.EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := EditsResponse{ + res := openai.EditsResponse{ Object: "test-object", Created: time.Now().Unix(), } @@ -62,12 +62,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { completionTokens := int(float32(len(editString))/4) * editReq.N for i := 0; i < editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ + res.Choices = append(res.Choices, openai.EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -77,16 +77,16 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} +func getEditBody(r *http.Request) (openai.EditsRequest, error) { + edit := openai.EditsRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } err = json.Unmarshal(reqBody, &edit) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } return edit, nil } diff --git a/embeddings_test.go b/embeddings_test.go index 72e8c245f..af04d96bf 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -11,32 +11,32 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, + embeddedModels := []openai.EmbeddingModel{ + openai.AdaSimilarity, + openai.BabbageSimilarity, + openai.CurieSimilarity, + openai.DavinciSimilarity, + openai.AdaSearchDocument, + openai.AdaSearchQuery, + openai.BabbageSearchDocument, + openai.BabbageSearchQuery, + openai.CurieSearchDocument, + openai.CurieSearchQuery, + openai.DavinciSearchDocument, + openai.DavinciSearchQuery, + openai.AdaCodeSearchCode, + openai.AdaCodeSearchText, + openai.BabbageCodeSearchCode, + openai.BabbageCodeSearchText, } for _, model := range embeddedModels { // test embedding request with strings (simple embedding request) - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -52,7 +52,7 @@ func TestEmbedding(t *testing.T) { } // test embedding request with strings - embeddingReqStrings := EmbeddingRequestStrings{ + embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -66,7 +66,7 @@ func TestEmbedding(t *testing.T) { } // test embedding request with tokens - embeddingReqTokens := EmbeddingRequestTokens{ + embeddingReqTokens := openai.EmbeddingRequestTokens{ Input: [][]int{ {464, 2057, 373, 12625, 290, 262, 46612}, {6395, 6096, 286, 11525, 12083, 2581}, @@ -82,17 +82,17 @@ func TestEmbedding(t *testing.T) { } func TestEmbeddingModel(t *testing.T) { - var em EmbeddingModel + var em openai.EmbeddingModel err := em.UnmarshalText([]byte("text-similarity-ada-001")) checks.NoError(t, err, "Could not marshal embedding model") - if em != AdaSimilarity { + if em != openai.AdaSimilarity { t.Errorf("Model is not equal to AdaSimilarity") } err = em.UnmarshalText([]byte("some-non-existent-model")) checks.NoError(t, err, "Could not marshal embedding model") - if em != Unknown { + if em != openai.Unknown { t.Errorf("Model is not equal to Unknown") } } @@ -101,12 +101,12 @@ func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - sampleEmbeddings := []Embedding{ + sampleEmbeddings := []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, } - sampleBase64Embeddings := []Base64Embedding{ + sampleBase64Embeddings := []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, } @@ -115,8 +115,8 @@ func TestEmbeddingEndpoint(t *testing.T) { "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { var req struct { - EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` - User string `json:"user"` + EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` } _ = json.NewDecoder(r.Body).Decode(&req) @@ -125,16 +125,16 @@ func TestEmbeddingEndpoint(t *testing.T) { case req.User == "invalid": w.WriteHeader(http.StatusBadRequest) return - case req.EncodingFormat == EmbeddingEncodingFormatBase64: - resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings}) default: - resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) + resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) - res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) @@ -143,8 +143,8 @@ func TestEmbeddingEndpoint(t *testing.T) { // test create embeddings with strings (simple embedding request) res, err = client.CreateEmbeddings( context.Background(), - EmbeddingRequest{ - EncodingFormat: EmbeddingEncodingFormatBase64, + openai.EmbeddingRequest{ + EncodingFormat: openai.EmbeddingEncodingFormatBase64, }, ) checks.NoError(t, err, "CreateEmbeddings error") @@ -153,23 +153,23 @@ func TestEmbeddingEndpoint(t *testing.T) { } // test create embeddings with strings - res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test create embeddings with tokens - res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test failed sendRequest - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ + _, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ User: "invalid", - EncodingFormat: EmbeddingEncodingFormatBase64, + EncodingFormat: openai.EmbeddingEncodingFormatBase64, }) checks.HasError(t, err, "CreateEmbeddings error") } @@ -177,26 +177,26 @@ func TestEmbeddingEndpoint(t *testing.T) { func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { type fields struct { Object string - Data []Base64Embedding - Model EmbeddingModel - Usage Usage + Data []openai.Base64Embedding + Model openai.EmbeddingModel + Usage openai.Usage } tests := []struct { name string fields fields - want EmbeddingResponse + want openai.EmbeddingResponse wantErr bool }{ { name: "test embedding response base64 to embedding response", fields: fields{ - Data: []Base64Embedding{ + Data: []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, }, }, - want: EmbeddingResponse{ - Data: []Embedding{ + want: openai.EmbeddingResponse{ + Data: []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, }, @@ -206,19 +206,19 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { { name: "Invalid embedding", fields: fields{ - Data: []Base64Embedding{ + Data: []openai.Base64Embedding{ { Embedding: "----", }, }, }, - want: EmbeddingResponse{}, + want: openai.EmbeddingResponse{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &EmbeddingResponseBase64{ + r := &openai.EmbeddingResponseBase64{ Object: tt.fields.Object, Data: tt.fields.Data, Model: tt.fields.Model, @@ -237,8 +237,8 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { } func TestDotProduct(t *testing.T) { - v1 := &Embedding{Embedding: []float32{1, 2, 3}} - v2 := &Embedding{Embedding: []float32{2, 4, 6}} + v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}} + v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}} expected := float32(28.0) result, err := v1.DotProduct(v2) @@ -250,8 +250,8 @@ func TestDotProduct(t *testing.T) { t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) } - v1 = &Embedding{Embedding: []float32{1, 0, 0}} - v2 = &Embedding{Embedding: []float32{0, 1, 0}} + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}} expected = float32(0.0) result, err = v1.DotProduct(v2) @@ -264,10 +264,10 @@ func TestDotProduct(t *testing.T) { } // Test for VectorLengthMismatchError - v1 = &Embedding{Embedding: []float32{1, 0, 0}} - v2 = &Embedding{Embedding: []float32{0, 1}} + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1}} _, err = v1.DotProduct(v2) - if !errors.Is(err, ErrVectorLengthMismatch) { + if !errors.Is(err, openai.ErrVectorLengthMismatch) { t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) } } diff --git a/engines_test.go b/engines_test.go index 31e7ec8be..d26aa5541 100644 --- a/engines_test.go +++ b/engines_test.go @@ -7,7 +7,7 @@ import ( "net/http" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -15,8 +15,8 @@ import ( func TestGetEngine(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(Engine{}) + server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Engine{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.GetEngine(context.Background(), "text-davinci-003") @@ -27,8 +27,8 @@ func TestGetEngine(t *testing.T) { func TestListEngines(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EnginesList{}) + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.EnginesList{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.ListEngines(context.Background()) @@ -38,7 +38,7 @@ func TestListEngines(t *testing.T) { func TestListEnginesReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusTeapot) }) diff --git a/error_test.go b/error_test.go index a0806b7ed..48cbe4f29 100644 --- a/error_test.go +++ b/error_test.go @@ -6,7 +6,7 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" ) func TestAPIErrorUnmarshalJSON(t *testing.T) { @@ -14,7 +14,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name string response string hasError bool - checkFunc func(t *testing.T, apiErr APIError) + checkFunc func(t *testing.T, apiErr openai.APIError) } testCases := []testCase{ // testcase for message field @@ -22,7 +22,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is string", response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo") }, }, @@ -30,7 +30,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is array with single item", response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo") }, }, @@ -38,7 +38,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is array with multiple items", response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo, bar, baz") }, }, @@ -46,7 +46,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is empty array", response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "") }, }, @@ -54,7 +54,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is null", response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "") }, }, @@ -89,23 +89,23 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } }`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{ + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{ Code: "ResponsibleAIPolicyViolation", - ContentFilterResults: ContentFilterResults{ - Hate: Hate{ + ContentFilterResults: openai.ContentFilterResults{ + Hate: openai.Hate{ Filtered: false, Severity: "safe", }, - SelfHarm: SelfHarm{ + SelfHarm: openai.SelfHarm{ Filtered: false, Severity: "safe", }, - Sexual: Sexual{ + Sexual: openai.Sexual{ Filtered: true, Severity: "medium", }, - Violence: Violence{ + Violence: openai.Violence{ Filtered: false, Severity: "safe", }, @@ -117,16 +117,16 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the innerError is empty (Azure Openai)", response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{}) + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) }, }, { name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, hasError: true, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{}) + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) }, }, { @@ -159,7 +159,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is int", response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, 418) }, }, @@ -167,7 +167,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is string", response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, "teapot") }, }, @@ -175,7 +175,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is not exists", response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, nil) }, }, @@ -196,7 +196,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse failed when the response is invalid json", response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: true, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, nil) assertAPIErrorMessage(t, apiErr, "") assertAPIErrorParam(t, apiErr, nil) @@ -206,7 +206,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var apiErr APIError + var apiErr openai.APIError err := apiErr.UnmarshalJSON([]byte(tc.response)) if (err != nil) != tc.hasError { t.Errorf("Unexpected error: %v", err) @@ -218,19 +218,19 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } } -func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { +func assertAPIErrorMessage(t *testing.T, apiErr openai.APIError, expected string) { if apiErr.Message != expected { t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) } } -func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { +func assertAPIErrorInnerError(t *testing.T, apiErr openai.APIError, expected interface{}) { if !reflect.DeepEqual(apiErr.InnerError, expected) { t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) } } -func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { +func assertAPIErrorCode(t *testing.T, apiErr openai.APIError, expected interface{}) { switch v := apiErr.Code.(type) { case int: if v != expected { @@ -246,25 +246,25 @@ func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { } } -func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) { +func assertAPIErrorParam(t *testing.T, apiErr openai.APIError, expected *string) { if apiErr.Param != expected { t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) } } -func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) { +func assertAPIErrorType(t *testing.T, apiErr openai.APIError, typ string) { if apiErr.Type != typ { t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) } } func TestRequestError(t *testing.T) { - var err error = &RequestError{ + var err error = &openai.RequestError{ HTTPStatusCode: http.StatusTeapot, Err: errors.New("i am a teapot"), } - var reqErr *RequestError + var reqErr *openai.RequestError if !errors.As(err, &reqErr) { t.Fatalf("Error is not a RequestError: %+v", err) } diff --git a/example_test.go b/example_test.go index b5dfafea9..de67c57cd 100644 --- a/example_test.go +++ b/example_test.go @@ -28,7 +28,6 @@ func Example() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return @@ -319,7 +318,6 @@ func ExampleDefaultAzureConfig() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return diff --git a/files_api_test.go b/files_api_test.go index 1cbc72894..330b88159 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -20,7 +20,7 @@ func TestFileUpload(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files", handleCreateFile) - req := FileRequest{ + req := openai.FileRequest{ FileName: "test.go", FilePath: "client.go", Purpose: "fine-tune", @@ -57,7 +57,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { } defer file.Close() - var fileReq = File{ + fileReq := openai.File{ Bytes: int(header.Size), ID: strconv.Itoa(int(time.Now().Unix())), FileName: header.Filename, @@ -82,7 +82,7 @@ func TestListFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FilesList{}) + resBytes, _ := json.Marshal(openai.FilesList{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.ListFiles(context.Background()) @@ -93,7 +93,7 @@ func TestGetFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(File{}) + resBytes, _ := json.Marshal(openai.File{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.GetFile(context.Background(), "deadbeef") @@ -148,7 +148,7 @@ func TestGetFileContentReturnError(t *testing.T) { t.Fatal("Did not return error") } - apiErr := &APIError{} + apiErr := &openai.APIError{} if !errors.As(err, &apiErr) { t.Fatalf("Did not return APIError: %+v\n", apiErr) } diff --git a/files_test.go b/files_test.go index df6eaef7b..f588b30dc 100644 --- a/files_test.go +++ b/files_test.go @@ -1,14 +1,14 @@ package openai //nolint:testpackage // testing private field import ( - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "fmt" "io" "os" "testing" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestFileUploadWithFailingFormBuilder(t *testing.T) { diff --git a/fine_tunes.go b/fine_tunes.go index ca840781c..46f89f165 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,6 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { + //nolint:goconst // Decreases readability req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return diff --git a/fine_tunes_test.go b/fine_tunes_test.go index 67f681d97..2ab6817f7 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -1,14 +1,14 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneID = "fine-tune-id" @@ -22,9 +22,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodGet { - resBytes, _ = json.Marshal(FineTuneList{}) + resBytes, _ = json.Marshal(openai.FineTuneList{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -32,8 +32,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTune{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTune{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -43,9 +43,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodDelete { - resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) + resBytes, _ = json.Marshal(openai.FineTuneDeleteResponse{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -53,8 +53,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuneEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuneEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -64,7 +64,7 @@ func TestFineTunes(t *testing.T) { _, err := client.ListFineTunes(ctx) checks.NoError(t, err, "ListFineTunes error") - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + _, err = client.CreateFineTune(ctx, openai.FineTuneRequest{}) checks.NoError(t, err, "CreateFineTune error") _, err = client.CancelFineTune(ctx, testFineTuneID) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index f6d41c33d..c892ef775 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -2,14 +2,13 @@ package openai_test import ( "context" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuninigJobID = "fine-tuning-job-id" @@ -20,8 +19,8 @@ func TestFineTuningJob(t *testing.T) { defer teardown() server.RegisterHandler( "/v1/fine_tuning/jobs", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJob{ + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{ Object: "fine_tuning.job", ID: testFineTuninigJobID, Model: "davinci-002", @@ -33,7 +32,7 @@ func TestFineTuningJob(t *testing.T) { Status: "succeeded", ValidationFile: "", TrainingFile: "file-abc123", - Hyperparameters: Hyperparameters{ + Hyperparameters: openai.Hyperparameters{ Epochs: "auto", }, TrainedTokens: 5768, @@ -44,32 +43,32 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJob{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) }, ) server.RegisterHandler( "/v1/fine_tuning/jobs/"+testFineTuninigJobID, - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { var resBytes []byte - resBytes, _ = json.Marshal(FineTuningJob{}) + resBytes, _ = json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) }, ) server.RegisterHandler( "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJobEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJobEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) ctx := context.Background() - _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + _, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{}) checks.NoError(t, err, "CreateFineTuningJob error") _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) @@ -84,22 +83,22 @@ func TestFineTuningJob(t *testing.T) { _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithAfter("last-event-id"), ) checks.NoError(t, err, "ListFineTuningJobEvents error") _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithLimit(10), + openai.ListFineTuningJobEventsWithLimit(10), ) checks.NoError(t, err, "ListFineTuningJobEvents error") _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithAfter("last-event-id"), - ListFineTuningJobEventsWithLimit(10), + openai.ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithLimit(10), ) checks.NoError(t, err, "ListFineTuningJobEvents error") } diff --git a/image_api_test.go b/image_api_test.go index b472eb04a..422f831fe 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -12,13 +9,16 @@ import ( "os" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestImages(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/images/generations", handleImageEndpoint) - _, err := client.CreateImage(context.Background(), ImageRequest{ + _, err := client.CreateImage(context.Background(), openai.ImageRequest{ Prompt: "Lorem ipsum", }) checks.NoError(t, err, "CreateImage error") @@ -33,20 +33,20 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq ImageRequest + var imageReq openai.ImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ImageResponse{ + res := openai.ImageResponse{ Created: time.Now().Unix(), } for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} + imageData := openai.ImageResponseDataInner{} switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": + case openai.CreateImageResponseFormatURL, "": imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: + case openai.CreateImageResponseFormatB64JSON: // This decodes to "{}" in base64. imageData.B64JSON = "e30K" default: @@ -60,16 +60,16 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} +func getImageBody(r *http.Request) (openai.ImageRequest, error) { + image := openai.ImageRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } err = json.Unmarshal(reqBody, &image) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } return image, nil } @@ -98,13 +98,13 @@ func TestImageEdit(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, Mask: mask, Prompt: "There is a turtle in the pool", N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -125,12 +125,12 @@ func TestImageEditWithoutMask(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, Prompt: "There is a turtle in the pool", N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -144,9 +144,9 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - responses := ImageResponse{ + responses := openai.ImageResponse{ Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ + Data: []openai.ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", @@ -182,11 +182,11 @@ func TestImageVariation(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateVariImage(context.Background(), ImageVariRequest{ + _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ Image: origin, N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -200,9 +200,9 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - responses := ImageResponse{ + responses := openai.ImageResponse{ Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ + Data: []openai.ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index c8d0c1d9e..744706082 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -5,28 +5,28 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai/jsonschema" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestDefinition_MarshalJSON(t *testing.T) { tests := []struct { name string - def Definition + def jsonschema.Definition want string }{ { name: "Test with empty Definition", - def: Definition{}, + def: jsonschema.Definition{}, want: `{"properties":{}}`, }, { name: "Test with Definition properties set", - def: Definition{ - Type: String, + def: jsonschema.Definition{ + Type: jsonschema.String, Description: "A string type", - Properties: map[string]Definition{ + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, }, }, @@ -43,17 +43,17 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with nested Definition properties", - def: Definition{ - Type: Object, - Properties: map[string]Definition{ + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "user": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, "age": { - Type: Integer, + Type: jsonschema.Integer, }, }, }, @@ -80,26 +80,26 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with complex nested Definition", - def: Definition{ - Type: Object, - Properties: map[string]Definition{ + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "user": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, "age": { - Type: Integer, + Type: jsonschema.Integer, }, "address": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "city": { - Type: String, + Type: jsonschema.String, }, "country": { - Type: String, + Type: jsonschema.String, }, }, }, @@ -141,14 +141,14 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with Array type Definition", - def: Definition{ - Type: Array, - Items: &Definition{ - Type: String, + def: jsonschema.Definition{ + Type: jsonschema.Array, + Items: &jsonschema.Definition{ + Type: jsonschema.String, }, - Properties: map[string]Definition{ + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, }, }, diff --git a/models_test.go b/models_test.go index 9ff73042a..4a4c759dc 100644 --- a/models_test.go +++ b/models_test.go @@ -1,17 +1,16 @@ package openai_test import ( - "os" - "time" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" + "os" "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneModelID = "fine-tune-model-id" @@ -35,7 +34,7 @@ func TestAzureListModels(t *testing.T) { // handleListModelsEndpoint Handles the list models endpoint by the test server. func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(ModelsList{}) + resBytes, _ := json.Marshal(openai.ModelsList{}) fmt.Fprintln(w, string(resBytes)) } @@ -58,7 +57,7 @@ func TestAzureGetModel(t *testing.T) { // handleGetModelsEndpoint Handles the get model endpoint by the test server. func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(Model{}) + resBytes, _ := json.Marshal(openai.Model{}) fmt.Fprintln(w, string(resBytes)) } @@ -90,6 +89,6 @@ func TestDeleteFineTuneModel(t *testing.T) { } func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) + resBytes, _ := json.Marshal(openai.FineTuneModelDeleteResponse{}) fmt.Fprintln(w, string(resBytes)) } diff --git a/moderation_test.go b/moderation_test.go index 68f9565e1..059f0d1c7 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -13,6 +10,9 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestModeration Tests the moderations endpoint of the API using the mocked server. @@ -20,8 +20,8 @@ func TestModerations(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) - _, err := client.Moderations(context.Background(), ModerationRequest{ - Model: ModerationTextStable, + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ + Model: openai.ModerationTextStable, Input: "I want to kill them.", }) checks.NoError(t, err, "Moderation error") @@ -34,16 +34,16 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) { expect error } modelOptions = append(modelOptions, - getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel), - getModerationModelTestOption(ModerationTextStable, nil), - getModerationModelTestOption(ModerationTextLatest, nil), + getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), + getModerationModelTestOption(openai.ModerationTextStable, nil), + getModerationModelTestOption(openai.ModerationTextLatest, nil), getModerationModelTestOption("", nil), ) client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) for _, modelTest := range modelOptions { - _, err := client.Moderations(context.Background(), ModerationRequest{ + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ Model: modelTest.model, Input: "I want to kill them.", }) @@ -71,32 +71,32 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var moderationReq ModerationRequest + var moderationReq openai.ModerationRequest if moderationReq, err = getModerationBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - resCat := ResultCategories{} - resCatScore := ResultCategoryScores{} + resCat := openai.ResultCategories{} + resCatScore := openai.ResultCategoryScores{} switch { case strings.Contains(moderationReq.Input, "kill"): - resCat = ResultCategories{Violence: true} - resCatScore = ResultCategoryScores{Violence: 1} + resCat = openai.ResultCategories{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "hate"): - resCat = ResultCategories{Hate: true} - resCatScore = ResultCategoryScores{Hate: 1} + resCat = openai.ResultCategories{Hate: true} + resCatScore = openai.ResultCategoryScores{Hate: 1} case strings.Contains(moderationReq.Input, "suicide"): - resCat = ResultCategories{SelfHarm: true} - resCatScore = ResultCategoryScores{SelfHarm: 1} + resCat = openai.ResultCategories{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "porn"): - resCat = ResultCategories{Sexual: true} - resCatScore = ResultCategoryScores{Sexual: 1} + resCat = openai.ResultCategories{Sexual: true} + resCatScore = openai.ResultCategoryScores{Sexual: 1} } - result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} - res := ModerationResponse{ + res := openai.ModerationResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Model: moderationReq.Model, } @@ -107,16 +107,16 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { } // getModerationBody Returns the body of the request to do a moderation. -func getModerationBody(r *http.Request) (ModerationRequest, error) { - moderation := ModerationRequest{} +func getModerationBody(r *http.Request) (openai.ModerationRequest, error) { + moderation := openai.ModerationRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } err = json.Unmarshal(reqBody, &moderation) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } return moderation, nil } diff --git a/openai_test.go b/openai_test.go index 4fc41ecc0..729d8880c 100644 --- a/openai_test.go +++ b/openai_test.go @@ -1,29 +1,29 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" ) -func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) { +func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { server = test.NewTestServer() ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := DefaultConfig(test.GetTestToken()) + config := openai.DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - client = NewClientWithConfig(config) + client = openai.NewClientWithConfig(config) return } -func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) { +func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { server = test.NewTestServer() ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config := openai.DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") config.BaseURL = ts.URL - client = NewClientWithConfig(config) + client = openai.NewClientWithConfig(config) return } diff --git a/stream_test.go b/stream_test.go index f3f8f85cd..35c52ae3b 100644 --- a/stream_test.go +++ b/stream_test.go @@ -10,23 +10,23 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletionStream( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } @@ -56,7 +56,7 @@ func TestCreateCompletionStream(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -65,20 +65,20 @@ func TestCreateCompletionStream(t *testing.T) { checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []CompletionResponse{ + expectedResponses := []openai.CompletionResponse{ { ID: "1", Object: "completion", Created: 1598069254, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, }, { ID: "2", Object: "completion", Created: 1598069255, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, }, } @@ -129,9 +129,9 @@ func TestCreateCompletionStreamError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3TextDavinci003, + Model: openai.GPT3TextDavinci003, Prompt: "Hello!", Stream: true, }) @@ -141,7 +141,7 @@ func TestCreateCompletionStreamError(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -166,10 +166,10 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - var apiErr *APIError - _, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + var apiErr *openai.APIError + _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Ada, + Model: openai.GPT3Ada, Prompt: "Hello!", Stream: true, }) @@ -209,7 +209,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -220,7 +220,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { _, _ = stream.Recv() _, streamErr := stream.Recv() - if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) { + if !errors.Is(streamErr, openai.ErrTooManyEmptyStreamMessages) { t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") } } @@ -244,7 +244,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -285,7 +285,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -312,7 +312,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) defer cancel() - _, err := client.CreateCompletionStream(ctx, CompletionRequest{ + _, err := client.CreateCompletionStream(ctx, openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -327,7 +327,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { } // Helper funcs. -func compareResponses(r1, r2 CompletionResponse) bool { +func compareResponses(r1, r2 openai.CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -342,7 +342,7 @@ func compareResponses(r1, r2 CompletionResponse) bool { return true } -func compareResponseChoices(c1, c2 CompletionChoice) bool { +func compareResponseChoices(c1, c2 openai.CompletionChoice) bool { if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { return false } From d07833e19bfbb2f26011c8881f7fb61366c07e75 Mon Sep 17 00:00:00 2001 From: Carson Kahn Date: Tue, 7 Nov 2023 04:27:29 -0500 Subject: [PATCH 072/206] Doc ways to improve reproducability besides Temp (#532) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f0b609088..4cb77db6b 100644 --- a/README.md +++ b/README.md @@ -757,8 +757,9 @@ Even when specifying a temperature field of 0, it doesn't guarantee that you'll Due to the factors mentioned above, different answers may be returned even for the same question. **Workarounds:** -1. Using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. -2. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. +1. As of November 2023, use [the new `seed` parameter](https://platform.openai.com/docs/guides/text-generation/reproducible-outputs) in conjunction with the `system_fingerprint` response field, alongside Temperature management. +2. Try using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. +3. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. By adopting these strategies, you can expect more consistent results. From 6d9c3a6365643d02692ecc6f0b34a5fa3e7fea45 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 15:25:21 +0100 Subject: [PATCH 073/206] Feat Support chat completion response format and seed new fields (#525) * feat: support chat completion response format * fix linting error * fix * fix linting * Revert "fix linting" This reverts commit 015c6ad62aad561218b693225f58670b5619dba8. * Revert "fix" This reverts commit 7b2ffe28c3e586b629d23479ec1728bf52f0c66f. * Revert "fix linting error" This reverts commit 29960423784e296cb6d22c5db8f8ccf00cac59fd. * chore: add seed new parameter * fix --- chat.go | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/chat.go b/chat.go index df0e5f970..88db8cf1d 100644 --- a/chat.go +++ b/chat.go @@ -69,18 +69,31 @@ type FunctionCall struct { Arguments string `json:"arguments,omitempty"` } +type ChatCompletionResponseFormatType string + +const ( + ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" +) + +type ChatCompletionResponseFormat struct { + Type ChatCompletionResponseFormatType `json:"type"` +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias From 3063e676bf5932024d76be8e8d9e41df06d4e8cc Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 16:20:59 +0100 Subject: [PATCH 074/206] Feat Implement assistants API (#535) * chore: implement assistants API * fix * fix * chore: add tests * fix tests * fix linting --- assistant.go | 260 ++++++++++++++++++++++++++++++++++++++++++++++ assistant_test.go | 202 +++++++++++++++++++++++++++++++++++ client_test.go | 27 +++++ 3 files changed, 489 insertions(+) create mode 100644 assistant.go create mode 100644 assistant_test.go diff --git a/assistant.go b/assistant.go new file mode 100644 index 000000000..d75eebef3 --- /dev/null +++ b/assistant.go @@ -0,0 +1,260 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + assistantsSuffix = "/assistants" + assistantsFilesSuffix = "/files" +) + +type Assistant struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []any `json:"tools,omitempty"` + + httpHeader +} + +type AssistantTool struct { + Type string `json:"type"` +} + +type AssistantToolCodeInterpreter struct { + AssistantTool +} + +type AssistantToolRetrieval struct { + AssistantTool +} + +type AssistantToolFunction struct { + AssistantTool + Function FunctionDefinition `json:"function"` +} + +type AssistantRequest struct { + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []any `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// AssistantsList is a list of assistants. +type AssistantsList struct { + Assistants []Assistant `json:"data"` + + httpHeader +} + +type AssistantFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + + httpHeader +} + +type AssistantFileRequest struct { + FileID string `json:"file_id"` +} + +type AssistantFilesList struct { + AssistantFiles []AssistantFile `json:"data"` + + httpHeader +} + +// CreateAssistant creates a new assistant. +func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistant retrieves an assistant. +func (c *Client) RetrieveAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyAssistant modifies an assistant. +func (c *Client) ModifyAssistant( + ctx context.Context, + assistantID string, + request AssistantRequest, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistant deletes an assistant. +func (c *Client) DeleteAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListAssistants Lists the currently available assistants. +func (c *Client) ListAssistants( + ctx context.Context, + limit *int, + order *string, + after *string, + before *string, +) (reponse AssistantsList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &reponse) + return +} + +// CreateAssistantFile creates a new assistant file. +func (c *Client) CreateAssistantFile( + ctx context.Context, + assistantID string, + request AssistantFileRequest, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistantFile retrieves an assistant file. +func (c *Client) RetrieveAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistantFile deletes an existing file. +func (c *Client) DeleteAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, nil) + return +} + +// ListAssistantFiles Lists the currently available files for an assistant. +func (c *Client) ListAssistantFiles( + ctx context.Context, + assistantID string, + limit *int, + order *string, + after *string, + before *string, +) (response AssistantFilesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/assistant_test.go b/assistant_test.go new file mode 100644 index 000000000..eb6f42458 --- /dev/null +++ b/assistant_test.go @@ -0,0 +1,202 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assitantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assitantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assitantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assitantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + + _, err = client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + + _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assitantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + _, err = client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + + _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + + _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + + _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + + _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + + err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") +} diff --git a/client_test.go b/client_test.go index 2c1d749ed..bff2597c5 100644 --- a/client_test.go +++ b/client_test.go @@ -274,6 +274,33 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteFineTuneModel", func() (any, error) { return client.DeleteFineTuneModel(ctx, "") }}, + {"CreateAssistant", func() (any, error) { + return client.CreateAssistant(ctx, AssistantRequest{}) + }}, + {"RetrieveAssistant", func() (any, error) { + return client.RetrieveAssistant(ctx, "") + }}, + {"ModifyAssistant", func() (any, error) { + return client.ModifyAssistant(ctx, "", AssistantRequest{}) + }}, + {"DeleteAssistant", func() (any, error) { + return client.DeleteAssistant(ctx, "") + }}, + {"ListAssistants", func() (any, error) { + return client.ListAssistants(ctx, nil, nil, nil, nil) + }}, + {"CreateAssistantFile", func() (any, error) { + return client.CreateAssistantFile(ctx, "", AssistantFileRequest{}) + }}, + {"ListAssistantFiles", func() (any, error) { + return client.ListAssistantFiles(ctx, "", nil, nil, nil, nil) + }}, + {"RetrieveAssistantFile", func() (any, error) { + return client.RetrieveAssistantFile(ctx, "", "") + }}, + {"DeleteAssistantFile", func() (any, error) { + return nil, client.DeleteAssistantFile(ctx, "", "") + }}, } for _, testCase := range testCases { From 1ad6b6f53dcd9abfaf56e8adb02b5b599936580c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 16:53:24 +0100 Subject: [PATCH 075/206] Feat Support tools and tools choice new fileds (#526) * feat: support tools and tools choice new fileds * fix: use value not pointers --- chat.go | 41 +++++++++++++++++++++++++++++++++++++---- chat_stream.go | 1 + 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/chat.go b/chat.go index 88db8cf1d..04303184a 100644 --- a/chat.go +++ b/chat.go @@ -12,6 +12,7 @@ const ( ChatMessageRoleUser = "user" ChatMessageRoleAssistant = "assistant" ChatMessageRoleFunction = "function" + ChatMessageRoleTool = "tool" ) const chatCompletionsSuffix = "/chat/completions" @@ -61,6 +62,12 @@ type ChatCompletionMessage struct { Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Function FunctionCall `json:"function"` } type FunctionCall struct { @@ -97,10 +104,35 @@ type ChatCompletionRequest struct { // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` - Functions []FunctionDefinition `json:"functions,omitempty"` - FunctionCall any `json:"function_call,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` + // Deprecated: use Tools instead. + Functions []FunctionDefinition `json:"functions,omitempty"` + // Deprecated: use ToolChoice instead. + FunctionCall any `json:"function_call,omitempty"` + Tools []Tool `json:"tools,omitempty"` + // This can be either a string or an ToolChoice object. + ToolChoiche any `json:"tool_choice,omitempty"` +} + +type ToolType string + +const ( + ToolTypeFunction ToolType = "function" +) + +type Tool struct { + Type ToolType `json:"type"` + Function FunctionDefinition `json:"function,omitempty"` +} + +type ToolChoiche struct { + Type ToolType `json:"type"` + Function ToolFunction `json:"function,omitempty"` +} + +type ToolFunction struct { + Name string `json:"name"` } type FunctionDefinition struct { @@ -123,6 +155,7 @@ const ( FinishReasonStop FinishReason = "stop" FinishReasonLength FinishReason = "length" FinishReasonFunctionCall FinishReason = "function_call" + FinishReasonToolCalls FinishReason = "tool_calls" FinishReasonContentFilter FinishReason = "content_filter" FinishReasonNull FinishReason = "null" ) diff --git a/chat_stream.go b/chat_stream.go index f1faa3964..57cfa789f 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -9,6 +9,7 @@ type ChatCompletionStreamChoiceDelta struct { Content string `json:"content,omitempty"` Role string `json:"role,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type ChatCompletionStreamChoice struct { From a20eb08b79e5c34882888a401020b47c145357ff Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 22:30:05 +0100 Subject: [PATCH 076/206] fix: use pointer for ChatCompletionResponseFormat (#544) --- chat.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index 04303184a..609e0c311 100644 --- a/chat.go +++ b/chat.go @@ -89,18 +89,18 @@ type ChatCompletionResponseFormat struct { // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias From a0159ad2b00e4f127222814694bec68863395543 Mon Sep 17 00:00:00 2001 From: Mike Cutalo Date: Tue, 7 Nov 2023 23:16:22 -0800 Subject: [PATCH 077/206] Support new fields for /v1/images/generation API (#530) * add support for new image/generation api * fix one lint * add revised_prompt to response * fix lints * add CreateImageQualityStandard --- image.go | 26 ++++++++++++++++++++++++-- image_api_test.go | 9 ++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/image.go b/image.go index 4addcdb1e..4fe8b3a32 100644 --- a/image.go +++ b/image.go @@ -13,6 +13,9 @@ const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" CreateImageSize1024x1024 = "1024x1024" + // dall-e-3 supported only. + CreateImageSize1792x1024 = "1792x1024" + CreateImageSize1024x1792 = "1024x1792" ) const ( @@ -20,11 +23,29 @@ const ( CreateImageResponseFormatB64JSON = "b64_json" ) +const ( + CreateImageModelDallE2 = "dall-e-2" + CreateImageModelDallE3 = "dall-e-3" +) + +const ( + CreateImageQualityHD = "hd" + CreateImageQualityStandard = "standard" +) + +const ( + CreateImageStyleVivid = "vivid" + CreateImageStyleNatural = "natural" +) + // ImageRequest represents the request structure for the image API. type ImageRequest struct { Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` ResponseFormat string `json:"response_format,omitempty"` User string `json:"user,omitempty"` } @@ -39,8 +60,9 @@ type ImageResponse struct { // ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` } // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. diff --git a/image_api_test.go b/image_api_test.go index 422f831fe..2eb46f2b4 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -19,7 +19,14 @@ func TestImages(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/generations", handleImageEndpoint) _, err := client.CreateImage(context.Background(), openai.ImageRequest{ - Prompt: "Lorem ipsum", + Prompt: "Lorem ipsum", + Model: openai.CreateImageModelDallE3, + N: 1, + Quality: openai.CreateImageQualityHD, + Size: openai.CreateImageSize1024x1024, + Style: openai.CreateImageStyleVivid, + ResponseFormat: openai.CreateImageResponseFormatURL, + User: "user", }) checks.NoError(t, err, "CreateImage error") } From a2d2bf685122fd51d768f2a828787cae587d9ad6 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 8 Nov 2023 10:20:20 +0100 Subject: [PATCH 078/206] Fix Refactor assistant api (#545) * fix: refactor assistant API * fix * trigger build * fix: use AssistantDeleteResponse --- assistant.go | 90 ++++++++++++++++++++++++++++++---------------------- client.go | 6 ++++ 2 files changed, 58 insertions(+), 38 deletions(-) diff --git a/assistant.go b/assistant.go index d75eebef3..de49be680 100644 --- a/assistant.go +++ b/assistant.go @@ -10,46 +10,43 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" + openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []any `json:"tools,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools,omitempty"` httpHeader } -type AssistantTool struct { - Type string `json:"type"` -} - -type AssistantToolCodeInterpreter struct { - AssistantTool -} +type AssistantToolType string -type AssistantToolRetrieval struct { - AssistantTool -} +const ( + AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" + AssistantToolTypeRetrieval AssistantToolType = "retrieval" + AssistantToolTypeFunction AssistantToolType = "function" +) -type AssistantToolFunction struct { - AssistantTool - Function FunctionDefinition `json:"function"` +type AssistantTool struct { + Type AssistantToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` } type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []any `json:"tools,omitempty"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // AssistantsList is a list of assistants. @@ -59,6 +56,14 @@ type AssistantsList struct { httpHeader } +type AssistantDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + type AssistantFile struct { ID string `json:"id"` Object string `json:"object"` @@ -80,7 +85,8 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -95,7 +101,8 @@ func (c *Client) RetrieveAssistant( assistantID string, ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -111,7 +118,8 @@ func (c *Client) ModifyAssistant( request AssistantRequest, ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -124,9 +132,10 @@ func (c *Client) ModifyAssistant( func (c *Client) DeleteAssistant( ctx context.Context, assistantID string, -) (response Assistant, err error) { +) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -163,7 +172,8 @@ func (c *Client) ListAssistants( } urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -180,7 +190,8 @@ func (c *Client) CreateAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(request)) + withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -196,7 +207,8 @@ func (c *Client) RetrieveAssistantFile( fileID string, ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -212,7 +224,8 @@ func (c *Client) DeleteAssistantFile( fileID string, ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) - req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -250,7 +263,8 @@ func (c *Client) ListAssistantFiles( } urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } diff --git a/client.go b/client.go index 65ece812f..056226c61 100644 --- a/client.go +++ b/client.go @@ -83,6 +83,12 @@ func withContentType(contentType string) requestOption { } } +func withBetaAssistantV1() requestOption { + return func(args *requestOptions) { + args.header.Set("OpenAI-Beta", "assistants=v1") + } +} + func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { // Default Options args := &requestOptions{ From 08c167fecf6953619d1905ab2959ed341bfb063d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 8 Nov 2023 18:21:51 +0900 Subject: [PATCH 079/206] test: fix compile error in api integration test (#548) --- api_integration_test.go | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 6be188bc6..736040c50 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -9,6 +9,7 @@ import ( "os" "testing" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -20,7 +21,7 @@ func TestAPI(t *testing.T) { } var err error - c := NewClient(apiToken) + c := openai.NewClient(apiToken) ctx := context.Background() _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") @@ -36,23 +37,23 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "GetFile error") } // else skip - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: AdaSearchQuery, + Model: openai.AdaSearchQuery, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") _, err = c.CreateChatCompletion( ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -63,11 +64,11 @@ func TestAPI(t *testing.T) { _, err = c.CreateChatCompletion( ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Name: "John_Doe", Content: "Hello!", }, @@ -76,9 +77,9 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ Prompt: "Ex falso quodlibet", - Model: GPT3Ada, + Model: openai.GPT3Ada, MaxTokens: 5, Stream: true, }) @@ -103,15 +104,15 @@ func TestAPI(t *testing.T) { _, err = c.CreateChatCompletion( context.Background(), - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "What is the weather like in Boston?", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "get_current_weather", Parameters: jsonschema.Definition{ Type: jsonschema.Object, @@ -140,12 +141,12 @@ func TestAPIError(t *testing.T) { } var err error - c := NewClient(apiToken + "_invalid") + c := openai.NewClient(apiToken + "_invalid") ctx := context.Background() _, err = c.ListEngines(ctx) checks.HasError(t, err, "ListEngines should fail with an invalid key") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(err, &apiErr) { t.Fatalf("Error is not an APIError: %+v", err) } From bc89139c1ddcc4f6d5b15b7e8d0491c69dda402c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 09:05:44 +0100 Subject: [PATCH 080/206] Feat Implement threads API (#536) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: implement threads API * fix * add tests * fix * trigger£ * trigger * chore: add beta header --- client_test.go | 12 ++++++ thread.go | 107 +++++++++++++++++++++++++++++++++++++++++++++++++ thread_test.go | 95 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+) create mode 100644 thread.go create mode 100644 thread_test.go diff --git a/client_test.go b/client_test.go index bff2597c5..b2f28f90a 100644 --- a/client_test.go +++ b/client_test.go @@ -301,6 +301,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteAssistantFile", func() (any, error) { return nil, client.DeleteAssistantFile(ctx, "", "") }}, + {"CreateThread", func() (any, error) { + return client.CreateThread(ctx, ThreadRequest{}) + }}, + {"RetrieveThread", func() (any, error) { + return client.RetrieveThread(ctx, "") + }}, + {"ModifyThread", func() (any, error) { + return client.ModifyThread(ctx, "", ModifyThreadRequest{}) + }}, + {"DeleteThread", func() (any, error) { + return client.DeleteThread(ctx, "") + }}, } for _, testCase := range testCases { diff --git a/thread.go b/thread.go new file mode 100644 index 000000000..291f3dcab --- /dev/null +++ b/thread.go @@ -0,0 +1,107 @@ +package openai + +import ( + "context" + "net/http" +) + +const ( + threadsSuffix = "/threads" +) + +type Thread struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type ThreadRequest struct { + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ModifyThreadRequest struct { + Metadata map[string]any `json:"metadata"` +} + +type ThreadMessageRole string + +const ( + ThreadMessageRoleUser ThreadMessageRole = "user" +) + +type ThreadMessage struct { + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +// CreateThread creates a new thread. +func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveThread retrieves a thread. +func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyThread modifies a thread. +func (c *Client) ModifyThread( + ctx context.Context, + threadID string, + request ModifyThreadRequest, +) (response Thread, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteThread deletes a thread. +func (c *Client) DeleteThread( + ctx context.Context, + threadID string, +) (response ThreadDeleteResponse, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/thread_test.go b/thread_test.go new file mode 100644 index 000000000..227ab6330 --- /dev/null +++ b/thread_test.go @@ -0,0 +1,95 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestThread Tests the thread endpoint of the API using the mocked server. +func TestThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +} From e3e065deb0a190e2d3c3bbf9caf54471b32f675e Mon Sep 17 00:00:00 2001 From: Gabriel Burt Date: Thu, 9 Nov 2023 03:08:43 -0500 Subject: [PATCH 081/206] Add SystemFingerprint and chatMsg.ToolCallID field (#543) * fix ToolChoiche typo * add tool_call_id to ChatCompletionMessage * add /chat system_fingerprint response field * check empty ToolCallID JSON marshaling and add omitempty for tool_call_id * messages also required; don't omitempty * add Type to ToolCall, required by the API * fix test, omitempty for response_format ptr * fix casing of role values in comments --- chat.go | 27 +++++++++++++++++---------- chat_test.go | 14 ++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/chat.go b/chat.go index 609e0c311..9ad31c466 100644 --- a/chat.go +++ b/chat.go @@ -62,11 +62,17 @@ type ChatCompletionMessage struct { Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // For Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. + ToolCallID string `json:"tool_call_id,omitempty"` } type ToolCall struct { ID string `json:"id"` + Type ToolType `json:"type"` Function FunctionCall `json:"function"` } @@ -84,7 +90,7 @@ const ( ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. @@ -112,7 +118,7 @@ type ChatCompletionRequest struct { FunctionCall any `json:"function_call,omitempty"` Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. - ToolChoiche any `json:"tool_choice,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } type ToolType string @@ -126,7 +132,7 @@ type Tool struct { Function FunctionDefinition `json:"function,omitempty"` } -type ToolChoiche struct { +type ToolChoice struct { Type ToolType `json:"type"` Function ToolFunction `json:"function,omitempty"` } @@ -182,12 +188,13 @@ type ChatCompletionChoice struct { // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` httpHeader } diff --git a/chat_test.go b/chat_test.go index 5bf1eaf6c..a8155edf2 100644 --- a/chat_test.go +++ b/chat_test.go @@ -51,6 +51,20 @@ func TestChatCompletionsWrongModel(t *testing.T) { checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } +func TestChatRequestOmitEmpty(t *testing.T) { + data, err := json.Marshal(openai.ChatCompletionRequest{ + // We set model b/c it's required, so omitempty doesn't make sense + Model: "gpt-4", + }) + checks.NoError(t, err) + + // messages is also required so isn't omitted + const expected = `{"model":"gpt-4","messages":null}` + if string(data) != expected { + t.Errorf("expected JSON with all empty fields to be %v but was %v", expected, string(data)) + } +} + func TestChatCompletionsWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" From 81270725539980d202829528054f3fda346970db Mon Sep 17 00:00:00 2001 From: Urjit Singh Bhatia Date: Thu, 9 Nov 2023 00:20:39 -0800 Subject: [PATCH 082/206] fix test server setup: (#549) * fix test server setup: - go map access is not deterministic - this can lead to a route: /foo/bar/1 matching /foo/bar before matching /foo/bar/1 if the map iteration go through /foo/bar first since the regex match wasn't bound to start and end anchors - registering handlers now converts * in routes to .* for proper regex matching - test server route handling now tries to fully match the handler route * add missing /v1 prefix to fine-tuning job cancel test server handler --- fine_tuning_job_test.go | 2 +- internal/test/server.go | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index c892ef775..d2fbcd4c7 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -42,7 +42,7 @@ func TestFineTuningJob(t *testing.T) { ) server.RegisterHandler( - "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) diff --git a/internal/test/server.go b/internal/test/server.go index 3813ff869..127d4c16f 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "regexp" + "strings" ) const testAPI = "this-is-my-secure-token-do-not-steal!!" @@ -23,13 +24,16 @@ func NewTestServer() *ServerTest { } func (ts *ServerTest) RegisterHandler(path string, handler handler) { + // to make the registered paths friendlier to a regex match in the route handler + // in OpenAITestServer + path = strings.ReplaceAll(path, "*", ".*") ts.handlers[path] = handler } // OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. func (ts *ServerTest) OpenAITestServer() *httptest.Server { return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("received request at path %q\n", r.URL.Path) + log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path) // check auth if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() { @@ -38,8 +42,10 @@ func (ts *ServerTest) OpenAITestServer() *httptest.Server { } // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling for route, handler := range ts.handlers { - pattern, _ := regexp.Compile(route) + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") if pattern.MatchString(r.URL.Path) { handler(w, r) return From 78862a2798df46f6ca8bb73350b720f9c8d4a592 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 15:05:03 +0100 Subject: [PATCH 083/206] fix: add missing fields in tool_calls (#558) --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index 9ad31c466..ebdc0e24b 100644 --- a/chat.go +++ b/chat.go @@ -71,6 +71,8 @@ type ChatCompletionMessage struct { } type ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` ID string `json:"id"` Type ToolType `json:"type"` Function FunctionCall `json:"function"` From d6f3bdcdac9172ab5248d6be8c3e1761446a434c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 20:17:30 +0100 Subject: [PATCH 084/206] Feat implement Run APIs (#560) * chore: first commit * add apis * chore: add tests * feat add apis * chore: add api and tests * chore: add tests * fix * trigger build * fix * chore: formatting code * chore: add pagination type --- client_test.go | 27 ++++ run.go | 399 +++++++++++++++++++++++++++++++++++++++++++++++++ run_test.go | 237 +++++++++++++++++++++++++++++ 3 files changed, 663 insertions(+) create mode 100644 run.go create mode 100644 run_test.go diff --git a/client_test.go b/client_test.go index b2f28f90a..d5d3e2644 100644 --- a/client_test.go +++ b/client_test.go @@ -313,6 +313,33 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteThread", func() (any, error) { return client.DeleteThread(ctx, "") }}, + {"CreateRun", func() (any, error) { + return client.CreateRun(ctx, "", RunRequest{}) + }}, + {"RetrieveRun", func() (any, error) { + return client.RetrieveRun(ctx, "", "") + }}, + {"ModifyRun", func() (any, error) { + return client.ModifyRun(ctx, "", "", RunModifyRequest{}) + }}, + {"ListRuns", func() (any, error) { + return client.ListRuns(ctx, "", Pagination{}) + }}, + {"SubmitToolOutputs", func() (any, error) { + return client.SubmitToolOutputs(ctx, "", "", SubmitToolOutputsRequest{}) + }}, + {"CancelRun", func() (any, error) { + return client.CancelRun(ctx, "", "") + }}, + {"CreateThreadAndRun", func() (any, error) { + return client.CreateThreadAndRun(ctx, CreateThreadAndRunRequest{}) + }}, + {"RetrieveRunStep", func() (any, error) { + return client.RetrieveRunStep(ctx, "", "", "") + }}, + {"ListRunSteps", func() (any, error) { + return client.ListRunSteps(ctx, "", "", Pagination{}) + }}, } for _, testCase := range testCases { diff --git a/run.go b/run.go new file mode 100644 index 000000000..5d6ea58db --- /dev/null +++ b/run.go @@ -0,0 +1,399 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type Run struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + ThreadID string `json:"thread_id"` + AssistantID string `json:"assistant_id"` + Status RunStatus `json:"status"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools"` + FileIDS []string `json:"file_ids"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type RunStatus string + +const ( + RunStatusQueued RunStatus = "queued" + RunStatusInProgress RunStatus = "in_progress" + RunStatusRequiresAction RunStatus = "requires_action" + RunStatusCancelling RunStatus = "cancelling" + RunStatusFailed RunStatus = "failed" + RunStatusCompleted RunStatus = "completed" + RunStatusExpired RunStatus = "expired" +) + +type RunRequiredAction struct { + Type RequiredActionType `json:"type"` + SubmitToolOutputs *SubmitToolOutputs `json:"submit_tool_outputs,omitempty"` +} + +type RequiredActionType string + +const ( + RequiredActionTypeSubmitToolOutputs RequiredActionType = "submit_tool_outputs" +) + +type SubmitToolOutputs struct { + ToolCalls []ToolCall `json:"tool_calls"` +} + +type RunLastError struct { + Code RunError `json:"code"` + Message string `json:"message"` +} + +type RunError string + +const ( + RunErrorServerError RunError = "server_error" + RunErrorRateLimitExceeded RunError = "rate_limit_exceeded" +) + +type RunRequest struct { + AssistantID string `json:"assistant_id"` + Model *string `json:"model,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any +} + +type RunModifyRequest struct { + Metadata map[string]any `json:"metadata,omitempty"` +} + +// RunList is a list of runs. +type RunList struct { + Runs []Run `json:"data"` + + httpHeader +} + +type SubmitToolOutputsRequest struct { + ToolOutputs []ToolOutput `json:"tool_outputs"` +} + +type ToolOutput struct { + ToolCallID string `json:"tool_call_id"` + Output any `json:"output"` +} + +type CreateThreadAndRunRequest struct { + RunRequest + Thread ThreadRequest `json:"thread"` +} + +type RunStep struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + ThreadID string `json:"thread_id"` + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + Status RunStepStatus `json:"status"` + StepDetails StepDetails `json:"step_details"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiredAt *int64 `json:"expired_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type RunStepStatus string + +const ( + RunStepStatusInProgress RunStepStatus = "in_progress" + RunStepStatusCancelling RunStepStatus = "cancelled" + RunStepStatusFailed RunStepStatus = "failed" + RunStepStatusCompleted RunStepStatus = "completed" + RunStepStatusExpired RunStepStatus = "expired" +) + +type RunStepType string + +const ( + RunStepTypeMessageCreation RunStepType = "message_creation" + RunStepTypeToolCalls RunStepType = "tool_calls" +) + +type StepDetails struct { + Type RunStepType `json:"type"` + MessageCreation *StepDetailsMessageCreation `json:"message_creation,omitempty"` + ToolCalls *StepDetailsToolCalls `json:"tool_calls,omitempty"` +} + +type StepDetailsMessageCreation struct { + MessageID string `json:"message_id"` +} + +type StepDetailsToolCalls struct { + ToolCalls []ToolCall `json:"tool_calls"` +} + +// RunStepList is a list of steps. +type RunStepList struct { + RunSteps []RunStep `json:"data"` + + httpHeader +} + +type Pagination struct { + Limit *int + Order *string + After *string + Before *string +} + +// CreateRun creates a new run. +func (c *Client) CreateRun( + ctx context.Context, + threadID string, + request RunRequest, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveRun retrieves a run. +func (c *Client) RetrieveRun( + ctx context.Context, + threadID string, + runID string, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyRun modifies a run. +func (c *Client) ModifyRun( + ctx context.Context, + threadID string, + runID string, + request RunModifyRequest, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListRuns lists runs. +func (c *Client) ListRuns( + ctx context.Context, + threadID string, + pagination Pagination, +) (response RunList, err error) { + urlValues := url.Values{} + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/runs%s", threadID, encodedValues) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// SubmitToolOutputs submits tool outputs. +func (c *Client) SubmitToolOutputs( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelRun cancels a run. +func (c *Client) CancelRun( + ctx context.Context, + threadID string, + runID string) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/cancel", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CreateThreadAndRun submits tool outputs. +func (c *Client) CreateThreadAndRun( + ctx context.Context, + request CreateThreadAndRunRequest) (response Run, err error) { + urlSuffix := "/threads/runs" + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveRunStep retrieves a run step. +func (c *Client) RetrieveRunStep( + ctx context.Context, + threadID string, + runID string, + stepID string, +) (response RunStep, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/steps/%s", threadID, runID, stepID) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListRunSteps lists run steps. +func (c *Client) ListRunSteps( + ctx context.Context, + threadID string, + runID string, + pagination Pagination, +) (response RunStepList, err error) { + urlValues := url.Values{} + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/steps%s", threadID, runID, encodedValues) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/run_test.go b/run_test.go new file mode 100644 index 000000000..cdf99db05 --- /dev/null +++ b/run_test.go @@ -0,0 +1,237 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestRun(t *testing.T) { + assistantID := "asst_abc123" + threadID := "thread_abc123" + runID := "run_abc123" + stepID := "step_abc123" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/steps/"+stepID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunStep{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStepStatusCompleted, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/steps", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunStepList{ + RunSteps: []openai.RunStep{ + { + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStepStatusCompleted, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusCancelling, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/submit_tool_outputs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusCancelling, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.RunModifyRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.RunRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunList{ + Runs: []openai.Run{ + { + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/runs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.CreateThreadAndRunRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateRun(ctx, threadID, openai.RunRequest{ + AssistantID: assistantID, + }) + checks.NoError(t, err, "CreateRun error") + + _, err = client.RetrieveRun(ctx, threadID, runID) + checks.NoError(t, err, "RetrieveRun error") + + _, err = client.ModifyRun(ctx, threadID, runID, openai.RunModifyRequest{ + Metadata: map[string]any{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyRun error") + + _, err = client.ListRuns( + ctx, + threadID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }, + ) + checks.NoError(t, err, "ListRuns error") + + _, err = client.SubmitToolOutputs(ctx, threadID, runID, + openai.SubmitToolOutputsRequest{}) + checks.NoError(t, err, "SubmitToolOutputs error") + + _, err = client.CancelRun(ctx, threadID, runID) + checks.NoError(t, err, "CancelRun error") + + _, err = client.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{ + RunRequest: openai.RunRequest{ + AssistantID: assistantID, + }, + Thread: openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }, + }) + checks.NoError(t, err, "CreateThreadAndRun error") + + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) + checks.NoError(t, err, "RetrieveRunStep error") + + _, err = client.ListRunSteps( + ctx, + threadID, + runID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }, + ) + checks.NoError(t, err, "ListRunSteps error") +} From 35495ccd364265f37800a6fa72fed7f05705eb82 Mon Sep 17 00:00:00 2001 From: Kyle Bolton Date: Sun, 12 Nov 2023 06:09:40 -0500 Subject: [PATCH 085/206] Add `json:"metadata,omitempty"` to RunRequest struct (#561) Metadata is an optional field per the api spec https://platform.openai.com/docs/api-reference/runs/createRun --- run.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/run.go b/run.go index 5d6ea58db..7ff730fea 100644 --- a/run.go +++ b/run.go @@ -70,11 +70,11 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model *string `json:"model,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any + AssistantID string `json:"assistant_id"` + Model *string `json:"model,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type RunModifyRequest struct { From 9fefd50e12ad138efa3f38756be5dd2ed5fefadd Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sun, 12 Nov 2023 20:10:00 +0900 Subject: [PATCH 086/206] Fix typo in chat_test.go (#564) requetsts -> requests --- chat_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat_test.go b/chat_test.go index a8155edf2..8377809da 100644 --- a/chat_test.go +++ b/chat_test.go @@ -144,7 +144,7 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { } resetRequestsTime := headers.ResetRequests.Time() if resetRequestsTime.Before(time.Now()) { - t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + t.Errorf("unexpected reset requests: %v", resetRequestsTime) } bs1, _ := json.Marshal(headers) From b7cac703acb1a8be0e803c81ad3236be66be969a Mon Sep 17 00:00:00 2001 From: Urjit Singh Bhatia Date: Mon, 13 Nov 2023 08:33:26 -0600 Subject: [PATCH 087/206] Feat/messages api (#546) * fix test server setup: - go map access is not deterministic - this can lead to a route: /foo/bar/1 matching /foo/bar before matching /foo/bar/1 if the map iteration go through /foo/bar first since the regex match wasn't bound to start and end anchors - registering handlers now converts * in routes to .* for proper regex matching - test server route handling now tries to fully match the handler route * add missing /v1 prefix to fine-tuning job cancel test server handler * add create message call * add messages list call * add get message call * add modify message call, fix return types for other message calls * add message file retrieve call * add list message files call * code style fixes * add test for list messages with pagination options * add beta header to msg calls now that #545 is merged * Update messages.go Co-authored-by: Simone Vellei * Update messages.go Co-authored-by: Simone Vellei * add missing object details for message, fix tests * fix merge formatting * minor style fixes --------- Co-authored-by: Simone Vellei --- client_test.go | 18 ++++ messages.go | 178 +++++++++++++++++++++++++++++++++++ messages_test.go | 235 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 431 insertions(+) create mode 100644 messages.go create mode 100644 messages_test.go diff --git a/client_test.go b/client_test.go index d5d3e2644..24cb5ffa7 100644 --- a/client_test.go +++ b/client_test.go @@ -301,6 +301,24 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteAssistantFile", func() (any, error) { return nil, client.DeleteAssistantFile(ctx, "", "") }}, + {"CreateMessage", func() (any, error) { + return client.CreateMessage(ctx, "", MessageRequest{}) + }}, + {"ListMessage", func() (any, error) { + return client.ListMessage(ctx, "", nil, nil, nil, nil) + }}, + {"RetrieveMessage", func() (any, error) { + return client.RetrieveMessage(ctx, "", "") + }}, + {"ModifyMessage", func() (any, error) { + return client.ModifyMessage(ctx, "", "", nil) + }}, + {"RetrieveMessageFile", func() (any, error) { + return client.RetrieveMessageFile(ctx, "", "", "") + }}, + {"ListMessageFiles", func() (any, error) { + return client.ListMessageFiles(ctx, "", "") + }}, {"CreateThread", func() (any, error) { return client.CreateThread(ctx, ThreadRequest{}) }}, diff --git a/messages.go b/messages.go new file mode 100644 index 000000000..4e691a8ba --- /dev/null +++ b/messages.go @@ -0,0 +1,178 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + messagesSuffix = "messages" +) + +type Message struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + ThreadID string `json:"thread_id"` + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIds []string `json:"file_ids"` + AssistantID *string `json:"assistant_id,omitempty"` + RunID *string `json:"run_id,omitempty"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type MessagesList struct { + Messages []Message `json:"data"` + + httpHeader +} + +type MessageContent struct { + Type string `json:"type"` + Text *MessageText `json:"text,omitempty"` + ImageFile *ImageFile `json:"image_file,omitempty"` +} +type MessageText struct { + Value string `json:"value"` + Annotations []any `json:"annotations"` +} + +type ImageFile struct { + FileID string `json:"file_id"` +} + +type MessageRequest struct { + Role string `json:"role"` + Content string `json:"content"` + FileIds []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type MessageFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + MessageID string `json:"message_id"` + + httpHeader +} + +type MessageFilesList struct { + MessageFiles []MessageFile `json:"data"` + + httpHeader +} + +// CreateMessage creates a new message. +func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// ListMessage fetches all messages in the thread. +func (c *Client) ListMessage(ctx context.Context, threadID string, + limit *int, + order *string, + after *string, + before *string, +) (messages MessagesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &messages) + return +} + +// RetrieveMessage retrieves a Message. +func (c *Client) RetrieveMessage( + ctx context.Context, + threadID, messageID string, +) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// ModifyMessage modifies a message. +func (c *Client) ModifyMessage( + ctx context.Context, + threadID, messageID string, + metadata map[string]any, +) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(metadata), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// RetrieveMessageFile fetches a message file. +func (c *Client) RetrieveMessageFile( + ctx context.Context, + threadID, messageID, fileID string, +) (file MessageFile, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return +} + +// ListMessageFiles fetches all files attached to a message. +func (c *Client) ListMessageFiles( + ctx context.Context, + threadID, messageID string, +) (files MessageFilesList, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &files) + return +} diff --git a/messages_test.go b/messages_test.go new file mode 100644 index 000000000..282b1cc9d --- /dev/null +++ b/messages_test.go @@ -0,0 +1,235 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +var emptyStr = "" + +// TestMessages Tests the messages endpoint of the API using the mocked server. +func TestMessages(t *testing.T) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.MessageFile{ + ID: fileID, + Object: "thread.message.file", + CreatedAt: 1699061776, + MessageID: messageID, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID+"/files", + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.MessageFilesList{MessageFiles: []openai.MessageFile{{ + ID: fileID, + Object: "thread.message.file", + CreatedAt: 0, + MessageID: messageID, + }}}) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + metadata := map[string]any{} + err := json.NewDecoder(r.Body).Decode(&metadata) + checks.NoError(t, err, "unable to decode metadata in modify message call") + + resBytes, _ := json.Marshal( + openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: metadata, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages", + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + resBytes, _ := json.Marshal(openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodGet: + resBytes, _ := json.Marshal(openai.MessagesList{ + Messages: []openai.Message{{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }}}) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + ctx := context.Background() + + // static assertion of return type + var msg openai.Message + msg, err := client.CreateMessage(ctx, threadID, openai.MessageRequest{ + Role: "user", + Content: "How does AI work?", + FileIds: nil, + Metadata: nil, + }) + checks.NoError(t, err, "CreateMessage error") + if msg.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + + var msgs openai.MessagesList + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) + checks.NoError(t, err, "ListMessages error") + if len(msgs.Messages) != 1 { + t.Fatalf("unexpected length of fetched messages") + } + + // with pagination options set + limit := 1 + order := "desc" + after := "obj_foo" + before := "obj_bar" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListMessages error") + if len(msgs.Messages) != 1 { + t.Fatalf("unexpected length of fetched messages") + } + + msg, err = client.RetrieveMessage(ctx, threadID, messageID) + checks.NoError(t, err, "RetrieveMessage error") + if msg.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + + msg, err = client.ModifyMessage(ctx, threadID, messageID, + map[string]any{ + "foo": "bar", + }) + checks.NoError(t, err, "ModifyMessage error") + if msg.Metadata["foo"] != "bar" { + t.Fatalf("expected message metadata to get modified") + } + + // message files + var msgFile openai.MessageFile + msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) + checks.NoError(t, err, "RetrieveMessageFile error") + if msgFile.ID != fileID { + t.Fatalf("unexpected message file id: '%s'", msgFile.ID) + } + + var msgFiles openai.MessageFilesList + msgFiles, err = client.ListMessageFiles(ctx, threadID, messageID) + checks.NoError(t, err, "RetrieveMessageFile error") + if len(msgFiles.MessageFiles) != 1 { + t.Fatalf("unexpected count of message files: %d", len(msgFiles.MessageFiles)) + } + if msgFiles.MessageFiles[0].ID != fileID { + t.Fatalf("unexpected message file id: '%s' in list message files", msgFiles.MessageFiles[0].ID) + } +} From 515de0219d3b4d30351d44d8a0f508599de6c053 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Mon, 13 Nov 2023 09:35:34 -0500 Subject: [PATCH 088/206] feat: initial TTS support (#528) * feat: initial TTS support * chore: lint, omitempty * chore: dont use pointer in struct * fix: add mocked server tests to speech_test.go Co-authored-by: Lachlan Laycock * chore: update imports * chore: fix lint * chore: add an error check * chore: ignore lint * chore: add error checks in package * chore: add test * chore: fix test --------- Co-authored-by: Lachlan Laycock --- client_test.go | 3 ++ speech.go | 87 +++++++++++++++++++++++++++++++++++++ speech_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+) create mode 100644 speech.go create mode 100644 speech_test.go diff --git a/client_test.go b/client_test.go index 24cb5ffa7..1c9084585 100644 --- a/client_test.go +++ b/client_test.go @@ -358,6 +358,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListRunSteps", func() (any, error) { return client.ListRunSteps(ctx, "", "", Pagination{}) }}, + {"CreateSpeech", func() (any, error) { + return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) + }}, } for _, testCase := range testCases { diff --git a/speech.go b/speech.go new file mode 100644 index 000000000..a3d5f5dca --- /dev/null +++ b/speech.go @@ -0,0 +1,87 @@ +package openai + +import ( + "context" + "errors" + "io" + "net/http" +) + +type SpeechModel string + +const ( + TTSModel1 SpeechModel = "tts-1" + TTsModel1HD SpeechModel = "tts-1-hd" +) + +type SpeechVoice string + +const ( + VoiceAlloy SpeechVoice = "alloy" + VoiceEcho SpeechVoice = "echo" + VoiceFable SpeechVoice = "fable" + VoiceOnyx SpeechVoice = "onyx" + VoiceNova SpeechVoice = "nova" + VoiceShimmer SpeechVoice = "shimmer" +) + +type SpeechResponseFormat string + +const ( + SpeechResponseFormatMp3 SpeechResponseFormat = "mp3" + SpeechResponseFormatOpus SpeechResponseFormat = "opus" + SpeechResponseFormatAac SpeechResponseFormat = "aac" + SpeechResponseFormatFlac SpeechResponseFormat = "flac" +) + +var ( + ErrInvalidSpeechModel = errors.New("invalid speech model") + ErrInvalidVoice = errors.New("invalid voice") +) + +type CreateSpeechRequest struct { + Model SpeechModel `json:"model"` + Input string `json:"input"` + Voice SpeechVoice `json:"voice"` + ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 + Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 +} + +func contains[T comparable](s []T, e T) bool { + for _, v := range s { + if v == e { + return true + } + } + return false +} + +func isValidSpeechModel(model SpeechModel) bool { + return contains([]SpeechModel{TTSModel1, TTsModel1HD}, model) +} + +func isValidVoice(voice SpeechVoice) bool { + return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) +} + +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { + if !isValidSpeechModel(request.Model) { + err = ErrInvalidSpeechModel + return + } + if !isValidVoice(request.Voice) { + err = ErrInvalidVoice + return + } + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", request.Model), + withBody(request), + withContentType("application/json; charset=utf-8"), + ) + if err != nil { + return + } + + response, err = c.sendRequestRaw(req) + + return +} diff --git a/speech_test.go b/speech_test.go new file mode 100644 index 000000000..d9ba58b13 --- /dev/null +++ b/speech_test.go @@ -0,0 +1,115 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestSpeechIntegration(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { + dir, cleanup := test.CreateTestDirectory(t) + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + defer cleanup() + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if mediaType != "application/json" { + http.Error(w, "request is not json", http.StatusBadRequest) + return + } + + // Parse the JSON body of the request + var params map[string]interface{} + err = json.NewDecoder(r.Body).Decode(¶ms) + if err != nil { + http.Error(w, "failed to parse request body", http.StatusBadRequest) + return + } + + // Check if each required field is present in the parsed JSON object + reqParams := []string{"model", "input", "voice"} + for _, param := range reqParams { + _, ok := params[param] + if !ok { + http.Error(w, fmt.Sprintf("no %s in params", param), http.StatusBadRequest) + return + } + } + + // read audio file content + audioFile, err := os.ReadFile(path) + if err != nil { + http.Error(w, "failed to read audio file", http.StatusInternalServerError) + return + } + + // write audio file content to response + w.Header().Set("Content-Type", "audio/mpeg") + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Connection", "keep-alive") + _, err = w.Write(audioFile) + if err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } + }) + + t.Run("happy path", func(t *testing.T) { + res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.NoError(t, err, "CreateSpeech error") + defer res.Close() + + buf, err := io.ReadAll(res) + checks.NoError(t, err, "ReadAll error") + + // save buf to file as mp3 + err = os.WriteFile("test.mp3", buf, 0644) + checks.NoError(t, err, "Create error") + }) + t.Run("invalid model", func(t *testing.T) { + _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: "invalid_model", + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") + }) + + t.Run("invalid voice", func(t *testing.T) { + _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: "invalid_voice", + }) + checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") + }) +} From fe67abb97ed472bad359cc606c2d63289277cabf Mon Sep 17 00:00:00 2001 From: Donnie Flood Date: Wed, 15 Nov 2023 09:06:57 -0700 Subject: [PATCH 089/206] fix: add beta assistant header to CreateMessage call (#566) --- messages.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/messages.go b/messages.go index 4e691a8ba..3fd377fcb 100644 --- a/messages.go +++ b/messages.go @@ -71,7 +71,7 @@ type MessageFilesList struct { // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) if err != nil { return } From 71848ccf6928157d1487c5bbd5029ceaf3af53ed Mon Sep 17 00:00:00 2001 From: Donnie Flood Date: Wed, 15 Nov 2023 09:08:48 -0700 Subject: [PATCH 090/206] feat: support direct bytes for file upload (#568) * feat: support direct bytes for file upload * add test for errors * add coverage --- client_test.go | 3 +++ files.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ files_api_test.go | 13 ++++++++++++ files_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+) diff --git a/client_test.go b/client_test.go index 1c9084585..664f9fb92 100644 --- a/client_test.go +++ b/client_test.go @@ -247,6 +247,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateImage", func() (any, error) { return client.CreateImage(ctx, ImageRequest{}) }}, + {"CreateFileBytes", func() (any, error) { + return client.CreateFileBytes(ctx, FileBytesRequest{}) + }}, {"DeleteFile", func() (any, error) { return nil, client.DeleteFile(ctx, "") }}, diff --git a/files.go b/files.go index 9e521fbbe..371d06c69 100644 --- a/files.go +++ b/files.go @@ -15,6 +15,24 @@ type FileRequest struct { Purpose string `json:"purpose"` } +// PurposeType represents the purpose of the file when uploading. +type PurposeType string + +const ( + PurposeFineTune PurposeType = "fine-tune" + PurposeAssistants PurposeType = "assistants" +) + +// FileBytesRequest represents a file upload request. +type FileBytesRequest struct { + // the name of the uploaded file in OpenAI + Name string + // the bytes of the file + Bytes []byte + // the purpose of the file + Purpose PurposeType +} + // File struct represents an OpenAPI file. type File struct { Bytes int `json:"bytes"` @@ -36,6 +54,37 @@ type FilesList struct { httpHeader } +// CreateFileBytes uploads bytes directly to OpenAI without requiring a local file. +func (c *Client) CreateFileBytes(ctx context.Context, request FileBytesRequest) (file File, err error) { + var b bytes.Buffer + reader := bytes.NewReader(request.Bytes) + builder := c.createFormBuilder(&b) + + err = builder.WriteField("purpose", string(request.Purpose)) + if err != nil { + return + } + + err = builder.CreateFormFileReader("file", reader, request.Name) + if err != nil { + return + } + + err = builder.Close() + if err != nil { + return + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return +} + // CreateFile uploads a jsonl file to GPT3 // FilePath must be a local file path. func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { diff --git a/files_api_test.go b/files_api_test.go index 330b88159..6f62a3fbc 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -16,6 +16,19 @@ import ( "github.com/sashabaranov/go-openai/internal/test/checks" ) +func TestFileBytesUpload(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: openai.PurposeFineTune, + } + _, err := client.CreateFileBytes(context.Background(), req) + checks.NoError(t, err, "CreateFile error") +} + func TestFileUpload(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/files_test.go b/files_test.go index f588b30dc..3c1b99fb4 100644 --- a/files_test.go +++ b/files_test.go @@ -11,6 +11,53 @@ import ( "github.com/sashabaranov/go-openai/internal/test/checks" ) +func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + mockBuilder := &mockFormBuilder{} + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return mockBuilder + } + + ctx := context.Background() + req := FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: PurposeAssistants, + } + + mockError := fmt.Errorf("mockWriteField error") + mockBuilder.mockWriteField = func(string, string) error { + return mockError + } + _, err := client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockCreateFormFile error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return mockError + } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockClose error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return nil + } + mockBuilder.mockClose = func() error { + return mockError + } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") +} + func TestFileUploadWithFailingFormBuilder(t *testing.T) { config := DefaultConfig("") config.BaseURL = "" @@ -55,6 +102,9 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { return mockError } _, err = client.CreateFile(ctx, req) + if err == nil { + t.Fatal("CreateFile should return error if form builder fails") + } checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") } From 464b85b6d766a53c922a15dd1138570e31ec661b Mon Sep 17 00:00:00 2001 From: Liron Levin Date: Wed, 15 Nov 2023 18:22:39 +0200 Subject: [PATCH 091/206] Pagination fields are missing from assistants list beta API (#571) curl "https://api.openai.com/v1/assistants?order=desc&limit=20" \ -H "Content-Type: application/json" \ -H "Authorization: Bearer $OPENAI_API_KEY" \ -H "OpenAI-Beta: assistants=v1" { "object": "list", "data": [], "first_id": null, "last_id": null, "has_more": false } --- assistant.go | 4 +++- assistant_test.go | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/assistant.go b/assistant.go index de49be680..59f78284f 100644 --- a/assistant.go +++ b/assistant.go @@ -52,7 +52,9 @@ type AssistantRequest struct { // AssistantsList is a list of assistants. type AssistantsList struct { Assistants []Assistant `json:"data"` - + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` httpHeader } diff --git a/assistant_test.go b/assistant_test.go index eb6f42458..30daec2b1 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -142,6 +142,8 @@ When asked a question, write and run Python code to answer the question.` fmt.Fprintln(w, string(resBytes)) } else if r.Method == http.MethodGet { resBytes, _ := json.Marshal(openai.AssistantsList{ + LastID: &assistantID, + FirstID: &assistantID, Assistants: []openai.Assistant{ { ID: assistantID, From 3220f19ee209de5e4bbc6db44261adcd4bbf1df1 Mon Sep 17 00:00:00 2001 From: Ccheers <1048315650@qq.com> Date: Thu, 16 Nov 2023 00:23:41 +0800 Subject: [PATCH 092/206] feat(runapi): add RunStepList response args https://platform.openai.com/docs/api-reference/runs/listRunSteps (#573) --- run.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/run.go b/run.go index 7ff730fea..f95bf0e35 100644 --- a/run.go +++ b/run.go @@ -157,6 +157,10 @@ type StepDetailsToolCalls struct { type RunStepList struct { RunSteps []RunStep `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` + httpHeader } From 18465723f7d96587045ce0a450d6874128b870cd Mon Sep 17 00:00:00 2001 From: Charlie Revett <2796074+revett@users.noreply.github.com> Date: Wed, 15 Nov 2023 16:25:18 +0000 Subject: [PATCH 093/206] Add missing struct properties. (#579) --- assistant.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/assistant.go b/assistant.go index 59f78284f..bd335833a 100644 --- a/assistant.go +++ b/assistant.go @@ -22,6 +22,8 @@ type Assistant struct { Model string `json:"model"` Instructions *string `json:"instructions,omitempty"` Tools []AssistantTool `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` httpHeader } From 4fd904c2927c421cdbff89249979bc6a8a371d11 Mon Sep 17 00:00:00 2001 From: Charlie Revett <2796074+revett@users.noreply.github.com> Date: Sat, 18 Nov 2023 06:55:58 +0000 Subject: [PATCH 094/206] Add File purposes as constants (#577) * Add purposes. * Formatting. --- files.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/files.go b/files.go index 371d06c69..a37d45f18 100644 --- a/files.go +++ b/files.go @@ -19,8 +19,10 @@ type FileRequest struct { type PurposeType string const ( - PurposeFineTune PurposeType = "fine-tune" - PurposeAssistants PurposeType = "assistants" + PurposeFineTune PurposeType = "fine-tune" + PurposeFineTuneResults PurposeType = "fine-tune-results" + PurposeAssistants PurposeType = "assistants" + PurposeAssistantsOutput PurposeType = "assistants_output" ) // FileBytesRequest represents a file upload request. From 9efad284d02d90b2de3eeefc67a966743e47a2ac Mon Sep 17 00:00:00 2001 From: Albert Putra Purnama <14824254+albertpurnama@users.noreply.github.com> Date: Fri, 17 Nov 2023 22:59:01 -0800 Subject: [PATCH 095/206] Updates the tool call struct (#595) --- run.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/run.go b/run.go index f95bf0e35..dbb708a13 100644 --- a/run.go +++ b/run.go @@ -142,17 +142,13 @@ const ( type StepDetails struct { Type RunStepType `json:"type"` MessageCreation *StepDetailsMessageCreation `json:"message_creation,omitempty"` - ToolCalls *StepDetailsToolCalls `json:"tool_calls,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type StepDetailsMessageCreation struct { MessageID string `json:"message_id"` } -type StepDetailsToolCalls struct { - ToolCalls []ToolCall `json:"tool_calls"` -} - // RunStepList is a list of steps. type RunStepList struct { RunSteps []RunStep `json:"data"` From a130cfee26427b99ae0bf957be74e32ca8a7f567 Mon Sep 17 00:00:00 2001 From: Albert Putra Purnama <14824254+albertpurnama@users.noreply.github.com> Date: Fri, 17 Nov 2023 23:01:06 -0800 Subject: [PATCH 096/206] Add missing response fields for pagination (#584) --- messages.go | 5 +++++ messages_test.go | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/messages.go b/messages.go index 3fd377fcb..ead247f5b 100644 --- a/messages.go +++ b/messages.go @@ -29,6 +29,11 @@ type Message struct { type MessagesList struct { Messages []Message `json:"data"` + Object string `json:"object"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` + httpHeader } diff --git a/messages_test.go b/messages_test.go index 282b1cc9d..9168d6ccf 100644 --- a/messages_test.go +++ b/messages_test.go @@ -142,6 +142,7 @@ func TestMessages(t *testing.T) { fmt.Fprintln(w, string(resBytes)) case http.MethodGet: resBytes, _ := json.Marshal(openai.MessagesList{ + Object: "list", Messages: []openai.Message{{ ID: messageID, Object: "thread.message", @@ -159,7 +160,11 @@ func TestMessages(t *testing.T) { AssistantID: &emptyStr, RunID: &emptyStr, Metadata: nil, - }}}) + }}, + FirstID: &messageID, + LastID: &messageID, + HasMore: false, + }) fmt.Fprintln(w, string(resBytes)) default: t.Fatalf("unsupported messages http method: %s", r.Method) From f87909596f8b0d293142ca00c4d4adc872c52ded Mon Sep 17 00:00:00 2001 From: pjuhasz Date: Fri, 24 Nov 2023 07:34:25 +0000 Subject: [PATCH 097/206] Add canary-tts to speech models (#603) Co-authored-by: Peter Juhasz --- speech.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/speech.go b/speech.go index a3d5f5dca..f2442b921 100644 --- a/speech.go +++ b/speech.go @@ -10,8 +10,9 @@ import ( type SpeechModel string const ( - TTSModel1 SpeechModel = "tts-1" - TTsModel1HD SpeechModel = "tts-1-hd" + TTSModel1 SpeechModel = "tts-1" + TTSModel1HD SpeechModel = "tts-1-hd" + TTSModelCanary SpeechModel = "canary-tts" ) type SpeechVoice string @@ -57,7 +58,7 @@ func contains[T comparable](s []T, e T) bool { } func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTsModel1HD}, model) + return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) } func isValidVoice(voice SpeechVoice) bool { From 726099132704fd5ebc1680166f45bbd280bdb546 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 24 Nov 2023 13:36:10 +0400 Subject: [PATCH 098/206] Update PULL_REQUEST_TEMPLATE.md (#606) --- .github/PULL_REQUEST_TEMPLATE.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 44bf697ed..222c065ce 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,11 +8,14 @@ Thanks for submitting a pull request! Please provide enough information so that **Describe the change** Please provide a clear and concise description of the changes you're proposing. Explain what problem it solves or what feature it adds. +**Provide OpenAI documentation link** +Provide a relevant API doc from https://platform.openai.com/docs/api-reference + **Describe your solution** Describe how your changes address the problem or how they add the feature. This should include a brief description of your approach and any new libraries or dependencies you're using. **Tests** -Briefly describe how you have tested these changes. +Briefly describe how you have tested these changes. If possible — please add integration tests. **Additional context** Add any other context or screenshots or logs about your pull request here. If the pull request relates to an open issue, please link to it. From 03caea89b75c4e6a5ac32f6e60e69e309d852e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Kintzi?= Date: Fri, 24 Nov 2023 13:17:00 +0000 Subject: [PATCH 099/206] Add support for multi part chat messages (and gpt-4-vision-preview model) (#580) * Add support for multi part chat messages OpenAI has recently introduced a new model called gpt-4-visual-preview, which now supports images as input. The chat completion endpoint accepts multi-part chat messages, where the content can be an array of structs in addition to the usual string format. This commit introduces new structures and constants to represent different types of content parts. It also implements the json.Marshaler and json.Unmarshaler interfaces on ChatCompletionMessage. * Add ImageURLDetail and ChatMessagePartType types * Optimize ChatCompletionMessage deserialization * Add ErrContentFieldsMisused error --- chat.go | 91 ++++++++++++++++++++++++++++++++++++++++++++- chat_test.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index ebdc0e24b..5b87b6bd7 100644 --- a/chat.go +++ b/chat.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "errors" "net/http" ) @@ -20,6 +21,7 @@ const chatCompletionsSuffix = "/chat/completions" var ( ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll + ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously") ) type Hate struct { @@ -51,9 +53,36 @@ type PromptAnnotation struct { ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } +type ImageURLDetail string + +const ( + ImageURLDetailHigh ImageURLDetail = "high" + ImageURLDetailLow ImageURLDetail = "low" + ImageURLDetailAuto ImageURLDetail = "auto" +) + +type ChatMessageImageURL struct { + URL string `json:"url,omitempty"` + Detail ImageURLDetail `json:"detail,omitempty"` +} + +type ChatMessagePartType string + +const ( + ChatMessagePartTypeText ChatMessagePartType = "text" + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" +) + +type ChatMessagePart struct { + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` +} + type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in // the documentation for the official library for python: @@ -70,6 +99,64 @@ type ChatCompletionMessage struct { ToolCallID string `json:"tool_call_id,omitempty"` } +func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { + if m.Content != "" && m.MultiContent != nil { + return nil, ErrContentFieldsMisused + } + if len(m.MultiContent) > 0 { + msg := struct { + Role string `json:"role"` + Content string `json:"-"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) + } + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) +} + +func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &msg); err == nil { + *m = ChatCompletionMessage(msg) + return nil + } + multiMsg := struct { + Role string `json:"role"` + Content string + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &multiMsg); err != nil { + return err + } + *m = ChatCompletionMessage(multiMsg) + return nil +} + type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` diff --git a/chat_test.go b/chat_test.go index 8377809da..520bf5ca4 100644 --- a/chat_test.go +++ b/chat_test.go @@ -3,6 +3,7 @@ package openai_test import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -296,6 +297,108 @@ func TestAzureChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateAzureChatCompletion error") } +func TestMultipartChatCompletions(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) + + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Hello!", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: "URL", + Detail: openai.ImageURLDetailLow, + }, + }, + }, + }, + }, + }) + checks.NoError(t, err, "CreateAzureChatCompletion error") +} + +func TestMultipartChatMessageSerialization(t *testing.T) { + jsonText := `[{"role":"system","content":"system-message"},` + + `{"role":"user","content":[{"type":"text","text":"nice-text"},` + + `{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]` + + var msgs []openai.ChatCompletionMessage + err := json.Unmarshal([]byte(jsonText), &msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + if len(msgs) != 2 { + t.Errorf("unexpected number of messages") + } + if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil { + t.Errorf("invalid user message: %v", msgs[0]) + } + if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 { + t.Errorf("invalid user message") + } + parts := msgs[1].MultiContent + if parts[0].Type != "text" || parts[0].Text != "nice-text" { + t.Errorf("invalid text part: %v", parts[0]) + } + if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" { + t.Errorf("invalid image_url part") + } + + s, err := json.Marshal(msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + res := strings.ReplaceAll(string(s), " ", "") + if res != jsonText { + t.Fatalf("invalid message: %s", string(s)) + } + + invalidMsg := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "some-text", + MultiContent: []openai.ChatMessagePart{ + { + Type: "text", + Text: "nice-text", + }, + }, + }, + } + _, err = json.Marshal(invalidMsg) + if !errors.Is(err, openai.ErrContentFieldsMisused) { + t.Fatalf("Expected error: %s", err) + } + + err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs) + if err == nil { + t.Fatalf("Expected error") + } + + emptyMultiContentMsg := openai.ChatCompletionMessage{ + Role: "user", + MultiContent: []openai.ChatMessagePart{}, + } + s, err = json.Marshal(emptyMultiContentMsg) + if err != nil { + t.Fatalf("Unexpected error") + } + res = strings.ReplaceAll(string(s), " ", "") + if res != `{"role":"user","content":""}` { + t.Fatalf("invalid message: %s", string(s)) + } +} + // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error From a09cb0c528c110a6955a9ee9a5d021a57ed44b90 Mon Sep 17 00:00:00 2001 From: mikeb26 <83850730+mikeb26@users.noreply.github.com> Date: Sun, 26 Nov 2023 08:45:28 +0000 Subject: [PATCH 100/206] Add completion-with-tool example (#598) As a user of this go SDK it was not immediately intuitive to me how to correctly utilize the function calling capability of GPT4 (https://platform.openai.com/docs/guides/function-calling). While the aformentioned link provides a helpful example written in python, I initially tripped over how to correclty translate the specification of function arguments when usingthis go SDK. To make it easier for others in the future this commit adds a completion-with-tool example showing how to correctly utilize the function calling capability of GPT4 using this SDK end-to-end in a CreateChatCompletion() sequence. --- examples/completion-with-tool/main.go | 94 +++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 examples/completion-with-tool/main.go diff --git a/examples/completion-with-tool/main.go b/examples/completion-with-tool/main.go new file mode 100644 index 000000000..2c7fedc5e --- /dev/null +++ b/examples/completion-with-tool/main.go @@ -0,0 +1,94 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + ctx := context.Background() + client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) + + // describe the function & its inputs + params := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + } + f := openai.FunctionDefinition{ + Name: "get_current_weather", + Description: "Get the current weather in a given location", + Parameters: params, + } + t := openai.Tool{ + Type: openai.ToolTypeFunction, + Function: f, + } + + // simulate user asking a question that requires the function + dialogue := []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"}, + } + fmt.Printf("Asking OpenAI '%v' and providing it a '%v()' function...\n", + dialogue[0].Content, f.Name) + resp, err := client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: dialogue, + Tools: []openai.Tool{t}, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("Completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + msg := resp.Choices[0].Message + if len(msg.ToolCalls) != 1 { + fmt.Printf("Completion error: len(toolcalls): %v\n", len(msg.ToolCalls)) + return + } + + // simulate calling the function & responding to OpenAI + dialogue = append(dialogue, msg) + fmt.Printf("OpenAI called us back wanting to invoke our function '%v' with params '%v'\n", + msg.ToolCalls[0].Function.Name, msg.ToolCalls[0].Function.Arguments) + dialogue = append(dialogue, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: "Sunny and 80 degrees.", + Name: msg.ToolCalls[0].Function.Name, + ToolCallID: msg.ToolCalls[0].ID, + }) + fmt.Printf("Sending OpenAI our '%v()' function's response and requesting the reply to the original question...\n", + f.Name) + resp, err = client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: dialogue, + Tools: []openai.Tool{t}, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("2nd completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + + // display OpenAI's response to the original question utilizing our function + msg = resp.Choices[0].Message + fmt.Printf("OpenAI answered the original request with: %v\n", + msg.Content) +} From c9615e0cbe3b68088ee04221acdfde63d6d20766 Mon Sep 17 00:00:00 2001 From: "xuanming.zhang" Date: Wed, 3 Jan 2024 19:42:57 +0800 Subject: [PATCH 101/206] Added support for createImage Azure models (#608) --- image.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/image.go b/image.go index 4fe8b3a32..afd4e196b 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } From f10955ce090c7b0d8f38458c753c01cd9b88aca5 Mon Sep 17 00:00:00 2001 From: Danai Antoniou <32068609+danai-antoniou@users.noreply.github.com> Date: Tue, 9 Jan 2024 16:50:56 +0000 Subject: [PATCH 102/206] Log probabilities for chat completion output tokens (#625) * Add logprobs * Logprobs pointer * Move toplogporbs * Create toplogprobs struct * Remove pointers --- chat.go | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 5b87b6bd7..33b8755ce 100644 --- a/chat.go +++ b/chat.go @@ -200,7 +200,15 @@ type ChatCompletionRequest struct { // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + // LogProbs indicates whether to return log probabilities of the output tokens or not. + // If true, returns the log probabilities of each output token returned in the content of message. + // This option is currently not available on the gpt-4-vision-preview model. + LogProbs bool `json:"logprobs,omitempty"` + // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each + // token position, each with an associated log probability. + // logprobs must be set to true if this parameter is used. + TopLogProbs int `json:"top_logprobs,omitempty"` + User string `json:"user,omitempty"` // Deprecated: use Tools instead. Functions []FunctionDefinition `json:"functions,omitempty"` // Deprecated: use ToolChoice instead. @@ -244,6 +252,28 @@ type FunctionDefinition struct { // Deprecated: use FunctionDefinition instead. type FunctionDefine = FunctionDefinition +type TopLogProbs struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` +} + +// LogProb represents the probability information for a token. +type LogProb struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null + // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. + // In rare cases, there may be fewer than the number of requested top_logprobs returned. + TopLogProbs []TopLogProbs `json:"top_logprobs"` +} + +// LogProbs is the top-level structure containing the log probability information. +type LogProbs struct { + // Content is a list of message content tokens with log probability information. + Content []LogProb `json:"content"` +} + type FinishReason string const ( @@ -273,6 +303,7 @@ type ChatCompletionChoice struct { // content_filter: Omitted content due to a flag from our content filters // null: API response still in progress or incomplete FinishReason FinishReason `json:"finish_reason"` + LogProbs *LogProbs `json:"logprobs,omitempty"` } // ChatCompletionResponse represents a response structure for chat completion API. From 682b7adb0bd645f290031fbca6028feb5c22ab9c Mon Sep 17 00:00:00 2001 From: Alexander Kledal Date: Thu, 11 Jan 2024 11:45:15 +0100 Subject: [PATCH 103/206] Update README.md (#631) Ensure variables in examples are valid --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4cb77db6b..9a479c0a0 100644 --- a/README.md +++ b/README.md @@ -453,7 +453,7 @@ func main() { config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint") // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function // config.AzureModelMapperFunc = func(model string) string { - // azureModelMapping = map[string]string{ + // azureModelMapping := map[string]string{ // "gpt-3.5-turbo": "your gpt-3.5-turbo deployment name", // } // return azureModelMapping[model] @@ -559,7 +559,7 @@ func main() { //If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function //config.AzureModelMapperFunc = func(model string) string { - // azureModelMapping = map[string]string{ + // azureModelMapping := map[string]string{ // "gpt-3.5-turbo":"your gpt-3.5-turbo deployment name", // } // return azureModelMapping[model] From e01a2d7231fafec2c1cbdd176806e3be767df965 Mon Sep 17 00:00:00 2001 From: Matthew Jaffee Date: Mon, 15 Jan 2024 03:33:02 -0600 Subject: [PATCH 104/206] convert EmbeddingModel to string type (#629) This gives the user the ability to pass in models for embeddings that are not already defined in the library. Also more closely matches how the completions API works. --- embeddings.go | 120 ++++++++------------------------------------- embeddings_test.go | 22 ++------- 2 files changed, 24 insertions(+), 118 deletions(-) diff --git a/embeddings.go b/embeddings.go index 7e2aa7eb0..f79df9df5 100644 --- a/embeddings.go +++ b/embeddings.go @@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. -type EmbeddingModel int - -// String implements the fmt.Stringer interface. -func (e EmbeddingModel) String() string { - return enumToString[e] -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (e EmbeddingModel) MarshalText() ([]byte, error) { - return []byte(e.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// On unrecognized value, it sets |e| to Unknown. -func (e *EmbeddingModel) UnmarshalText(b []byte) error { - if val, ok := stringToEnum[(string(b))]; ok { - *e = val - return nil - } - - *e = Unknown - - return nil -} +type EmbeddingModel string const ( - Unknown EmbeddingModel = iota - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaCodeSearchCode - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaCodeSearchText - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageCodeSearchCode - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageCodeSearchText - AdaEmbeddingV2 + // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + AdaSimilarity EmbeddingModel = "text-similarity-ada-001" + BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" + CurieSimilarity EmbeddingModel = "text-similarity-curie-001" + DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001" + AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001" + AdaSearchQuery EmbeddingModel = "text-search-ada-query-001" + BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001" + BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001" + CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001" + CurieSearchQuery EmbeddingModel = "text-search-curie-query-001" + DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001" + DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001" + AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001" + AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001" + BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001" + BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001" + + AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" ) -var enumToString = map[EmbeddingModel]string{ - AdaSimilarity: "text-similarity-ada-001", - BabbageSimilarity: "text-similarity-babbage-001", - CurieSimilarity: "text-similarity-curie-001", - DavinciSimilarity: "text-similarity-davinci-001", - AdaSearchDocument: "text-search-ada-doc-001", - AdaSearchQuery: "text-search-ada-query-001", - BabbageSearchDocument: "text-search-babbage-doc-001", - BabbageSearchQuery: "text-search-babbage-query-001", - CurieSearchDocument: "text-search-curie-doc-001", - CurieSearchQuery: "text-search-curie-query-001", - DavinciSearchDocument: "text-search-davinci-doc-001", - DavinciSearchQuery: "text-search-davinci-query-001", - AdaCodeSearchCode: "code-search-ada-code-001", - AdaCodeSearchText: "code-search-ada-text-001", - BabbageCodeSearchCode: "code-search-babbage-code-001", - BabbageCodeSearchText: "code-search-babbage-text-001", - AdaEmbeddingV2: "text-embedding-ada-002", -} - -var stringToEnum = map[string]EmbeddingModel{ - "text-similarity-ada-001": AdaSimilarity, - "text-similarity-babbage-001": BabbageSimilarity, - "text-similarity-curie-001": CurieSimilarity, - "text-similarity-davinci-001": DavinciSimilarity, - "text-search-ada-doc-001": AdaSearchDocument, - "text-search-ada-query-001": AdaSearchQuery, - "text-search-babbage-doc-001": BabbageSearchDocument, - "text-search-babbage-query-001": BabbageSearchQuery, - "text-search-curie-doc-001": CurieSearchDocument, - "text-search-curie-query-001": CurieSearchQuery, - "text-search-davinci-doc-001": DavinciSearchDocument, - "text-search-davinci-query-001": DavinciSearchQuery, - "code-search-ada-code-001": AdaCodeSearchCode, - "code-search-ada-text-001": AdaCodeSearchText, - "code-search-babbage-code-001": BabbageCodeSearchCode, - "code-search-babbage-text-001": BabbageCodeSearchText, - "text-embedding-ada-002": AdaEmbeddingV2, -} - // Embedding is a special format of data representation that can be easily utilized by machine // learning models and algorithms. The embedding is an information dense representation of the // semantic meaning of a piece of text. Each embedding is a vector of floating point numbers, @@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq)) if err != nil { return } diff --git a/embeddings_test.go b/embeddings_test.go index af04d96bf..846d1995d 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) { // the AdaSearchQuery type marshaled, err := json.Marshal(embeddingReq) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } @@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) { } marshaled, err = json.Marshal(embeddingReqStrings) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } @@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) { } marshaled, err = json.Marshal(embeddingReqTokens) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } } } -func TestEmbeddingModel(t *testing.T) { - var em openai.EmbeddingModel - err := em.UnmarshalText([]byte("text-similarity-ada-001")) - checks.NoError(t, err, "Could not marshal embedding model") - - if em != openai.AdaSimilarity { - t.Errorf("Model is not equal to AdaSimilarity") - } - - err = em.UnmarshalText([]byte("some-non-existent-model")) - checks.NoError(t, err, "Could not marshal embedding model") - if em != openai.Unknown { - t.Errorf("Model is not equal to Unknown") - } -} - func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() From 09f6920ad04666f65dd86ed542e5ebf8bffc93a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E5=AE=8F=E6=95=8F?= Date: Mon, 15 Jan 2024 20:01:49 +0800 Subject: [PATCH 105/206] fixed #594 (#609) APITypeAzure dall-e3 model url Co-authored-by: HanHongmin --- image.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/image.go b/image.go index afd4e196b..665de1a74 100644 --- a/image.go +++ b/image.go @@ -82,6 +82,7 @@ type ImageEditRequest struct { Image *os.File `json:"image,omitempty"` Mask *os.File `json:"mask,omitempty"` Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` @@ -131,7 +132,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits"), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return @@ -144,6 +145,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // ImageVariRequest represents the request structure for the image API. type ImageVariRequest struct { Image *os.File `json:"image,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` @@ -181,7 +183,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations"), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return From 4ce03a919ae9fdcb62e8098a03500ef77eafe348 Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Tue, 16 Jan 2024 04:32:48 -0500 Subject: [PATCH 106/206] Fix Azure embeddings model detection by passing string to `fullURL` (#637) --- embeddings.go | 2 +- embeddings_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/embeddings.go b/embeddings.go index f79df9df5..c144119f8 100644 --- a/embeddings.go +++ b/embeddings.go @@ -228,7 +228,7 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) if err != nil { return } diff --git a/embeddings_test.go b/embeddings_test.go index 846d1995d..ed6384f3f 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -158,6 +158,32 @@ func TestEmbeddingEndpoint(t *testing.T) { checks.HasError(t, err, "CreateEmbeddings error") } +func TestAzureEmbeddingEndpoint(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + + sampleEmbeddings := []openai.Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + } + + server.RegisterHandler( + "/openai/deployments/text-embedding-ada-002/embeddings", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + // test create embeddings with strings (simple embedding request) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + }) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } +} + func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { type fields struct { Object string From eff8dc1118ea82a1b50ee316608e24d83df74d6b Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Thu, 18 Jan 2024 01:42:07 +0800 Subject: [PATCH 107/206] fix(audio): fix audioTextResponse decode (#638) * fix(audio): fix audioTextResponse decode * test(audio): add audioTextResponse decode test * test(audio): simplify code --- client.go | 10 +++++++--- client_test.go | 48 ++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index 056226c61..8bbbb875a 100644 --- a/client.go +++ b/client.go @@ -193,10 +193,14 @@ func decodeResponse(body io.Reader, v any) error { return nil } - if result, ok := v.(*string); ok { - return decodeString(body, result) + switch o := v.(type) { + case *string: + return decodeString(body, o) + case *audioTextResponse: + return decodeString(body, &o.Text) + default: + return json.NewDecoder(body).Decode(v) } - return json.NewDecoder(body).Decode(v) } func decodeString(body io.Reader, output *string) error { diff --git a/client_test.go b/client_test.go index 664f9fb92..bc5133edc 100644 --- a/client_test.go +++ b/client_test.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "net/http" + "reflect" "testing" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) var errTestRequestBuilderFailed = errors.New("test request builder failed") @@ -43,23 +45,29 @@ func TestDecodeResponse(t *testing.T) { testCases := []struct { name string value interface{} + expected interface{} body io.Reader hasError bool }{ { - name: "nil input", - value: nil, - body: bytes.NewReader([]byte("")), + name: "nil input", + value: nil, + body: bytes.NewReader([]byte("")), + expected: nil, }, { - name: "string input", - value: &stringInput, - body: bytes.NewReader([]byte("test")), + name: "string input", + value: &stringInput, + body: bytes.NewReader([]byte("test")), + expected: "test", }, { name: "map input", value: &map[string]interface{}{}, body: bytes.NewReader([]byte(`{"test": "test"}`)), + expected: map[string]interface{}{ + "test": "test", + }, }, { name: "reader return error", @@ -67,14 +75,38 @@ func TestDecodeResponse(t *testing.T) { body: &errorReader{err: errors.New("dummy")}, hasError: true, }, + { + name: "audio text input", + value: &audioTextResponse{}, + body: bytes.NewReader([]byte("test")), + expected: audioTextResponse{ + Text: "test", + }, + }, + } + + assertEqual := func(t *testing.T, expected, actual interface{}) { + t.Helper() + if expected == actual { + return + } + v := reflect.ValueOf(actual).Elem().Interface() + if !reflect.DeepEqual(v, expected) { + t.Fatalf("Unexpected value: %v, expected: %v", v, expected) + } } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := decodeResponse(tc.body, tc.value) - if (err != nil) != tc.hasError { - t.Errorf("Unexpected error: %v", err) + if tc.hasError { + checks.HasError(t, err, "Unexpected nil error") + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) } + assertEqual(t, tc.expected, tc.value) }) } } From 4c41f24a99ad56f707df7c25b8833fb0a374c8c5 Mon Sep 17 00:00:00 2001 From: Daniil <7709243+bazuker@users.noreply.github.com> Date: Fri, 26 Jan 2024 00:41:48 -0800 Subject: [PATCH 108/206] Support January 25, 2024, models update. (#644) --- completion.go | 6 +++++- embeddings.go | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 2709c8b03..6326a72a8 100644 --- a/completion.go +++ b/completion.go @@ -22,7 +22,9 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" - GPT4TurboPreview = "gpt-4-1106-preview" + GPT4Turbo0125 = "gpt-4-0125-preview" + GPT4Turbo1106 = "gpt-4-1106-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" @@ -78,6 +80,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4TurboPreview: true, GPT4VisionPreview: true, + GPT4Turbo1106: true, + GPT4Turbo0125: true, GPT40314: true, GPT40613: true, GPT432K: true, diff --git a/embeddings.go b/embeddings.go index c144119f8..517027f5a 100644 --- a/embeddings.go +++ b/embeddings.go @@ -34,7 +34,9 @@ const ( BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001" BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001" - AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" + AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" + SmallEmbedding3 EmbeddingModel = "text-embedding-3-small" + LargeEmbedding3 EmbeddingModel = "text-embedding-3-large" ) // Embedding is a special format of data representation that can be easily utilized by machine From 06ff541559eaf66482a89202da946644b6c96510 Mon Sep 17 00:00:00 2001 From: chenhhA <463474838@qq.com> Date: Mon, 29 Jan 2024 15:09:56 +0800 Subject: [PATCH 109/206] Add new struct filed dimensions for embedding API (#645) * add new struct filed dimensions for embedding API * docs: remove long single-line comments * change embedding request param Dimensions type to int --- embeddings.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/embeddings.go b/embeddings.go index 517027f5a..c5633a313 100644 --- a/embeddings.go +++ b/embeddings.go @@ -157,6 +157,9 @@ type EmbeddingRequest struct { Model EmbeddingModel `json:"model"` User string `json:"user"` EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -181,6 +184,9 @@ type EmbeddingRequestStrings struct { // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. // If not specified OpenAI will use "float". EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -189,6 +195,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { Model: r.Model, User: r.User, EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, } } @@ -209,6 +216,9 @@ type EmbeddingRequestTokens struct { // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. // If not specified OpenAI will use "float". EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -217,6 +227,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { Model: r.Model, User: r.User, EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, } } From bc8cdd33d158ea165fcecde4a64fc5f1580f0192 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Fri, 2 Feb 2024 18:30:24 +0800 Subject: [PATCH 110/206] add GPT3Dot5Turbo0125 model (#648) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 6326a72a8..ab1dbd6c5 100644 --- a/completion.go +++ b/completion.go @@ -27,6 +27,7 @@ const ( GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" @@ -75,6 +76,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, GPT3Dot5Turbo1106: true, + GPT3Dot5Turbo0125: true, GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, From bb6ed545306ba56b99d297a77da0a93b0bcfb80e Mon Sep 17 00:00:00 2001 From: shadowpigy <71599610+shadowpigy@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:41:39 +0800 Subject: [PATCH 111/206] Fix: Add RunStatusCancelled (#650) Co-authored-by: shadowpigy --- run.go | 1 + 1 file changed, 1 insertion(+) diff --git a/run.go b/run.go index dbb708a13..d06756572 100644 --- a/run.go +++ b/run.go @@ -40,6 +40,7 @@ const ( RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" RunStatusExpired RunStatus = "expired" + RunStatusCancelled RunStatus = "cancelled" ) type RunRequiredAction struct { From 69e3fcbc2726d208d34e9d89089b47ebebdff01b Mon Sep 17 00:00:00 2001 From: chrbsg <52408325+chrbsg@users.noreply.github.com> Date: Tue, 6 Feb 2024 19:04:40 +0000 Subject: [PATCH 112/206] Fix typo assitantInstructions (#655) --- assistant_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/assistant_test.go b/assistant_test.go index 30daec2b1..9e1e3f38d 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -17,7 +17,7 @@ func TestAssistant(t *testing.T) { assistantID := "asst_abc123" assistantName := "Ambrogio" assistantDescription := "Ambrogio is a friendly assistant." - assitantInstructions := `You are a personal math tutor. + assistantInstructions := `You are a personal math tutor. When asked a question, write and run Python code to answer the question.` assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" limit := 20 @@ -92,7 +92,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Model: openai.GPT4TurboPreview, Description: &assistantDescription, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) fmt.Fprintln(w, string(resBytes)) case http.MethodPost: @@ -152,7 +152,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Model: openai.GPT4TurboPreview, Description: &assistantDescription, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }, }, }) @@ -167,7 +167,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Description: &assistantDescription, Model: openai.GPT4TurboPreview, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) checks.NoError(t, err, "CreateAssistant error") @@ -178,7 +178,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Description: &assistantDescription, Model: openai.GPT4TurboPreview, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) checks.NoError(t, err, "ModifyAssistant error") From 6c2e3162dfe3b32cbd1d026043957f8e589e987c Mon Sep 17 00:00:00 2001 From: "xuanming.zhang" Date: Thu, 8 Feb 2024 15:40:39 +0800 Subject: [PATCH 113/206] Added support for CreateSpeech Azure models (#657) --- speech.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/speech.go b/speech.go index f2442b921..b9344ac66 100644 --- a/speech.go +++ b/speech.go @@ -74,7 +74,7 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) err = ErrInvalidVoice return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", request.Model), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), withContentType("application/json; charset=utf-8"), ) From a7954c854c89f45d3f5df62aab8df688b4c20b20 Mon Sep 17 00:00:00 2001 From: shadowpigy <71599610+shadowpigy@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:08:30 +0800 Subject: [PATCH 114/206] Feat: Add assistant usage (#649) * Feat: Add assistant usage --------- Co-authored-by: shadowpigy --- run.go | 1 + 1 file changed, 1 insertion(+) diff --git a/run.go b/run.go index d06756572..4befe0b44 100644 --- a/run.go +++ b/run.go @@ -26,6 +26,7 @@ type Run struct { Tools []Tool `json:"tools"` FileIDS []string `json:"file_ids"` Metadata map[string]any `json:"metadata"` + Usage Usage `json:"usage,omitempty"` httpHeader } From 11ad4b69d0f0dc61ed8777ac2d54a6787c8d2fea Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:02:48 +0400 Subject: [PATCH 115/206] make linter happy (#661) --- embeddings_test.go | 2 +- files_api_test.go | 10 +++++----- image_test.go | 10 +++++----- messages.go | 4 ++-- models_test.go | 2 +- run.go | 2 +- stream_test.go | 14 +++++++------- 7 files changed, 22 insertions(+), 22 deletions(-) diff --git a/embeddings_test.go b/embeddings_test.go index ed6384f3f..438978169 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -169,7 +169,7 @@ func TestAzureEmbeddingEndpoint(t *testing.T) { server.RegisterHandler( "/openai/deployments/text-embedding-ada-002/embeddings", - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) fmt.Fprintln(w, string(resBytes)) }, diff --git a/files_api_test.go b/files_api_test.go index 6f62a3fbc..c92162a84 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -86,7 +86,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { func TestDeleteFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {}) + server.RegisterHandler("/v1/files/deadbeef", func(http.ResponseWriter, *http.Request) {}) err := client.DeleteFile(context.Background(), "deadbeef") checks.NoError(t, err, "DeleteFile error") } @@ -94,7 +94,7 @@ func TestDeleteFile(t *testing.T) { func TestListFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.FilesList{}) fmt.Fprintln(w, string(resBytes)) }) @@ -105,7 +105,7 @@ func TestListFile(t *testing.T) { func TestGetFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.File{}) fmt.Fprintln(w, string(resBytes)) }) @@ -151,7 +151,7 @@ func TestGetFileContentReturnError(t *testing.T) { }` client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, wantErrorResp) }) @@ -178,7 +178,7 @@ func TestGetFileContentReturnError(t *testing.T) { func TestGetFileContentReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef/content", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() diff --git a/image_test.go b/image_test.go index 81fff6cba..9332dd5cd 100644 --- a/image_test.go +++ b/image_test.go @@ -60,7 +60,7 @@ func TestImageFormBuilderFailures(t *testing.T) { _, err := client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { if name == "mask" { return mockFailedErr } @@ -69,12 +69,12 @@ func TestImageFormBuilderFailures(t *testing.T) { _, err = client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(string, *os.File) error { return nil } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } @@ -125,12 +125,12 @@ func TestVariImageFormBuilderFailures(t *testing.T) { _, err := client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(string, *os.File) error { return nil } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } diff --git a/messages.go b/messages.go index ead247f5b..861463235 100644 --- a/messages.go +++ b/messages.go @@ -18,7 +18,7 @@ type Message struct { ThreadID string `json:"thread_id"` Role string `json:"role"` Content []MessageContent `json:"content"` - FileIds []string `json:"file_ids"` + FileIds []string `json:"file_ids"` //nolint:revive //backwards-compatibility AssistantID *string `json:"assistant_id,omitempty"` RunID *string `json:"run_id,omitempty"` Metadata map[string]any `json:"metadata"` @@ -54,7 +54,7 @@ type ImageFile struct { type MessageRequest struct { Role string `json:"role"` Content string `json:"content"` - FileIds []string `json:"file_ids,omitempty"` + FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility Metadata map[string]any `json:"metadata,omitempty"` } diff --git a/models_test.go b/models_test.go index 4a4c759dc..24a28ed23 100644 --- a/models_test.go +++ b/models_test.go @@ -64,7 +64,7 @@ func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { func TestGetModelReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/models/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/models/text-davinci-003", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() diff --git a/run.go b/run.go index 4befe0b44..ba09366cb 100644 --- a/run.go +++ b/run.go @@ -24,7 +24,7 @@ type Run struct { Model string `json:"model"` Instructions string `json:"instructions,omitempty"` Tools []Tool `json:"tools"` - FileIDS []string `json:"file_ids"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility Metadata map[string]any `json:"metadata"` Usage Usage `json:"usage,omitempty"` diff --git a/stream_test.go b/stream_test.go index 35c52ae3b..2822a3535 100644 --- a/stream_test.go +++ b/stream_test.go @@ -34,7 +34,7 @@ func TestCompletionsStreamWrongModel(t *testing.T) { func TestCreateCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -106,7 +106,7 @@ func TestCreateCompletionStream(t *testing.T) { func TestCreateCompletionStreamError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -151,7 +151,7 @@ func TestCreateCompletionStreamError(t *testing.T) { func TestCreateCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -182,7 +182,7 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -228,7 +228,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -263,7 +263,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -305,7 +305,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() From 66bae3ee7329619b27ba8bcb185e0d333e9b3e26 Mon Sep 17 00:00:00 2001 From: grulex Date: Thu, 15 Feb 2024 16:11:58 +0000 Subject: [PATCH 116/206] Content-type fix (#659) * charset fixes * make linter happy (#661) --------- Co-authored-by: grulex Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- client.go | 4 ++-- speech.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 8bbbb875a..55c48bd47 100644 --- a/client.go +++ b/client.go @@ -107,13 +107,13 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... } func (c *Client) sendRequest(req *http.Request, v Response) error { - req.Header.Set("Accept", "application/json; charset=utf-8") + req.Header.Set("Accept", "application/json") // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data contentType := req.Header.Get("Content-Type") if contentType == "" { - req.Header.Set("Content-Type", "application/json; charset=utf-8") + req.Header.Set("Content-Type", "application/json") } res, err := c.config.HTTPClient.Do(req) diff --git a/speech.go b/speech.go index b9344ac66..be8950218 100644 --- a/speech.go +++ b/speech.go @@ -76,7 +76,7 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), - withContentType("application/json; charset=utf-8"), + withContentType("application/json"), ) if err != nil { return From ff61bbb32253aad84c6cc96bf9be3884aa8cde88 Mon Sep 17 00:00:00 2001 From: chrbsg <52408325+chrbsg@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:12:22 +0000 Subject: [PATCH 117/206] Add RunRequest field AdditionalInstructions (#656) AdditionalInstructions is an optional string field used to append additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions. Also, change the Model and Instructions *string fields to string. --- run.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/run.go b/run.go index ba09366cb..1f3cb7eb7 100644 --- a/run.go +++ b/run.go @@ -72,11 +72,12 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model *string `json:"model,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type RunModifyRequest struct { From 69e3bbb1eb05a5c1b27a29fc9a83d02d0d040e27 Mon Sep 17 00:00:00 2001 From: Igor Berlenko Date: Fri, 16 Feb 2024 18:22:38 +0800 Subject: [PATCH 118/206] Update client.go - allow to skip Authorization header (#658) * Update client.go - allow to skip Authorization header * Update client.go --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 55c48bd47..7fdc36caa 100644 --- a/client.go +++ b/client.go @@ -175,7 +175,7 @@ func (c *Client) setCommonHeaders(req *http.Request) { // Azure API Key authentication if c.config.APIType == APITypeAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { + } else if c.config.authToken != "" { // OpenAI or Azure AD authentication req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) } From e8b347891b21187740d594409b1c11fb0846577e Mon Sep 17 00:00:00 2001 From: CaoPengFlying Date: Mon, 19 Feb 2024 20:26:04 +0800 Subject: [PATCH 119/206] fix:fix open ai original validation. modify Tool's Function to pointer (#664) Co-authored-by: caopengfei1 --- chat.go | 4 ++-- examples/completion-with-tool/main.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/chat.go b/chat.go index 33b8755ce..efb14fd4c 100644 --- a/chat.go +++ b/chat.go @@ -225,8 +225,8 @@ const ( ) type Tool struct { - Type ToolType `json:"type"` - Function FunctionDefinition `json:"function,omitempty"` + Type ToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` } type ToolChoice struct { diff --git a/examples/completion-with-tool/main.go b/examples/completion-with-tool/main.go index 2c7fedc5e..26126e41b 100644 --- a/examples/completion-with-tool/main.go +++ b/examples/completion-with-tool/main.go @@ -35,7 +35,7 @@ func main() { } t := openai.Tool{ Type: openai.ToolTypeFunction, - Function: f, + Function: &f, } // simulate user asking a question that requires the function From 7381d18a75a673d569c7dc7657407381e5c84dd5 Mon Sep 17 00:00:00 2001 From: Rich Coggins <57115183+coggsflod@users.noreply.github.com> Date: Wed, 21 Feb 2024 07:45:15 -0500 Subject: [PATCH 120/206] Fix for broken Azure Assistants url (#665) * fix:fix url for Azure assistants api * test:add unit tests for Azure Assistants api * fix:minor liniting issue --- assistant_test.go | 190 ++++++++++++++++++++++++++++++++++++++++++++++ client.go | 2 +- 2 files changed, 191 insertions(+), 1 deletion(-) diff --git a/assistant_test.go b/assistant_test.go index 9e1e3f38d..48bc6f91d 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -202,3 +202,193 @@ When asked a question, write and run Python code to answer the question.` err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) checks.NoError(t, err, "DeleteAssistantFile error") } + +func TestAzureAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assistantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + LastID: &assistantID, + FirstID: &assistantID, + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + + _, err = client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + + _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + _, err = client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + + _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + + _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + + _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + + _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + + err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") +} diff --git a/client.go b/client.go index 7fdc36caa..e7a4d5beb 100644 --- a/client.go +++ b/client.go @@ -221,7 +221,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { baseURL = strings.TrimRight(baseURL, "/") // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if strings.Contains(suffix, "/models") { + if strings.Contains(suffix, "/models") || strings.Contains(suffix, "/assistants") { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } azureDeploymentName := "UNKNOWN" From c5401e9e6417ac2b5374993ccff1f40010e03f52 Mon Sep 17 00:00:00 2001 From: Rich Coggins <57115183+coggsflod@users.noreply.github.com> Date: Mon, 26 Feb 2024 03:46:35 -0500 Subject: [PATCH 121/206] Fix for broken Azure Threads url (#668) --- client.go | 11 ++++++- thread_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index e7a4d5beb..7b1a313a8 100644 --- a/client.go +++ b/client.go @@ -221,7 +221,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { baseURL = strings.TrimRight(baseURL, "/") // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if strings.Contains(suffix, "/models") || strings.Contains(suffix, "/assistants") { + if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } azureDeploymentName := "UNKNOWN" @@ -258,3 +258,12 @@ func (c *Client) handleErrorResp(resp *http.Response) error { errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } + +func containsSubstr(s []string, e string) bool { + for _, v := range s { + if strings.Contains(e, v) { + return true + } + } + return false +} diff --git a/thread_test.go b/thread_test.go index 227ab6330..1ac0f3c0e 100644 --- a/thread_test.go +++ b/thread_test.go @@ -93,3 +93,86 @@ func TestThread(t *testing.T) { _, err = client.DeleteThread(ctx, threadID) checks.NoError(t, err, "DeleteThread error") } + +// TestAzureThread Tests the thread endpoint of the API using the Azure mocked server. +func TestAzureThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +} From f2204439857a1085207e74c8f05abf6c8248d336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Oester?= <56402078+raphoester@users.noreply.github.com> Date: Mon, 26 Feb 2024 10:48:09 +0200 Subject: [PATCH 122/206] Added fields for moderation (#662) --- moderation.go | 36 ++++++++++++++++++++++-------------- moderation_test.go | 43 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/moderation.go b/moderation.go index f8d20ee51..45d05248e 100644 --- a/moderation.go +++ b/moderation.go @@ -44,24 +44,32 @@ type Result struct { // ResultCategories represents Categories of Result. type ResultCategories struct { - Hate bool `json:"hate"` - HateThreatening bool `json:"hate/threatening"` - SelfHarm bool `json:"self-harm"` - Sexual bool `json:"sexual"` - SexualMinors bool `json:"sexual/minors"` - Violence bool `json:"violence"` - ViolenceGraphic bool `json:"violence/graphic"` + Hate bool `json:"hate"` + HateThreatening bool `json:"hate/threatening"` + Harassment bool `json:"harassment"` + HarassmentThreatening bool `json:"harassment/threatening"` + SelfHarm bool `json:"self-harm"` + SelfHarmIntent bool `json:"self-harm/intent"` + SelfHarmInstructions bool `json:"self-harm/instructions"` + Sexual bool `json:"sexual"` + SexualMinors bool `json:"sexual/minors"` + Violence bool `json:"violence"` + ViolenceGraphic bool `json:"violence/graphic"` } // ResultCategoryScores represents CategoryScores of Result. type ResultCategoryScores struct { - Hate float32 `json:"hate"` - HateThreatening float32 `json:"hate/threatening"` - SelfHarm float32 `json:"self-harm"` - Sexual float32 `json:"sexual"` - SexualMinors float32 `json:"sexual/minors"` - Violence float32 `json:"violence"` - ViolenceGraphic float32 `json:"violence/graphic"` + Hate bool `json:"hate"` + HateThreatening bool `json:"hate/threatening"` + Harassment bool `json:"harassment"` + HarassmentThreatening bool `json:"harassment/threatening"` + SelfHarm bool `json:"self-harm"` + SelfHarmIntent bool `json:"self-harm/intent"` + SelfHarmInstructions bool `json:"self-harm/instructions"` + Sexual bool `json:"sexual"` + SexualMinors bool `json:"sexual/minors"` + Violence bool `json:"violence"` + ViolenceGraphic bool `json:"violence/graphic"` } // ModerationResponse represents a response structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 059f0d1c7..7fdeb9baf 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -80,18 +80,49 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { resCat := openai.ResultCategories{} resCatScore := openai.ResultCategoryScores{} switch { - case strings.Contains(moderationReq.Input, "kill"): - resCat = openai.ResultCategories{Violence: true} - resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "hate"): resCat = openai.ResultCategories{Hate: true} - resCatScore = openai.ResultCategoryScores{Hate: 1} + resCatScore = openai.ResultCategoryScores{Hate: true} + + case strings.Contains(moderationReq.Input, "hate more"): + resCat = openai.ResultCategories{HateThreatening: true} + resCatScore = openai.ResultCategoryScores{HateThreatening: true} + + case strings.Contains(moderationReq.Input, "harass"): + resCat = openai.ResultCategories{Harassment: true} + resCatScore = openai.ResultCategoryScores{Harassment: true} + + case strings.Contains(moderationReq.Input, "harass hard"): + resCat = openai.ResultCategories{Harassment: true} + resCatScore = openai.ResultCategoryScores{HarassmentThreatening: true} + case strings.Contains(moderationReq.Input, "suicide"): resCat = openai.ResultCategories{SelfHarm: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: 1} + resCatScore = openai.ResultCategoryScores{SelfHarm: true} + + case strings.Contains(moderationReq.Input, "wanna suicide"): + resCat = openai.ResultCategories{SelfHarmIntent: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: true} + + case strings.Contains(moderationReq.Input, "drink bleach"): + resCat = openai.ResultCategories{SelfHarmInstructions: true} + resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: true} + case strings.Contains(moderationReq.Input, "porn"): resCat = openai.ResultCategories{Sexual: true} - resCatScore = openai.ResultCategoryScores{Sexual: 1} + resCatScore = openai.ResultCategoryScores{Sexual: true} + + case strings.Contains(moderationReq.Input, "child porn"): + resCat = openai.ResultCategories{SexualMinors: true} + resCatScore = openai.ResultCategoryScores{SexualMinors: true} + + case strings.Contains(moderationReq.Input, "kill"): + resCat = openai.ResultCategories{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: true} + + case strings.Contains(moderationReq.Input, "corpse"): + resCat = openai.ResultCategories{ViolenceGraphic: true} + resCatScore = openai.ResultCategoryScores{ViolenceGraphic: true} } result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} From 41037783bc7668998900248ed697b90ec36c3f09 Mon Sep 17 00:00:00 2001 From: Guillaume Dussault <146769929+guillaume-dussault@users.noreply.github.com> Date: Mon, 26 Feb 2024 03:48:53 -0500 Subject: [PATCH 123/206] fix: when no Assistant Tools are specified, an empty list should be sent (#669) --- assistant.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assistant.go b/assistant.go index bd335833a..7a7a7652e 100644 --- a/assistant.go +++ b/assistant.go @@ -46,7 +46,7 @@ type AssistantRequest struct { Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools,omitempty"` + Tools []AssistantTool `json:"tools"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } From bb6149f64fcb22381b2ef0b5c7d8287a520dc110 Mon Sep 17 00:00:00 2001 From: Martin Heck Date: Wed, 28 Feb 2024 10:25:47 +0100 Subject: [PATCH 124/206] fix: repair json decoding of moderation response (#670) --- moderation.go | 22 +++++++++++----------- moderation_test.go | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/moderation.go b/moderation.go index 45d05248e..ae285ef83 100644 --- a/moderation.go +++ b/moderation.go @@ -59,17 +59,17 @@ type ResultCategories struct { // ResultCategoryScores represents CategoryScores of Result. type ResultCategoryScores struct { - Hate bool `json:"hate"` - HateThreatening bool `json:"hate/threatening"` - Harassment bool `json:"harassment"` - HarassmentThreatening bool `json:"harassment/threatening"` - SelfHarm bool `json:"self-harm"` - SelfHarmIntent bool `json:"self-harm/intent"` - SelfHarmInstructions bool `json:"self-harm/instructions"` - Sexual bool `json:"sexual"` - SexualMinors bool `json:"sexual/minors"` - Violence bool `json:"violence"` - ViolenceGraphic bool `json:"violence/graphic"` + Hate float32 `json:"hate"` + HateThreatening float32 `json:"hate/threatening"` + Harassment float32 `json:"harassment"` + HarassmentThreatening float32 `json:"harassment/threatening"` + SelfHarm float32 `json:"self-harm"` + SelfHarmIntent float32 `json:"self-harm/intent"` + SelfHarmInstructions float32 `json:"self-harm/instructions"` + Sexual float32 `json:"sexual"` + SexualMinors float32 `json:"sexual/minors"` + Violence float32 `json:"violence"` + ViolenceGraphic float32 `json:"violence/graphic"` } // ModerationResponse represents a response structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 7fdeb9baf..61171c384 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -82,47 +82,47 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(moderationReq.Input, "hate"): resCat = openai.ResultCategories{Hate: true} - resCatScore = openai.ResultCategoryScores{Hate: true} + resCatScore = openai.ResultCategoryScores{Hate: 1} case strings.Contains(moderationReq.Input, "hate more"): resCat = openai.ResultCategories{HateThreatening: true} - resCatScore = openai.ResultCategoryScores{HateThreatening: true} + resCatScore = openai.ResultCategoryScores{HateThreatening: 1} case strings.Contains(moderationReq.Input, "harass"): resCat = openai.ResultCategories{Harassment: true} - resCatScore = openai.ResultCategoryScores{Harassment: true} + resCatScore = openai.ResultCategoryScores{Harassment: 1} case strings.Contains(moderationReq.Input, "harass hard"): resCat = openai.ResultCategories{Harassment: true} - resCatScore = openai.ResultCategoryScores{HarassmentThreatening: true} + resCatScore = openai.ResultCategoryScores{HarassmentThreatening: 1} case strings.Contains(moderationReq.Input, "suicide"): resCat = openai.ResultCategories{SelfHarm: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "wanna suicide"): resCat = openai.ResultCategories{SelfHarmIntent: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "drink bleach"): resCat = openai.ResultCategories{SelfHarmInstructions: true} - resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: true} + resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: 1} case strings.Contains(moderationReq.Input, "porn"): resCat = openai.ResultCategories{Sexual: true} - resCatScore = openai.ResultCategoryScores{Sexual: true} + resCatScore = openai.ResultCategoryScores{Sexual: 1} case strings.Contains(moderationReq.Input, "child porn"): resCat = openai.ResultCategories{SexualMinors: true} - resCatScore = openai.ResultCategoryScores{SexualMinors: true} + resCatScore = openai.ResultCategoryScores{SexualMinors: 1} case strings.Contains(moderationReq.Input, "kill"): resCat = openai.ResultCategories{Violence: true} - resCatScore = openai.ResultCategoryScores{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "corpse"): resCat = openai.ResultCategories{ViolenceGraphic: true} - resCatScore = openai.ResultCategoryScores{ViolenceGraphic: true} + resCatScore = openai.ResultCategoryScores{ViolenceGraphic: 1} } result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} From 38b16a3c413a3ea076cf4082ea5cd1754b72c70f Mon Sep 17 00:00:00 2001 From: Bilal Hameed <68427058+LinuxSploit@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:56:50 +0500 Subject: [PATCH 125/206] Added 'wav' and 'pcm' Audio Formats (#671) * Added 'wav' and 'pcm' Audio Formats Added "wav" and "pcm" audio formats as per OpenAI API documentation for createSpeech endpoint. Ref: https://platform.openai.com/docs/api-reference/audio/createSpeech Supported formats are mp3, opus, aac, flac, wav, and pcm. * Removed Extra Newline for Sanity Check * fix: run goimports to get accepted by the linter --- speech.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/speech.go b/speech.go index be8950218..92b30b55b 100644 --- a/speech.go +++ b/speech.go @@ -33,6 +33,8 @@ const ( SpeechResponseFormatOpus SpeechResponseFormat = "opus" SpeechResponseFormatAac SpeechResponseFormat = "aac" SpeechResponseFormatFlac SpeechResponseFormat = "flac" + SpeechResponseFormatWav SpeechResponseFormat = "wav" + SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) var ( From 699f397c36d05e42210f65456436a447885cc07a Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 11 Mar 2024 15:27:48 +0800 Subject: [PATCH 126/206] Update streamReader Close() method to return error (#681) --- stream_reader.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index d17412591..4210a1948 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -108,6 +108,6 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { return } -func (stream *streamReader[T]) Close() { - stream.response.Body.Close() +func (stream *streamReader[T]) Close() error { + return stream.response.Body.Close() } From 0925563e86c2fdc5011310aa616ba493989cfe0a Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Fri, 15 Mar 2024 18:59:16 +0800 Subject: [PATCH 127/206] Fix broken implementation AssistantModify implementation (#685) * add custom marshaller, documentation and isolate tests * fix linter --- assistant.go | 30 ++++++++++++- assistant_test.go | 109 ++++++++++++++++++++++++++++++++++------------ 2 files changed, 109 insertions(+), 30 deletions(-) diff --git a/assistant.go b/assistant.go index 7a7a7652e..4ca2dda62 100644 --- a/assistant.go +++ b/assistant.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -21,7 +22,7 @@ type Assistant struct { Description *string `json:"description,omitempty"` Model string `json:"model"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools,omitempty"` + Tools []AssistantTool `json:"tools"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` @@ -41,16 +42,41 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +// AssistantRequest provides the assistant request parameters. +// When modifying the tools the API functions as the following: +// If Tools is undefined, no changes are made to the Assistant's tools. +// If Tools is empty slice it will effectively delete all of the Assistant's tools. +// If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { Model string `json:"model"` Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` + Tools []AssistantTool `json:"-"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } +// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases +// If Tools is nil, the field is omitted from the JSON. +// If Tools is an empty slice, it's included in the JSON as an empty array ([]). +// If Tools is populated, it's included in the JSON with the elements. +func (a AssistantRequest) MarshalJSON() ([]byte, error) { + type Alias AssistantRequest + assistantAlias := &struct { + Tools *[]AssistantTool `json:"tools,omitempty"` + *Alias + }{ + Alias: (*Alias)(&a), + } + + if a.Tools != nil { + assistantAlias.Tools = &a.Tools + } + + return json.Marshal(assistantAlias) +} + // AssistantsList is a list of assistants. type AssistantsList struct { Assistants []Assistant `json:"data"` diff --git a/assistant_test.go b/assistant_test.go index 48bc6f91d..40de0e50f 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -96,7 +96,7 @@ When asked a question, write and run Python code to answer the question.` }) fmt.Fprintln(w, string(resBytes)) case http.MethodPost: - var request openai.AssistantRequest + var request openai.Assistant err := json.NewDecoder(r.Body).Decode(&request) checks.NoError(t, err, "Decode error") @@ -163,44 +163,97 @@ When asked a question, write and run Python code to answer the question.` ctx := context.Background() - _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ - Name: &assistantName, - Description: &assistantDescription, - Model: openai.GPT4TurboPreview, - Instructions: &assistantInstructions, + t.Run("create_assistant", func(t *testing.T) { + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") }) - checks.NoError(t, err, "CreateAssistant error") - _, err = client.RetrieveAssistant(ctx, assistantID) - checks.NoError(t, err, "RetrieveAssistant error") + t.Run("retrieve_assistant", func(t *testing.T) { + _, err := client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + }) - _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ - Name: &assistantName, - Description: &assistantDescription, - Model: openai.GPT4TurboPreview, - Instructions: &assistantInstructions, + t.Run("delete_assistant", func(t *testing.T) { + _, err := client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") }) - checks.NoError(t, err, "ModifyAssistant error") - _, err = client.DeleteAssistant(ctx, assistantID) - checks.NoError(t, err, "DeleteAssistant error") + t.Run("list_assistant", func(t *testing.T) { + _, err := client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + }) - _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) - checks.NoError(t, err, "ListAssistants error") + t.Run("create_assistant_file", func(t *testing.T) { + _, err := client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + }) - _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ - FileID: assistantFileID, + t.Run("list_assistant_files", func(t *testing.T) { + _, err := client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") }) - checks.NoError(t, err, "CreateAssistantFile error") - _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) - checks.NoError(t, err, "ListAssistantFiles error") + t.Run("retrieve_assistant_file", func(t *testing.T) { + _, err := client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + }) - _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) - checks.NoError(t, err, "RetrieveAssistantFile error") + t.Run("delete_assistant_file", func(t *testing.T) { + err := client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") + }) - err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) - checks.NoError(t, err, "DeleteAssistantFile error") + t.Run("modify_assistant_no_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools != nil { + t.Errorf("expected nil got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_with_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: []openai.AssistantTool{{Type: openai.AssistantToolTypeFunction}}, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil || len(assistant.Tools) != 1 { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_empty_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: make([]openai.AssistantTool, 0), + }) + + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) } func TestAzureAssistant(t *testing.T) { From 2646bce71c0cc907e2a3d050130b712c1e5688db Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sat, 6 Apr 2024 03:15:54 +0800 Subject: [PATCH 128/206] feat: get header from sendRequestRaw (#694) * feat: get header from sendRequestRaw * Fix ci lint --- client.go | 15 ++++++++++++--- files.go | 6 ++---- speech.go | 7 ++----- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 7b1a313a8..9a1c8958d 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,12 @@ func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { return newRateLimitHeaders(h.Header()) } +type RawResponse struct { + io.ReadCloser + + httpHeader +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -134,8 +140,8 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { return decodeResponse(res.Body, v) } -func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { - resp, err := c.config.HTTPClient.Do(req) +func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err error) { + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body should be closed by outer function if err != nil { return } @@ -144,7 +150,10 @@ func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err erro err = c.handleErrorResp(resp) return } - return resp.Body, nil + + response.SetHeader(resp.Header) + response.ReadCloser = resp.Body + return } func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { diff --git a/files.go b/files.go index a37d45f18..b40a44f15 100644 --- a/files.go +++ b/files.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "io" "net/http" "os" ) @@ -159,13 +158,12 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err return } -func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { +func (c *Client) GetFileContent(ctx context.Context, fileID string) (content RawResponse, err error) { urlSuffix := fmt.Sprintf("/files/%s/content", fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } - content, err = c.sendRequestRaw(req) - return + return c.sendRequestRaw(req) } diff --git a/speech.go b/speech.go index 92b30b55b..7e22e755c 100644 --- a/speech.go +++ b/speech.go @@ -3,7 +3,6 @@ package openai import ( "context" "errors" - "io" "net/http" ) @@ -67,7 +66,7 @@ func isValidVoice(voice SpeechVoice) bool { return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) } -func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { if !isValidSpeechModel(request.Model) { err = ErrInvalidSpeechModel return @@ -84,7 +83,5 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) return } - response, err = c.sendRequestRaw(req) - - return + return c.sendRequestRaw(req) } From 774fc9dd12ed60c10a9f9f03319ddb9cd5f8780c Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 5 Apr 2024 23:24:30 +0400 Subject: [PATCH 129/206] make linter happy (#701) --- fine_tunes.go | 1 - 1 file changed, 1 deletion(-) diff --git a/fine_tunes.go b/fine_tunes.go index 46f89f165..ca840781c 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,7 +115,6 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - //nolint:goconst // Decreases readability req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return From 187f4169f8898d78716f7944d87e5d95aa9a7c41 Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Tue, 9 Apr 2024 16:22:31 +0800 Subject: [PATCH 130/206] [BREAKING_CHANGES] Fix update message payload (#699) * add custom marshaller, documentation and isolate tests * fix linter * wrap payload as expected from the API and update test * modify input to accept map[string]string only --- messages.go | 4 ++-- messages_test.go | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/messages.go b/messages.go index 861463235..6fd0adbc9 100644 --- a/messages.go +++ b/messages.go @@ -139,11 +139,11 @@ func (c *Client) RetrieveMessage( func (c *Client) ModifyMessage( ctx context.Context, threadID, messageID string, - metadata map[string]any, + metadata map[string]string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(metadata), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) if err != nil { return } diff --git a/messages_test.go b/messages_test.go index 9168d6ccf..a18be20bd 100644 --- a/messages_test.go +++ b/messages_test.go @@ -68,6 +68,10 @@ func TestMessages(t *testing.T) { metadata := map[string]any{} err := json.NewDecoder(r.Body).Decode(&metadata) checks.NoError(t, err, "unable to decode metadata in modify message call") + payload, ok := metadata["metadata"].(map[string]any) + if !ok { + t.Fatalf("metadata payload improperly wrapped %+v", metadata) + } resBytes, _ := json.Marshal( openai.Message{ @@ -86,8 +90,9 @@ func TestMessages(t *testing.T) { FileIds: nil, AssistantID: &emptyStr, RunID: &emptyStr, - Metadata: metadata, + Metadata: payload, }) + fmt.Fprintln(w, string(resBytes)) case http.MethodGet: resBytes, _ := json.Marshal( @@ -212,7 +217,7 @@ func TestMessages(t *testing.T) { } msg, err = client.ModifyMessage(ctx, threadID, messageID, - map[string]any{ + map[string]string{ "foo": "bar", }) checks.NoError(t, err, "ModifyMessage error") From e0d0801ac73cdc87d1b56ced0a0eb71e574546c3 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 11 Apr 2024 16:39:10 +0800 Subject: [PATCH 131/206] feat: add GPT4Turbo and GPT4Turbo20240409 (#703) --- completion.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/completion.go b/completion.go index ab1dbd6c5..00f43ff1c 100644 --- a/completion.go +++ b/completion.go @@ -22,6 +22,8 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" GPT4Turbo1106 = "gpt-4-1106-preview" GPT4TurboPreview = "gpt-4-turbo-preview" @@ -84,6 +86,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4VisionPreview: true, GPT4Turbo1106: true, GPT4Turbo0125: true, + GPT4Turbo: true, + GPT4Turbo20240409: true, GPT40314: true, GPT40613: true, GPT432K: true, From ea551f422e5f38a0afc7d938eea5cff1f69494c5 Mon Sep 17 00:00:00 2001 From: Andreas Deininger Date: Sat, 13 Apr 2024 13:32:38 +0200 Subject: [PATCH 132/206] Fixing typos (#706) --- README.md | 2 +- assistant.go | 4 ++-- client_test.go | 2 +- error.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9a479c0a0..7946f4d9b 100644 --- a/README.md +++ b/README.md @@ -636,7 +636,7 @@ FunctionDefinition{ }, "unit": { Type: jsonschema.String, - Enum: []string{"celcius", "fahrenheit"}, + Enum: []string{"celsius", "fahrenheit"}, }, }, Required: []string{"location"}, diff --git a/assistant.go b/assistant.go index 4ca2dda62..9415325f8 100644 --- a/assistant.go +++ b/assistant.go @@ -181,7 +181,7 @@ func (c *Client) ListAssistants( order *string, after *string, before *string, -) (reponse AssistantsList, err error) { +) (response AssistantsList, err error) { urlValues := url.Values{} if limit != nil { urlValues.Add("limit", fmt.Sprintf("%d", *limit)) @@ -208,7 +208,7 @@ func (c *Client) ListAssistants( return } - err = c.sendRequest(req, &reponse) + err = c.sendRequest(req, &response) return } diff --git a/client_test.go b/client_test.go index bc5133edc..a08d10f21 100644 --- a/client_test.go +++ b/client_test.go @@ -406,7 +406,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } } -func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { +func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { config := DefaultConfig(test.GetTestToken()) client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} diff --git a/error.go b/error.go index b2d01e22e..37959a272 100644 --- a/error.go +++ b/error.go @@ -23,7 +23,7 @@ type InnerError struct { ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } -// RequestError provides informations about generic request errors. +// RequestError provides information about generic request errors. type RequestError struct { HTTPStatusCode int Err error From 2446f08f94b2750287c40bb9593377f349f5578e Mon Sep 17 00:00:00 2001 From: Andreas Deininger Date: Sat, 13 Apr 2024 13:34:23 +0200 Subject: [PATCH 133/206] Bump GitHub workflow actions to latest versions (#707) --- .github/workflows/close-inactive-issues.yml | 2 +- .github/workflows/pr.yml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml index bfe9b5c96..32723c4e9 100644 --- a/.github/workflows/close-inactive-issues.yml +++ b/.github/workflows/close-inactive-issues.yml @@ -10,7 +10,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v9 with: days-before-issue-stale: 30 days-before-issue-close: 14 diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 8df721f0f..a41fff92f 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -9,19 +9,19 @@ jobs: name: Sanity check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: - go-version: '1.19' + go-version: '1.21' - name: Run vet run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: latest - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v . - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 From a42f51967f5c2f8462f8d8dfd25f7d6a8d7a46fc Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Wed, 17 Apr 2024 03:26:14 +0800 Subject: [PATCH 134/206] [New_Features] Adds recently added Assistant cost saving parameters (#710) * add cost saving parameters * add periods at the end of comments * shorten commnet * further lower comment length * fix type --- run.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/run.go b/run.go index 1f3cb7eb7..7c14779c5 100644 --- a/run.go +++ b/run.go @@ -28,6 +28,16 @@ type Run struct { Metadata map[string]any `json:"metadata"` Usage Usage `json:"usage,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + httpHeader } @@ -78,8 +88,42 @@ type RunRequest struct { AdditionalInstructions string `json:"additional_instructions,omitempty"` Tools []Tool `json:"tools,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` + + // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. + // lower values are more focused and deterministic. + Temperature *float32 `json:"temperature,omitempty"` + + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` } +// ThreadTruncationStrategy defines the truncation strategy to use for the thread. +// https://platform.openai.com/docs/assistants/how-it-works/truncation-strategy. +type ThreadTruncationStrategy struct { + // default 'auto'. + Type TruncationStrategy `json:"type,omitempty"` + // this field should be set if the truncation strategy is set to LastMessages. + LastMessages *int `json:"last_messages,omitempty"` +} + +// TruncationStrategy defines the existing truncation strategies existing for thread management in an assistant. +type TruncationStrategy string + +const ( + // TruncationStrategyAuto messages in the middle of the thread will be dropped to fit the context length of the model. + TruncationStrategyAuto = TruncationStrategy("auto") + // TruncationStrategyLastMessages the thread will be truncated to the n most recent messages in the thread. + TruncationStrategyLastMessages = TruncationStrategy("last_messages") +) + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From c6a63ed19aeb0e91facc5409c5a08612db550fb2 Mon Sep 17 00:00:00 2001 From: Mike Chaykowsky Date: Tue, 16 Apr 2024 12:28:06 -0700 Subject: [PATCH 135/206] Add PromptFilterResult (#702) --- chat_stream.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 57cfa789f..6ff7078e2 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -19,13 +19,19 @@ type ChatCompletionStreamChoice struct { ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } +type PromptFilterResult struct { + Index int `json:"index"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionStreamChoice `json:"choices"` - PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } // ChatCompletionStream From 8d15a377ec4fa3aaf2e706cd1e2ad986dd6b8242 Mon Sep 17 00:00:00 2001 From: Danai Antoniou <32068609+danai-antoniou@users.noreply.github.com> Date: Wed, 24 Apr 2024 12:59:50 +0100 Subject: [PATCH 136/206] Remove hardcoded assistants version (#719) --- assistant.go | 19 +++++++++---------- client.go | 4 ++-- config.go | 14 +++++++++----- messages.go | 17 +++++++++++------ run.go | 27 +++++++++------------------ thread.go | 8 ++++---- 6 files changed, 44 insertions(+), 45 deletions(-) diff --git a/assistant.go b/assistant.go index 9415325f8..661681e83 100644 --- a/assistant.go +++ b/assistant.go @@ -11,7 +11,6 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" - openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { @@ -116,7 +115,7 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -132,7 +131,7 @@ func (c *Client) RetrieveAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -149,7 +148,7 @@ func (c *Client) ModifyAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -165,7 +164,7 @@ func (c *Client) DeleteAssistant( ) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -203,7 +202,7 @@ func (c *Client) ListAssistants( urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -221,7 +220,7 @@ func (c *Client) CreateAssistantFile( urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -238,7 +237,7 @@ func (c *Client) RetrieveAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -255,7 +254,7 @@ func (c *Client) DeleteAssistantFile( ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -294,7 +293,7 @@ func (c *Client) ListAssistantFiles( urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/client.go b/client.go index 9a1c8958d..77d693226 100644 --- a/client.go +++ b/client.go @@ -89,9 +89,9 @@ func withContentType(contentType string) requestOption { } } -func withBetaAssistantV1() requestOption { +func withBetaAssistantVersion(version string) requestOption { return func(args *requestOptions) { - args.header.Set("OpenAI-Beta", "assistants=v1") + args.header.Set("OpenAI-Beta", fmt.Sprintf("assistants=%s", version)) } } diff --git a/config.go b/config.go index c58b71ec6..599fa89c0 100644 --- a/config.go +++ b/config.go @@ -23,6 +23,8 @@ const ( const AzureAPIKeyHeader = "api-key" +const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -30,7 +32,8 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient *http.Client @@ -39,10 +42,11 @@ type ClientConfig struct { func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ - authToken: authToken, - BaseURL: openaiAPIURLv1, - APIType: APITypeOpenAI, - OrgID: "", + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + AssistantVersion: defaultAssistantVersion, + OrgID: "", HTTPClient: &http.Client{}, diff --git a/messages.go b/messages.go index 6fd0adbc9..6af118445 100644 --- a/messages.go +++ b/messages.go @@ -76,7 +76,8 @@ type MessageFilesList struct { // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -111,7 +112,8 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, } urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -126,7 +128,8 @@ func (c *Client) RetrieveMessage( threadID, messageID string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -143,7 +146,7 @@ func (c *Client) ModifyMessage( ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -158,7 +161,8 @@ func (c *Client) RetrieveMessageFile( threadID, messageID, fileID string, ) (file MessageFile, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -173,7 +177,8 @@ func (c *Client) ListMessageFiles( threadID, messageID string, ) (files MessageFilesList, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/run.go b/run.go index 7c14779c5..094b0a4db 100644 --- a/run.go +++ b/run.go @@ -226,8 +226,7 @@ func (c *Client) CreateRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -247,8 +246,7 @@ func (c *Client) RetrieveRun( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -270,8 +268,7 @@ func (c *Client) ModifyRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -310,8 +307,7 @@ func (c *Client) ListRuns( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -332,8 +328,7 @@ func (c *Client) SubmitToolOutputs( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -352,8 +347,7 @@ func (c *Client) CancelRun( ctx, http.MethodPost, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -372,8 +366,7 @@ func (c *Client) CreateThreadAndRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -394,8 +387,7 @@ func (c *Client) RetrieveRunStep( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -435,8 +427,7 @@ func (c *Client) ListRunSteps( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/thread.go b/thread.go index 291f3dcab..900e3f2ea 100644 --- a/thread.go +++ b/thread.go @@ -51,7 +51,7 @@ type ThreadDeleteResponse struct { // CreateThread creates a new thread. func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -64,7 +64,7 @@ func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (respo func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -81,7 +81,7 @@ func (c *Client) ModifyThread( ) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -97,7 +97,7 @@ func (c *Client) DeleteThread( ) (response ThreadDeleteResponse, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } From 2d58f8f4b87be26dc0b7ba2b1f0c9496ecf1dfa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=80=E6=97=A5=E3=80=82?= Date: Wed, 24 Apr 2024 20:02:03 +0800 Subject: [PATCH 137/206] chore: add SystemFingerprint for chat completion stream response (#716) * chore: add SystemFingerprint for stream response * chore: add test * lint: format for test --- chat_stream.go | 1 + chat_stream_test.go | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 6ff7078e2..159f9f472 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -30,6 +30,7 @@ type ChatCompletionStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionStreamChoice `json:"choices"` + SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } diff --git a/chat_stream_test.go b/chat_stream_test.go index bd571cb48..bd1c737dd 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -46,12 +46,12 @@ func TestCreateChatCompletionStream(t *testing.T) { dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: done\n")...) @@ -77,10 +77,11 @@ func TestCreateChatCompletionStream(t *testing.T) { expectedResponses := []openai.ChatCompletionStreamResponse{ { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: openai.GPT3Dot5Turbo, + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ @@ -91,10 +92,11 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: openai.GPT3Dot5Turbo, + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ From c84ab5f6ae8da3a78826ed2c8dc4c5cf93e30589 Mon Sep 17 00:00:00 2001 From: wurui <1009479218@qq.com> Date: Wed, 24 Apr 2024 20:08:58 +0800 Subject: [PATCH 138/206] feat: support cloudflare AI Gateway flavored azure openai (#715) * feat: support cloudflare AI Gateway flavored azure openai Signed-off-by: STRRL * test: add test for cloudflare azure fullURL --------- Signed-off-by: STRRL Co-authored-by: STRRL --- api_internal_test.go | 36 ++++++++++++++++++++++++++++++++++++ client.go | 10 ++++++++-- config.go | 7 ++++--- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 0fb0f8993..a590ec9ab 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -148,3 +148,39 @@ func TestAzureFullURL(t *testing.T) { }) } } + +func TestCloudflareAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Expect string + }{ + { + "CloudflareAzureBaseURLWithSlashAutoStrip", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + { + "CloudflareAzureBaseURLWithoutSlashOK", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL) + az.APIType = APITypeCloudflareAzure + + cli := NewClientWithConfig(az) + + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} diff --git a/client.go b/client.go index 77d693226..c57ba17c7 100644 --- a/client.go +++ b/client.go @@ -182,7 +182,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.APIType == APITypeAzure { + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else if c.config.authToken != "" { // OpenAI or Azure AD authentication @@ -246,7 +246,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { ) } - // c.config.APIType == APITypeOpenAI || c.config.APIType == "" + // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ + if c.config.APIType == APITypeCloudflareAzure { + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) + } + return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } diff --git a/config.go b/config.go index 599fa89c0..bb437c97f 100644 --- a/config.go +++ b/config.go @@ -16,9 +16,10 @@ const ( type APIType string const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" + APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" ) const AzureAPIKeyHeader = "api-key" From c9953a7b051bd661254fb071029553e61c78f8bd Mon Sep 17 00:00:00 2001 From: Alireza Ghasemi Date: Sat, 27 Apr 2024 12:55:49 +0330 Subject: [PATCH 139/206] Fixup minor copy-pasta comment typo (#728) imagess -> images --- image_api_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/image_api_test.go b/image_api_test.go index 2eb46f2b4..48416b1e2 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -36,7 +36,7 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { var err error var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -146,7 +146,7 @@ func TestImageEditWithoutMask(t *testing.T) { func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -202,7 +202,7 @@ func TestImageVariation(t *testing.T) { func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } From 3334a9c78a9d594934e33af184e4e6313c4a942b Mon Sep 17 00:00:00 2001 From: Alireza Ghasemi Date: Tue, 7 May 2024 16:10:07 +0330 Subject: [PATCH 140/206] Add support for word-level audio transcription timestamp granularity (#733) * Add support for audio transcription timestamp_granularities word * Fixup multiple timestamp granularities --- audio.go | 31 ++++++++++++++++++++++++++----- audio_api_test.go | 4 ++++ audio_test.go | 6 +++++- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/audio.go b/audio.go index 4cbe4fe64..dbc26d154 100644 --- a/audio.go +++ b/audio.go @@ -27,8 +27,14 @@ const ( AudioResponseFormatVTT AudioResponseFormat = "vtt" ) +type TranscriptionTimestampGranularity string + +const ( + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" +) + // AudioRequest represents a request structure for audio API. -// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { Model string @@ -38,10 +44,11 @@ type AudioRequest struct { // Reader is an optional io.Reader when you do not want to use an existing file. Reader io.Reader - Prompt string // For translation, it should be in English - Temperature float32 - Language string // For translation, just do not use it. It seems "en" works, not confirmed... - Format AudioResponseFormat + Prompt string + Temperature float32 + Language string // Only for transcription. + Format AudioResponseFormat + TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. } // AudioResponse represents a response structure for audio API. @@ -62,6 +69,11 @@ type AudioResponse struct { NoSpeechProb float64 `json:"no_speech_prob"` Transient bool `json:"transient"` } `json:"segments"` + Words []struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` + } `json:"words"` Text string `json:"text"` httpHeader @@ -179,6 +191,15 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { } } + if len(request.TimestampGranularities) > 0 { + for _, tg := range request.TimestampGranularities { + err = b.WriteField("timestamp_granularities[]", string(tg)) + if err != nil { + return fmt.Errorf("writing timestamp_granularities[]: %w", err) + } + } + } + // Close the multipart writer return b.Close() } diff --git a/audio_api_test.go b/audio_api_test.go index a0efc7921..c24598443 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -105,6 +105,10 @@ func TestAudioWithOptionalArgs(t *testing.T) { Temperature: 0.5, Language: "zh", Format: openai.AudioResponseFormatSRT, + TimestampGranularities: []openai.TranscriptionTimestampGranularity{ + openai.TranscriptionTimestampGranularitySegment, + openai.TranscriptionTimestampGranularityWord, + }, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index 5346244c8..235931f36 100644 --- a/audio_test.go +++ b/audio_test.go @@ -24,6 +24,10 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { Temperature: 0.5, Language: "en", Format: AudioResponseFormatSRT, + TimestampGranularities: []TranscriptionTimestampGranularity{ + TranscriptionTimestampGranularitySegment, + TranscriptionTimestampGranularityWord, + }, } mockFailedErr := fmt.Errorf("mock form builder fail") @@ -47,7 +51,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { return nil } - failOn := []string{"model", "prompt", "temperature", "language", "response_format"} + failOn := []string{"model", "prompt", "temperature", "language", "response_format", "timestamp_granularities[]"} for _, failingField := range failOn { failForField = failingField mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) From 6af32202d1ce469674050600efa07c90ec286d03 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 7 May 2024 20:42:24 +0800 Subject: [PATCH 141/206] feat: support stream_options (#736) * feat: support stream_options * fix lint * fix lint --- chat.go | 10 ++++ chat_stream.go | 4 ++ chat_stream_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+) diff --git a/chat.go b/chat.go index efb14fd4c..a1eb11720 100644 --- a/chat.go +++ b/chat.go @@ -216,6 +216,16 @@ type ChatCompletionRequest struct { Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` +} + +type StreamOptions struct { + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. + IncludeUsage bool `json:"include_usage,omitempty"` } type ToolType string diff --git a/chat_stream.go b/chat_stream.go index 159f9f472..ffd512ff6 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -33,6 +33,10 @@ type ChatCompletionStreamResponse struct { SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + // An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + // When present, it contains a null value except for the last chunk which contains the token usage statistics + // for the entire request. + Usage *Usage `json:"usage,omitempty"` } // ChatCompletionStream diff --git a/chat_stream_test.go b/chat_stream_test.go index bd1c737dd..63e45ee23 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -388,6 +388,120 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } } +func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + var dataBytes []byte + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + StreamOptions: &openai.StreamOptions{ + IncludeUsage: true, + }, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "3", + Object: "completion", + Created: 1598069256, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{}, + Usage: &openai.Usage{ + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished") + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) + } +} + // Helper funcs. func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { @@ -401,6 +515,15 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return false } } + if r1.Usage != nil || r2.Usage != nil { + if r1.Usage == nil || r2.Usage == nil { + return false + } + if r1.Usage.PromptTokens != r2.Usage.PromptTokens || r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || + r1.Usage.TotalTokens != r2.Usage.TotalTokens { + return false + } + } return true } From 3b25e09da90715681fe4049955d7c7ce645e218c Mon Sep 17 00:00:00 2001 From: Kevin Mesiab Date: Mon, 13 May 2024 11:48:14 -0700 Subject: [PATCH 142/206] enhancement: Add new GPT4-o and alias to completion enums (#744) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 00f43ff1c..3b4f8952a 100644 --- a/completion.go +++ b/completion.go @@ -22,6 +22,8 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" GPT4Turbo = "gpt-4-turbo" GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" From 9f19d1c93bf986f2a8925be62f35aa5c413a706a Mon Sep 17 00:00:00 2001 From: nullswan Date: Mon, 13 May 2024 21:07:07 +0200 Subject: [PATCH 143/206] Add gpt4o (#742) * Add gpt4o * disabled model for endpoint seen in https://github.com/sashabaranov/go-openai/commit/e0d0801ac73cdc87d1b56ced0a0eb71e574546c3 * Update completion.go --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 3b4f8952a..ced8e0606 100644 --- a/completion.go +++ b/completion.go @@ -84,6 +84,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, + GPT4o: true, + GPT4o20240513: true, GPT4TurboPreview: true, GPT4VisionPreview: true, GPT4Turbo1106: true, From 4f4a85687be31607536997e924b27693f5e5211a Mon Sep 17 00:00:00 2001 From: Kshirodra Meher Date: Tue, 14 May 2024 00:38:14 +0530 Subject: [PATCH 144/206] Added DALL.E 3 to readme.md (#741) * Added DALL.E 3 to readme.md Added DALL.E 3 to readme.md as its supported now as per issue https://github.com/sashabaranov/go-openai/issues/494 * Update README.md --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7946f4d9b..799dc602b 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * ChatGPT * GPT-3, GPT-4 -* DALL·E 2 +* DALL·E 2, DALL·E 3 * Whisper ## Installation From 211cb49fc22766f4174fef15301c4d39aef609d3 Mon Sep 17 00:00:00 2001 From: ando-masaki Date: Fri, 24 May 2024 16:18:47 +0900 Subject: [PATCH 145/206] Update client.go to get response header whether there is an error or not. (#751) Update client.go to get response header whether there is an error or not. Because 429 Too Many Requests error response has "Retry-After" header. --- client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index c57ba17c7..7bc28e984 100644 --- a/client.go +++ b/client.go @@ -129,14 +129,14 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { defer res.Body.Close() - if isFailureStatusCode(res) { - return c.handleErrorResp(res) - } - if v != nil { v.SetHeader(res.Header) } + if isFailureStatusCode(res) { + return c.handleErrorResp(res) + } + return decodeResponse(res.Body, v) } From 30cf7b879cff5eb56f06fda19c51c9e92fce8b13 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:50:22 -0700 Subject: [PATCH 146/206] feat: add params to RunRequest (#754) --- run.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/run.go b/run.go index 094b0a4db..6bd3933b1 100644 --- a/run.go +++ b/run.go @@ -92,6 +92,7 @@ type RunRequest struct { // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. @@ -103,6 +104,11 @@ type RunRequest struct { // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + // This can be either a string or a ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // This can be either a string or a ResponseFormat object. + ResponseFormat any `json:"response_format,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread. @@ -124,6 +130,13 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +// ReponseFormat specifies the format the model must output. +// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format. +// Type can either be text or json_object. +type ReponseFormat struct { + Type string `json:"type"` +} + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From 8618492b98bb91edbb43f8080b3a68275e183663 Mon Sep 17 00:00:00 2001 From: shosato0306 <38198918+shosato0306@users.noreply.github.com> Date: Wed, 5 Jun 2024 20:03:57 +0900 Subject: [PATCH 147/206] feat: add incomplete run status (#763) --- run.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/run.go b/run.go index 6bd3933b1..5598f1dfb 100644 --- a/run.go +++ b/run.go @@ -30,10 +30,10 @@ type Run struct { Temperature *float32 `json:"temperature,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` @@ -50,6 +50,7 @@ const ( RunStatusCancelling RunStatus = "cancelling" RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" + RunStatusIncomplete RunStatus = "incomplete" RunStatusExpired RunStatus = "expired" RunStatusCancelled RunStatus = "cancelled" ) @@ -95,11 +96,11 @@ type RunRequest struct { TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. From fd41f7a5f49e6723d97642c186e5e090abaebfe2 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Thu, 13 Jun 2024 06:23:07 -0700 Subject: [PATCH 148/206] Fix integration test (#762) * added TestCompletionStream test moved completion stream testing to seperate function added NoErrorF fixes nil pointer reference on stream object * update integration test models --- api_integration_test.go | 64 ++++++++++++++++++++-------------- completion.go | 31 ++++++++-------- embeddings.go | 2 +- internal/test/checks/checks.go | 7 ++++ 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 736040c50..f34685188 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -26,7 +26,7 @@ func TestAPI(t *testing.T) { _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") - _, err = c.GetEngine(ctx, "davinci") + _, err = c.GetEngine(ctx, openai.GPT3Davinci002) checks.NoError(t, err, "GetEngine error") fileRes, err := c.ListFiles(ctx) @@ -42,7 +42,7 @@ func TestAPI(t *testing.T) { "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: openai.AdaSearchQuery, + Model: openai.AdaEmbeddingV2, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") @@ -77,31 +77,6 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: openai.GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } - _, err = c.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ @@ -134,6 +109,41 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } +func TestCompletionStream(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + c := openai.NewClient(apiToken) + ctx := context.Background() + + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: openai.GPT3Babbage002, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + func TestAPIError(t *testing.T) { apiToken := os.Getenv("OPENAI_TOKEN") if apiToken == "" { diff --git a/completion.go b/completion.go index ced8e0606..024f09b14 100644 --- a/completion.go +++ b/completion.go @@ -39,30 +39,33 @@ const ( GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci002 = "text-davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextCurie001 = "text-curie-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextBabbage001 = "text-babbage-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextAda001 = "text-ada-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci001 = "text-davinci-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3Davinci002 = "davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use davinci-002 instead. + GPT3Davinci = "davinci" + GPT3Davinci002 = "davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3CurieInstructBeta = "curie-instruct-beta" GPT3Curie = "curie" GPT3Curie002 = "curie-002" - GPT3Ada = "ada" - GPT3Ada002 = "ada-002" - GPT3Babbage = "babbage" - GPT3Babbage002 = "babbage-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. diff --git a/embeddings.go b/embeddings.go index c5633a313..b513ba6a7 100644 --- a/embeddings.go +++ b/embeddings.go @@ -16,7 +16,7 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") type EmbeddingModel string const ( - // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + // Deprecated: The following block is shut down. Use text-embedding-ada-002 instead. AdaSimilarity EmbeddingModel = "text-similarity-ada-001" BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" CurieSimilarity EmbeddingModel = "text-similarity-curie-001" diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 713369157..6bd0964c6 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -12,6 +12,13 @@ func NoError(t *testing.T, err error, message ...string) { } } +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} + func HasError(t *testing.T, err error, message ...string) { t.Helper() if err == nil { From 7e96c712cbdad50b9cf67324b1ca5ef6541b6235 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:15:27 +0400 Subject: [PATCH 149/206] run integration tests (#769) --- .github/workflows/integration-tests.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/integration-tests.yml diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 000000000..19f158e40 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,19 @@ +name: Integration tests + +on: + push: + branches: + - master + +jobs: + integration_tests: + name: Run integration tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Run integration tests + run: go test -v -tags=integration ./api_integration_test.go From c69c3bb1d259375d5de801f890aca40c0b2a8867 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:21:25 +0400 Subject: [PATCH 150/206] integration tests: pass openai secret (#770) * pass openai secret * only run in master branch --- .github/workflows/integration-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 19f158e40..7260b00b4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -16,4 +16,6 @@ jobs: with: go-version: '1.21' - name: Run integration tests + env: + OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} run: go test -v -tags=integration ./api_integration_test.go From 99cc170b5414bd21fc1c55bccba1d6c1bad04516 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 13 Jun 2024 23:24:37 +0800 Subject: [PATCH 151/206] feat: support batches api (#746) * feat: support batches api * update batch_test.go * fix golangci-lint check * fix golangci-lint check * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix: create batch api * update batch_test.go * feat: add `CreateBatchWithUploadFile` * feat: add `UploadBatchFile` * optimize variable and type naming * expose `BatchLineItem` interface * update batches const --- batch.go | 275 ++++++++++++++++++++++++++++++++++++ batch_test.go | 368 +++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 11 ++ files.go | 1 + 4 files changed, 655 insertions(+) create mode 100644 batch.go create mode 100644 batch_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 000000000..4aba966bc --- /dev/null +++ b/batch.go @@ -0,0 +1,275 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +var ErrUploadBatchFileFailed = errors.New("upload batch file failed") + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + err = errors.Join(ErrUploadBatchFileFailed, err) + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 000000000..4b2261e0e --- /dev/null +++ b/batch_test.go @@ -0,0 +1,368 @@ +package openai_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestUploadBatchFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.UploadBatchFileRequest{} + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.UploadBatchFile(context.Background(), req) + checks.NoError(t, err, "UploadBatchFile error") +} + +func TestCreateBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + _, err := client.CreateBatch(context.Background(), openai.CreateBatchRequest{ + InputFileID: "file-abc", + Endpoint: openai.BatchEndpointChatCompletions, + CompletionWindow: "24h", + }) + checks.NoError(t, err, "CreateBatch error") +} + +func TestCreateBatchWithUploadFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + req := openai.CreateBatchWithUploadFileRequest{ + Endpoint: openai.BatchEndpointChatCompletions, + } + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.CreateBatchWithUploadFile(context.Background(), req) + checks.NoError(t, err, "CreateBatchWithUploadFile error") +} + +func TestRetrieveBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1", handleRetrieveBatchEndpoint) + _, err := client.RetrieveBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestCancelBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1/cancel", handleCancelBatchEndpoint) + _, err := client.CancelBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestListBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + after := "batch_abc123" + limit := 10 + _, err := client.ListBatch(context.Background(), &after, &limit) + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestUploadBatchFileRequest_AddChatCompletion(t *testing.T) { + type args struct { + customerID string + body openai.ChatCompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + { + customerID: "req-2", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddChatCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddCompletion(t *testing.T) { + type args struct { + customerID string + body openai.CompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + { + customerID: "req-2", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { + type args struct { + customerID string + body openai.EmbeddingRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.EmbeddingRequest{ + Model: openai.GPT3Dot5Turbo, + Input: []string{"Hello", "World"}, + }, + }, + { + customerID: "req-2", + body: openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"Hello", "World"}, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddEmbedding(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func handleBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } else if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job" + } + } + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + }`) + } +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} diff --git a/client_test.go b/client_test.go index a08d10f21..e49da9b3d 100644 --- a/client_test.go +++ b/client_test.go @@ -396,6 +396,17 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateSpeech", func() (any, error) { return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) }}, + {"CreateBatch", func() (any, error) { + return client.CreateBatch(ctx, CreateBatchRequest{}) + }}, + {"CreateBatchWithUploadFile", func() (any, error) { + return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{}) + }}, + {"RetrieveBatch", func() (any, error) { + return client.RetrieveBatch(ctx, "") + }}, + {"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }}, + {"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }}, } for _, testCase := range testCases { diff --git a/files.go b/files.go index b40a44f15..26ad6bd70 100644 --- a/files.go +++ b/files.go @@ -22,6 +22,7 @@ const ( PurposeFineTuneResults PurposeType = "fine-tune-results" PurposeAssistants PurposeType = "assistants" PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" ) // FileBytesRequest represents a file upload request. From 68acf22a43903c1b460006e7c4b883ce73e35857 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 13 Jun 2024 17:26:37 +0200 Subject: [PATCH 152/206] Support Tool Resources properties for Threads (#760) * Support Tool Resources properties for Threads * Add Chunking Strategy for Threads vector stores --- thread.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/thread.go b/thread.go index 900e3f2ea..6f7521454 100644 --- a/thread.go +++ b/thread.go @@ -10,21 +10,74 @@ const ( ) type Thread struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Metadata map[string]any `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` httpHeader } type ThreadRequest struct { - Messages []ThreadMessage `json:"messages,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` } +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + type ModifyThreadRequest struct { - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` } type ThreadMessageRole string From 0a421308993425afed7796da8f8e0e1abafd4582 Mon Sep 17 00:00:00 2001 From: Peng Guan-Cheng Date: Wed, 19 Jun 2024 16:37:21 +0800 Subject: [PATCH 153/206] feat: provide vector store (#772) * implement vectore store feature * fix after integration testing * fix golint error * improve test to increare code coverage * fix golint anc code coverage problem * add tool_resource in assistant response * chore: code style * feat: use pagination param * feat: use pagination param * test: use pagination param * test: rm unused code --------- Co-authored-by: Denny Depok <61371551+kodernubie@users.noreply.github.com> Co-authored-by: eric.p --- assistant.go | 50 ++++--- config.go | 2 +- vector_store.go | 345 ++++++++++++++++++++++++++++++++++++++++++ vector_store_test.go | 349 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 728 insertions(+), 18 deletions(-) create mode 100644 vector_store.go create mode 100644 vector_store_test.go diff --git a/assistant.go b/assistant.go index 661681e83..cc13a3020 100644 --- a/assistant.go +++ b/assistant.go @@ -14,16 +14,17 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` httpHeader } @@ -34,6 +35,7 @@ const ( AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" AssistantToolTypeRetrieval AssistantToolType = "retrieval" AssistantToolTypeFunction AssistantToolType = "function" + AssistantToolTypeFileSearch AssistantToolType = "file_search" ) type AssistantTool struct { @@ -41,19 +43,33 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +type AssistantToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` +} + +type AssistantToolCodeInterpreter struct { + FileIDs []string `json:"file_ids"` +} + +type AssistantToolResource struct { + FileSearch *AssistantToolFileSearch `json:"file_search,omitempty"` + CodeInterpreter *AssistantToolCodeInterpreter `json:"code_interpreter,omitempty"` +} + // AssistantRequest provides the assistant request parameters. // When modifying the tools the API functions as the following: // If Tools is undefined, no changes are made to the Assistant's tools. // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases diff --git a/config.go b/config.go index bb437c97f..1347567d7 100644 --- a/config.go +++ b/config.go @@ -24,7 +24,7 @@ const ( const AzureAPIKeyHeader = "api-key" -const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. +const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store // ClientConfig is a configuration of a client. type ClientConfig struct { diff --git a/vector_store.go b/vector_store.go new file mode 100644 index 000000000..5c364362a --- /dev/null +++ b/vector_store.go @@ -0,0 +1,345 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + vectorStoresSuffix = "/vector_stores" + vectorStoresFilesSuffix = "/files" + vectorStoresFileBatchesSuffix = "/file_batches" +) + +type VectorStoreFileCount struct { + InProgress int `json:"in_progress"` + Completed int `json:"completed"` + Failed int `json:"failed"` + Cancelled int `json:"cancelled"` + Total int `json:"total"` +} + +type VectorStore struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name string `json:"name"` + UsageBytes int `json:"usage_bytes"` + FileCounts VectorStoreFileCount `json:"file_counts"` + Status string `json:"status"` + ExpiresAfter *VectorStoreExpires `json:"expires_after"` + ExpiresAt *int `json:"expires_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type VectorStoreExpires struct { + Anchor string `json:"anchor"` + Days int `json:"days"` +} + +// VectorStoreRequest provides the vector store request parameters. +type VectorStoreRequest struct { + Name string `json:"name,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + ExpiresAfter *VectorStoreExpires `json:"expires_after,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// VectorStoresList is a list of vector store. +type VectorStoresList struct { + VectorStores []VectorStore `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type VectorStoreDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type VectorStoreFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + UsageBytes int `json:"usage_bytes"` + Status string `json:"status"` + + httpHeader +} + +type VectorStoreFileRequest struct { + FileID string `json:"file_id"` +} + +type VectorStoreFilesList struct { + VectorStoreFiles []VectorStoreFile `json:"data"` + + httpHeader +} + +type VectorStoreFileBatch struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + Status string `json:"status"` + FileCounts VectorStoreFileCount `json:"file_counts"` + + httpHeader +} + +type VectorStoreFileBatchRequest struct { + FileIDs []string `json:"file_ids"` +} + +// CreateVectorStore creates a new vector store. +func (c *Client) CreateVectorStore(ctx context.Context, request VectorStoreRequest) (response VectorStore, err error) { + req, _ := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(vectorStoresSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStore retrieves an vector store. +func (c *Client) RetrieveVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ModifyVectorStore modifies a vector store. +func (c *Client) ModifyVectorStore( + ctx context.Context, + vectorStoreID string, + request VectorStoreRequest, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStore deletes an vector store. +func (c *Client) DeleteVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStoreDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStores Lists the currently available vector store. +func (c *Client) ListVectorStores( + ctx context.Context, + pagination Pagination, +) (response VectorStoresList, err error) { + urlValues := url.Values{} + + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", vectorStoresSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFile creates a new vector store file. +func (c *Client) CreateVectorStoreFile( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileRequest, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFile retrieves a vector store file. +func (c *Client) RetrieveVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStoreFile deletes an existing file. +func (c *Client) DeleteVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, nil) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFiles( + ctx context.Context, + vectorStoreID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFileBatch creates a new vector store file batch. +func (c *Client) CreateVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileBatchRequest, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFileBatch retrieves a vector store file batch. +func (c *Client) RetrieveVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix, batchID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CancelVectorStoreFileBatch cancel a new vector store file batch. +func (c *Client) CancelVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/cancel") + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFilesInBatch( + ctx context.Context, + vectorStoreID string, + batchID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/files", encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} diff --git a/vector_store_test.go b/vector_store_test.go new file mode 100644 index 000000000..58b9a857e --- /dev/null +++ b/vector_store_test.go @@ -0,0 +1,349 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestVectorStore Tests the vector store endpoint of the API using the mocked server. +func TestVectorStore(t *testing.T) { + vectorStoreID := "vs_abc123" + vectorStoreName := "TestStore" + vectorStoreFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + vectorStoreFileBatchID := "vsfb_abc123" + limit := 20 + order := "desc" + after := "vs_abc122" + before := "vs_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files/"+vectorStoreFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "vector_store.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.VectorStoreFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: request.FileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 1, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreFileBatchRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: len(request.FileIDs), + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.VectorStore + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "vectorstore_abc123", + "object": "vector_store.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 0, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoresList{ + LastID: &vectorStoreID, + FirstID: &vectorStoreID, + VectorStores: []openai.VectorStore{ + { + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + t.Run("create_vector_store", func(t *testing.T) { + _, err := client.CreateVectorStore(ctx, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "CreateVectorStore error") + }) + + t.Run("retrieve_vector_store", func(t *testing.T) { + _, err := client.RetrieveVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "RetrieveVectorStore error") + }) + + t.Run("delete_vector_store", func(t *testing.T) { + _, err := client.DeleteVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "DeleteVectorStore error") + }) + + t.Run("list_vector_store", func(t *testing.T) { + _, err := client.ListVectorStores(context.TODO(), openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStores error") + }) + + t.Run("create_vector_store_file", func(t *testing.T) { + _, err := client.CreateVectorStoreFile(context.TODO(), vectorStoreID, openai.VectorStoreFileRequest{ + FileID: vectorStoreFileID, + }) + checks.NoError(t, err, "CreateVectorStoreFile error") + }) + + t.Run("list_vector_store_files", func(t *testing.T) { + _, err := client.ListVectorStoreFiles(ctx, vectorStoreID, openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFiles error") + }) + + t.Run("retrieve_vector_store_file", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "RetrieveVectorStoreFile error") + }) + + t.Run("delete_vector_store_file", func(t *testing.T) { + err := client.DeleteVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "DeleteVectorStoreFile error") + }) + + t.Run("modify_vector_store", func(t *testing.T) { + _, err := client.ModifyVectorStore(ctx, vectorStoreID, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "ModifyVectorStore error") + }) + + t.Run("create_vector_store_file_batch", func(t *testing.T) { + _, err := client.CreateVectorStoreFileBatch(ctx, vectorStoreID, openai.VectorStoreFileBatchRequest{ + FileIDs: []string{vectorStoreFileID}, + }) + checks.NoError(t, err, "CreateVectorStoreFileBatch error") + }) + + t.Run("retrieve_vector_store_file_batch", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "RetrieveVectorStoreFileBatch error") + }) + + t.Run("list_vector_store_files_in_batch", func(t *testing.T) { + _, err := client.ListVectorStoreFilesInBatch( + ctx, + vectorStoreID, + vectorStoreFileBatchID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFilesInBatch error") + }) + + t.Run("cancel_vector_store_file_batch", func(t *testing.T) { + _, err := client.CancelVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "CancelVectorStoreFileBatch error") + }) +} From e31185974c45949cc58c24a6cbf5ca969fb0f622 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:06:52 +0100 Subject: [PATCH 154/206] remove errors.Join (#778) --- batch.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/batch.go b/batch.go index 4aba966bc..a43d401ab 100644 --- a/batch.go +++ b/batch.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -109,8 +108,6 @@ type BatchResponse struct { Batch } -var ErrUploadBatchFileFailed = errors.New("upload batch file failed") - // CreateBatch — API call to Create batch. func (c *Client) CreateBatch( ctx context.Context, @@ -202,7 +199,6 @@ func (c *Client) CreateBatchWithUploadFile( Lines: request.Lines, }) if err != nil { - err = errors.Join(ErrUploadBatchFileFailed, err) return } return c.CreateBatch(ctx, CreateBatchRequest{ From 03851d20327b7df5358ff9fb0ac96f476be1875a Mon Sep 17 00:00:00 2001 From: Adrian Liechti Date: Sun, 30 Jun 2024 17:20:10 +0200 Subject: [PATCH 155/206] allow custom voice and speech models (#691) --- speech.go | 31 ------------------------------- speech_test.go | 17 ----------------- 2 files changed, 48 deletions(-) diff --git a/speech.go b/speech.go index 7e22e755c..19b21bdf1 100644 --- a/speech.go +++ b/speech.go @@ -2,7 +2,6 @@ package openai import ( "context" - "errors" "net/http" ) @@ -36,11 +35,6 @@ const ( SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) -var ( - ErrInvalidSpeechModel = errors.New("invalid speech model") - ErrInvalidVoice = errors.New("invalid voice") -) - type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` @@ -49,32 +43,7 @@ type CreateSpeechRequest struct { Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } -func contains[T comparable](s []T, e T) bool { - for _, v := range s { - if v == e { - return true - } - } - return false -} - -func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) -} - -func isValidVoice(voice SpeechVoice) bool { - return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) -} - func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - if !isValidSpeechModel(request.Model) { - err = ErrInvalidSpeechModel - return - } - if !isValidVoice(request.Voice) { - err = ErrInvalidVoice - return - } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), withContentType("application/json"), diff --git a/speech_test.go b/speech_test.go index d9ba58b13..f1e405c39 100644 --- a/speech_test.go +++ b/speech_test.go @@ -95,21 +95,4 @@ func TestSpeechIntegration(t *testing.T) { err = os.WriteFile("test.mp3", buf, 0644) checks.NoError(t, err, "Create error") }) - t.Run("invalid model", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: "invalid_model", - Input: "Hello!", - Voice: openai.VoiceAlloy, - }) - checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") - }) - - t.Run("invalid voice", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: openai.TTSModel1, - Input: "Hello!", - Voice: "invalid_voice", - }) - checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") - }) } From 727944c47886924800128d1c33df706b4159eb23 Mon Sep 17 00:00:00 2001 From: Luca Giannini <68999840+LGXerxes@users.noreply.github.com> Date: Fri, 12 Jul 2024 12:31:11 +0200 Subject: [PATCH 156/206] feat: ParallelToolCalls to ChatCompletionRequest with helper functions (#787) * added ParallelToolCalls to ChatCompletionRequest with helper functions * added tests for coverage * changed ParallelToolCalls to any --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index a1eb11720..eb494f41f 100644 --- a/chat.go +++ b/chat.go @@ -218,6 +218,8 @@ type ChatCompletionRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // Options for streaming response. Only set this when you set stream: true. StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } type StreamOptions struct { From 3e47e6fef4ac861dd5e07f73a8fb240374e8cad3 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 19 Jul 2024 22:06:27 +0800 Subject: [PATCH 157/206] fix: #790 (#798) --- files.go | 1 + 1 file changed, 1 insertion(+) diff --git a/files.go b/files.go index 26ad6bd70..edc9f2a20 100644 --- a/files.go +++ b/files.go @@ -102,6 +102,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File if err != nil { return } + defer fileData.Close() err = builder.CreateFormFile("file", fileData) if err != nil { From 27c1c56f0b50a84740425f7534c46825e227b437 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Fri, 19 Jul 2024 07:06:51 -0700 Subject: [PATCH 158/206] feat: Add GPT-4o Mini model support (#796) --- completion.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/completion.go b/completion.go index 024f09b14..4ff1123c4 100644 --- a/completion.go +++ b/completion.go @@ -24,6 +24,8 @@ const ( GPT40314 = "gpt-4-0314" GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" @@ -89,6 +91,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4o: true, GPT4o20240513: true, + GPT4oMini: true, + GPT4oMini20240718: true, GPT4TurboPreview: true, GPT4VisionPreview: true, GPT4Turbo1106: true, From 92f483055f666847f7954e148b7f46771c5581b8 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 19 Jul 2024 22:10:17 +0800 Subject: [PATCH 159/206] fix: #794 (#797) --- client.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 7bc28e984..d5d555c3d 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" utils "github.com/sashabaranov/go-openai/internal" @@ -228,10 +229,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") + parseURL, _ := url.Parse(baseURL) + query := parseURL.Query() + query.Add("api-version", c.config.APIVersion) // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) } azureDeploymentName := "UNKNOWN" if len(args) > 0 { @@ -240,9 +244,9 @@ func (c *Client) fullURL(suffix string, args ...any) string { azureDeploymentName = c.config.GetAzureDeploymentByModel(model) } } - return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", + return fmt.Sprintf("%s/%s/%s/%s%s?%s", baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, c.config.APIVersion, + azureDeploymentName, suffix, query.Encode(), ) } From ae903d7465c4b48654fac6103472767ee4d95e41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edin=20=C4=86orali=C4=87?= <73831203+ecoralic@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:12:20 +0300 Subject: [PATCH 160/206] fix: Updated ThreadMessage struct with latest fields based on OpenAI docs (#792) * fix: Updated ThreadMessage struct with latest fields based on OpenAI docs * fix: Reverted FileIDs for backward compatibility of v1 --- thread.go | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/thread.go b/thread.go index 6f7521454..bc08e2bcb 100644 --- a/thread.go +++ b/thread.go @@ -83,14 +83,25 @@ type ModifyThreadRequest struct { type ThreadMessageRole string const ( - ThreadMessageRoleUser ThreadMessageRole = "user" + ThreadMessageRoleAssistant ThreadMessageRole = "assistant" + ThreadMessageRoleUser ThreadMessageRole = "user" ) type ThreadMessage struct { - Role ThreadMessageRole `json:"role"` - Content string `json:"content"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadAttachment struct { + FileID string `json:"file_id"` + Tools []ThreadAttachmentTool `json:"tools"` +} + +type ThreadAttachmentTool struct { + Type string `json:"type"` } type ThreadDeleteResponse struct { From a7e9f0e3880d1487fe8e06a43820f42046b5b622 Mon Sep 17 00:00:00 2001 From: Janusch Jacoby Date: Fri, 19 Jul 2024 16:13:02 +0200 Subject: [PATCH 161/206] add hyperparams (#793) --- fine_tuning_job.go | 4 +++- fine_tuning_job_test.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 9dcb49de1..5a9f54a92 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -26,7 +26,9 @@ type FineTuningJob struct { } type Hyperparameters struct { - Epochs any `json:"n_epochs,omitempty"` + Epochs any `json:"n_epochs,omitempty"` + LearningRateMultiplier any `json:"learning_rate_multiplier,omitempty"` + BatchSize any `json:"batch_size,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index d2fbcd4c7..5f63ef24c 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -33,7 +33,9 @@ func TestFineTuningJob(t *testing.T) { ValidationFile: "", TrainingFile: "file-abc123", Hyperparameters: openai.Hyperparameters{ - Epochs: "auto", + Epochs: "auto", + LearningRateMultiplier: "auto", + BatchSize: "auto", }, TrainedTokens: 5768, }) From 966ee682b11ca580c2c2c3ac067c27b51bd6d749 Mon Sep 17 00:00:00 2001 From: VanessaMae23 <60029664+Vanessamae23@users.noreply.github.com> Date: Fri, 19 Jul 2024 22:18:16 +0800 Subject: [PATCH 162/206] Add New Optional Parameters to `AssistantRequest` Struct (#795) * Add more parameters to support Assistant v2 * Add goimports --- assistant.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/assistant.go b/assistant.go index cc13a3020..4c89c1b2f 100644 --- a/assistant.go +++ b/assistant.go @@ -62,14 +62,17 @@ type AssistantToolResource struct { // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases From 581da2f12d52617368bdfe2625f5b0ef1dd32758 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Mon, 29 Jul 2024 01:43:45 +0800 Subject: [PATCH 163/206] fix: #804 (#807) --- batch.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batch.go b/batch.go index a43d401ab..3c1a9d0d7 100644 --- a/batch.go +++ b/batch.go @@ -65,7 +65,7 @@ type Batch struct { Endpoint BatchEndpoint `json:"endpoint"` Errors *struct { Object string `json:"object,omitempty"` - Data struct { + Data []struct { Code string `json:"code,omitempty"` Message string `json:"message,omitempty"` Param *string `json:"param,omitempty"` From dbe726c59f6df65965a4ee25e37706c33e391dc4 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 7 Aug 2024 20:21:38 +1000 Subject: [PATCH 164/206] Add support for `gpt-4o-2024-08-06` (#812) * feat: Add GPT-4o Mini model support * feat: Add GPT-4o-2024-08-06 model support --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 4ff1123c4..d435eb382 100644 --- a/completion.go +++ b/completion.go @@ -24,6 +24,7 @@ const ( GPT40314 = "gpt-4-0314" GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" @@ -91,6 +92,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4o: true, GPT4o20240513: true, + GPT4o20240806: true, GPT4oMini: true, GPT4oMini20240718: true, GPT4TurboPreview: true, From 623074c14a110b97d9a7aac7896bbdccf335257f Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 7 Aug 2024 21:47:48 +0800 Subject: [PATCH 165/206] feat: Support Structured Outputs (#813) * feat: Support Structured Outputs * feat: Support Structured Outputs * update imports * add integration test * update JSON schema comments --- api_integration_test.go | 61 +++++++++++++++++++++++++++++++++++++++++ chat.go | 13 ++++++++- jsonschema/json.go | 8 +++++- 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index f34685188..a487f588a 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,6 +4,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "os" @@ -178,3 +179,63 @@ func TestAPIError(t *testing.T) { t.Fatal("Empty error message occurred") } } + +func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{ + Name: "cases", + Schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": jsonschema.Definition{Type: jsonschema.String}, + "CamelCase": jsonschema.Definition{Type: jsonschema.String}, + "KebabCase": jsonschema.Definition{Type: jsonschema.String}, + "SnakeCase": jsonschema.Definition{Type: jsonschema.String}, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + Strict: true, + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/chat.go b/chat.go index eb494f41f..8bfe558b5 100644 --- a/chat.go +++ b/chat.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "net/http" + + "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -175,11 +177,20 @@ type ChatCompletionResponseFormatType string const ( ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type,omitempty"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` +} + +type ChatCompletionResponseFormatJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema jsonschema.Definition `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. diff --git a/jsonschema/json.go b/jsonschema/json.go index cb941eb75..7fd1e11bf 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -29,11 +29,17 @@ type Definition struct { // one element, where each element is unique. You will probably only use this with strings. Enum []string `json:"enum,omitempty"` // Properties describes the properties of an object, if the schema type is Object. - Properties map[string]Definition `json:"properties"` + Properties map[string]Definition `json:"properties,omitempty"` // Required specifies which properties are required, if the schema type is Object. Required []string `json:"required,omitempty"` // Items specifies which data type an array contains, if the schema type is Array. Items *Definition `json:"items,omitempty"` + // AdditionalProperties is used to control the handling of properties in an object + // that are not explicitly defined in the properties section of the schema. example: + // additionalProperties: true + // additionalProperties: false + // additionalProperties: jsonschema.Definition{Type: jsonschema.String} + AdditionalProperties any `json:"additionalProperties,omitempty"` } func (d Definition) MarshalJSON() ([]byte, error) { From 6439e1fcc93fc5175accf5d51358e45fa5ea9099 Mon Sep 17 00:00:00 2001 From: Tyler Gannon Date: Wed, 7 Aug 2024 12:40:45 -0700 Subject: [PATCH 166/206] Make reponse format JSONSchema optional (#820) --- chat.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index 8bfe558b5..31fa887d6 100644 --- a/chat.go +++ b/chat.go @@ -182,8 +182,8 @@ const ( ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type,omitempty"` - JSONSchema ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` } type ChatCompletionResponseFormatJSONSchema struct { From 18803333812ea21c409e84d426141606b9a6e692 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Fri, 9 Aug 2024 18:30:32 +0200 Subject: [PATCH 167/206] Run integration tests for PRs (#823) * Unbreak integration tests * Update integration-tests.yml --- api_integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_integration_test.go b/api_integration_test.go index a487f588a..3084268e6 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -211,7 +211,7 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { }, ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, - JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{ + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ Name: "cases", Schema: jsonschema.Definition{ Type: jsonschema.Object, From 2c6889e0818b93c4fd724d9528b610896f5e9421 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sun, 11 Aug 2024 05:05:06 +0800 Subject: [PATCH 168/206] fix: #788 (#800) --- completion.go | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/completion.go b/completion.go index d435eb382..bc2a63795 100644 --- a/completion.go +++ b/completion.go @@ -138,25 +138,26 @@ func checkPromptType(prompt any) bool { // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { - Model string `json:"model"` - Prompt any `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` + Model string `json:"model"` + Prompt any `json:"prompt,omitempty"` + BestOf int `json:"best_of,omitempty"` + Echo bool `json:"echo,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From dd7f5824f9a4c3860cccfaf8350d5d09e864038f Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sat, 17 Aug 2024 01:11:38 +0800 Subject: [PATCH 169/206] fix: fullURL endpoint generation (#817) --- api_internal_test.go | 24 ++++++++--- audio.go | 9 ++++- chat.go | 7 +++- chat_stream.go | 7 +++- client.go | 84 ++++++++++++++++++++++++-------------- client_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++++ completion.go | 7 +++- edits.go | 7 +++- embeddings.go | 7 +++- example_test.go | 2 +- image.go | 25 +++++++++--- moderation.go | 7 +++- speech.go | 5 ++- stream.go | 8 +++- 14 files changed, 244 insertions(+), 51 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index a590ec9ab..09677968a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -112,6 +112,7 @@ func TestAzureFullURL(t *testing.T) { Name string BaseURL string AzureModelMapper map[string]string + Suffix string Model string Expect string }{ @@ -119,6 +120,7 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -128,11 +130,20 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithoutSlashOK", "https://httpbin.org", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + "/chat/completions?api-version=2023-05-15", }, + { + "", + "https://httpbin.org", + nil, + "/assistants?limit=10", + "chatgpt-demo", + "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", + }, } for _, c := range cases { @@ -140,7 +151,7 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions", c.Model) + actual := cli.fullURL(c.Suffix, withModel(c.Model)) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } @@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) { cases := []struct { Name string BaseURL string + Suffix string Expect string }{ { "CloudflareAzureBaseURLWithSlashAutoStrip", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "/chat/completions", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, { - "CloudflareAzureBaseURLWithoutSlashOK", + "", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + - "chat/completions?api-version=2023-05-15", + "/assistants?limit=10", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + + "/assistants?api-version=2023-05-15&limit=10", }, } @@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/audio.go b/audio.go index dbc26d154..f321f93d6 100644 --- a/audio.go +++ b/audio.go @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), - withBody(&formBody), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(&formBody), + withContentType(builder.FormDataContentType()), + ) if err != nil { return AudioResponse{}, err } diff --git a/chat.go b/chat.go index 31fa887d6..826fd3bd5 100644 --- a/chat.go +++ b/chat.go @@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index ffd512ff6..3f90bc019 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/client.go b/client.go index d5d555c3d..9f547e7cb 100644 --- a/client.go +++ b/client.go @@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error { return nil } +type fullURLOptions struct { + model string +} + +type fullURLOption func(*fullURLOptions) + +func withModel(model string) fullURLOption { + return func(args *fullURLOptions) { + args.model = model + } +} + +var azureDeploymentsEndpoints = []string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", +} + // fullURL returns full URL for request. -// args[0] is model name, if API type is Azure, model name is required to get deployment name. -func (c *Client) fullURL(suffix string, args ...any) string { - // /openai/deployments/{model}/chat/completions?api-version={api_version} +func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) + } + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - parseURL, _ := url.Parse(baseURL) - query := parseURL.Query() - query.Add("api-version", c.config.APIVersion) - // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 - // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) - } - azureDeploymentName := "UNKNOWN" - if len(args) > 0 { - model, ok := args[0].(string) - if ok { - azureDeploymentName = c.config.GetAzureDeploymentByModel(model) - } - } - return fmt.Sprintf("%s/%s/%s/%s%s?%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, query.Encode(), - ) + baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model) + } + + if c.config.APIVersion != "" { + suffix = c.suffixWithAPIVersion(suffix) } + return fmt.Sprintf("%s%s", baseURL, suffix) +} - // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ - if c.config.APIType == APITypeCloudflareAzure { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) +func (c *Client) suffixWithAPIVersion(suffix string) string { + parsedSuffix, err := url.Parse(suffix) + if err != nil { + panic("failed to parse url suffix") } + query := parsedSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode()) +} - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) +func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { + baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + azureDeploymentName := c.config.GetAzureDeploymentByModel(model) + if azureDeploymentName == "" { + azureDeploymentName = "UNKNOWN" + } + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) + } + return baseURL } func (c *Client) handleErrorResp(resp *http.Response) error { diff --git a/client_test.go b/client_test.go index e49da9b3d..a0d3bb390 100644 --- a/client_test.go +++ b/client_test.go @@ -431,3 +431,99 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestClient_suffixWithAPIVersion(t *testing.T) { + type fields struct { + apiVersion string + } + type args struct { + suffix string + } + tests := []struct { + name string + fields fields + args args + want string + wantPanic string + }{ + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants"}, + "/assistants?api-version=2023-05", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "123:assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "failed to parse url suffix", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + config: ClientConfig{APIVersion: tt.fields.apiVersion}, + } + defer func() { + if r := recover(); r != nil { + if r.(string) != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic) + } + } + }() + if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want { + t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_baseURLWithAzureDeployment(t *testing.T) { + type args struct { + baseURL string + suffix string + model string + } + tests := []struct { + name string + args args + wantNewBaseURL string + }{ + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai/deployments/gpt-4o-mini", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""}, + "https://test.openai.azure.com/openai/deployments/UNKNOWN", + }, + } + client := NewClient("") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotNewBaseURL := client.baseURLWithAzureDeployment( + tt.args.baseURL, + tt.args.suffix, + tt.args.model, + ); gotNewBaseURL != tt.wantNewBaseURL { + t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL) + } + }) + } +} diff --git a/completion.go b/completion.go index bc2a63795..e8e9242c9 100644 --- a/completion.go +++ b/completion.go @@ -213,7 +213,12 @@ func (c *Client) CreateCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/edits.go b/edits.go index 97d026029..fe8ecd0c1 100644 --- a/edits.go +++ b/edits.go @@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024. You can use CreateChatCompletion or CreateChatCompletionStream instead. */ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), + withBody(request), + ) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index b513ba6a7..74eb8aa57 100644 --- a/embeddings.go +++ b/embeddings.go @@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/embeddings", withModel(string(baseReq.Model))), + withBody(baseReq), + ) if err != nil { return } diff --git a/example_test.go b/example_test.go index de67c57cd..1bdb8496e 100644 --- a/example_test.go +++ b/example_test.go @@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() { return } - fmt.Printf(response.Choices[0].Delta.Content) + fmt.Println(response.Choices[0].Delta.Content) } } diff --git a/image.go b/image.go index 665de1a74..577d7db95 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,12 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } @@ -132,8 +137,13 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/edits", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } @@ -183,8 +193,13 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/variations", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } diff --git a/moderation.go b/moderation.go index ae285ef83..c8652efc8 100644 --- a/moderation.go +++ b/moderation.go @@ -88,7 +88,12 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re err = ErrModerationInvalidModel return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/moderations", withModel(request.Model)), + withBody(&request), + ) if err != nil { return } diff --git a/speech.go b/speech.go index 19b21bdf1..20b52e334 100644 --- a/speech.go +++ b/speech.go @@ -44,7 +44,10 @@ type CreateSpeechRequest struct { } func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/audio/speech", withModel(string(request.Model))), withBody(request), withContentType("application/json"), ) diff --git a/stream.go b/stream.go index b277f3c29..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -3,6 +3,7 @@ package openai import ( "context" "errors" + "net/http" ) var ( @@ -33,7 +34,12 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } From d86425a5cfd09bb76fe2f9239a03a9dbcdca8a9c Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Fri, 16 Aug 2024 13:41:39 -0400 Subject: [PATCH 170/206] Allow structured outputs via function calling (#828) --- api_integration_test.go | 76 +++++++++++++++++++++++++++++++++++++++++ chat.go | 1 + chat_test.go | 26 ++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/api_integration_test.go b/api_integration_test.go index 3084268e6..57f7c40fb 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -239,3 +239,79 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { } } } + +func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: "display_cases", + Strict: true, + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": { + Type: jsonschema.String, + }, + "CamelCase": { + Type: jsonschema.String, + }, + "KebabCase": { + Type: jsonschema.String, + }, + "SnakeCase": { + Type: jsonschema.String, + }, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + }, + }, + }, + ToolChoice: openai.ToolChoice{ + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{ + Name: "display_cases", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/chat.go b/chat.go index 826fd3bd5..97c89a497 100644 --- a/chat.go +++ b/chat.go @@ -264,6 +264,7 @@ type ToolFunction struct { type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` // Parameters is an object describing the function. // You can pass json.RawMessage to describe the schema, // or you can pass in a struct which serializes to the proper JSON schema. diff --git a/chat_test.go b/chat_test.go index 520bf5ca4..37dc09d4d 100644 --- a/chat_test.go +++ b/chat_test.go @@ -277,6 +277,32 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) + t.Run("StructuredOutputs", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Strict: true, + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) } func TestAzureChatCompletions(t *testing.T) { From 6d021190f05410a44d9401984815c55f4736b755 Mon Sep 17 00:00:00 2001 From: Yamagami ken-ichi Date: Thu, 22 Aug 2024 23:27:44 +0900 Subject: [PATCH 171/206] feat: Support Delete Message API (#799) * feat: Add DeleteMessage function to API client * fix: linter nolint : Deprecated method split function: cognitive complexity 21 * rename func name for unit-test --- client_test.go | 3 +++ fine_tunes.go | 2 +- messages.go | 24 ++++++++++++++++++++++++ messages_test.go | 36 +++++++++++++++++++++++++++++++----- 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/client_test.go b/client_test.go index a0d3bb390..7119d8a7e 100644 --- a/client_test.go +++ b/client_test.go @@ -348,6 +348,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ModifyMessage", func() (any, error) { return client.ModifyMessage(ctx, "", "", nil) }}, + {"DeleteMessage", func() (any, error) { + return client.DeleteMessage(ctx, "", "") + }}, {"RetrieveMessageFile", func() (any, error) { return client.RetrieveMessageFile(ctx, "", "", "") }}, diff --git a/fine_tunes.go b/fine_tunes.go index ca840781c..74b47bf3f 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,7 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) //nolint:lll //this method is deprecated if err != nil { return } diff --git a/messages.go b/messages.go index 6af118445..1fddd6314 100644 --- a/messages.go +++ b/messages.go @@ -73,6 +73,14 @@ type MessageFilesList struct { httpHeader } +type MessageDeletionStatus struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) @@ -186,3 +194,19 @@ func (c *Client) ListMessageFiles( err = c.sendRequest(req, &files) return } + +// DeleteMessage deletes a message.. +func (c *Client) DeleteMessage( + ctx context.Context, + threadID, messageID string, +) (status MessageDeletionStatus, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &status) + return +} diff --git a/messages_test.go b/messages_test.go index a18be20bd..71ceb4d3a 100644 --- a/messages_test.go +++ b/messages_test.go @@ -8,20 +8,17 @@ import ( "testing" "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) var emptyStr = "" -// TestMessages Tests the messages endpoint of the API using the mocked server. -func TestMessages(t *testing.T) { +func setupServerForTestMessage(t *testing.T, server *test.ServerTest) { threadID := "thread_abc123" messageID := "msg_abc123" fileID := "file_abc123" - client, server, teardown := setupOpenAITestServer() - defer teardown() - server.RegisterHandler( "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, func(w http.ResponseWriter, r *http.Request) { @@ -115,6 +112,13 @@ func TestMessages(t *testing.T) { Metadata: nil, }) fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + resBytes, _ := json.Marshal(openai.MessageDeletionStatus{ + ID: messageID, + Object: "thread.message.deleted", + Deleted: true, + }) + fmt.Fprintln(w, string(resBytes)) default: t.Fatalf("unsupported messages http method: %s", r.Method) } @@ -176,7 +180,18 @@ func TestMessages(t *testing.T) { } }, ) +} +// TestMessages Tests the messages endpoint of the API using the mocked server. +func TestMessages(t *testing.T) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + setupServerForTestMessage(t, server) ctx := context.Background() // static assertion of return type @@ -225,6 +240,17 @@ func TestMessages(t *testing.T) { t.Fatalf("expected message metadata to get modified") } + msgDel, err := client.DeleteMessage(ctx, threadID, messageID) + checks.NoError(t, err, "DeleteMessage error") + if msgDel.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + if !msgDel.Deleted { + t.Fatalf("expected deleted is true") + } + _, err = client.DeleteMessage(ctx, threadID, "not_exist_id") + checks.HasError(t, err, "DeleteMessage error") + // message files var msgFile openai.MessageFile msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) From 5162adbbf90cef77b8462c1f33c81f7d258a1447 Mon Sep 17 00:00:00 2001 From: Alexey Michurin Date: Fri, 23 Aug 2024 13:47:11 +0300 Subject: [PATCH 172/206] Support http client middlewareing (#830) --- config.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 1347567d7..8a9183558 100644 --- a/config.go +++ b/config.go @@ -26,6 +26,10 @@ const AzureAPIKeyHeader = "api-key" const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store +type HTTPDoer interface { + Do(req *http.Request) (*http.Response, error) +} + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -36,7 +40,7 @@ type ClientConfig struct { APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func - HTTPClient *http.Client + HTTPClient HTTPDoer EmptyMessagesLimit uint } From a3bd2569ac51f1c54d704ec80dcbb91ab9f46acf Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sun, 25 Aug 2024 01:06:08 +0800 Subject: [PATCH 173/206] Improve handling of JSON Schema in OpenAI API Response Context (#819) * feat: add jsonschema.Validate and jsonschema.Unmarshal * fix Sanity check * remove slices.Contains * fix Sanity check * add SchemaWrapper * update api_integration_test.go * update method 'reflectSchema' to support 'omitempty' in JSON tag * add GenerateSchemaForType * update json_test.go * update `Warp` to `Wrap` * fix Sanity check * fix Sanity check * update api_internal_test.go * update README.md * update README.md * remove jsonschema.SchemaWrapper * remove jsonschema.SchemaWrapper * fix Sanity check * optimize code formatting --- README.md | 64 +++++++++++++++++ api_integration_test.go | 36 +++++----- chat.go | 10 ++- example_test.go | 2 +- jsonschema/json.go | 105 +++++++++++++++++++++++++++- jsonschema/validate.go | 89 +++++++++++++++++++++++ jsonschema/validate_test.go | 136 ++++++++++++++++++++++++++++++++++++ 7 files changed, 412 insertions(+), 30 deletions(-) create mode 100644 jsonschema/validate.go create mode 100644 jsonschema/validate_test.go diff --git a/README.md b/README.md index 799dc602b..0d6aafa40 100644 --- a/README.md +++ b/README.md @@ -743,6 +743,70 @@ func main() { } ```
+ +
+Structured Outputs + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + type Result struct { + Steps []struct { + Explanation string `json:"explanation"` + Output string `json:"output"` + } `json:"steps"` + FinalAnswer string `json:"final_answer"` + } + var result Result + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + log.Fatalf("GenerateSchemaForType error: %v", err) + } + resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "You are a helpful math tutor. Guide the user through the solution step by step.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "how can I solve 8x + 7 = -23", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "math_reasoning", + Schema: schema, + Strict: true, + }, + }, + }) + if err != nil { + log.Fatalf("CreateChatCompletion error: %v", err) + } + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + if err != nil { + log.Fatalf("Unmarshal schema error: %v", err) + } + fmt.Println(result) +} +``` +
See the `examples/` folder for more. ## Frequently Asked Questions diff --git a/api_integration_test.go b/api_integration_test.go index 57f7c40fb..8c9f3384f 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,7 +4,6 @@ package openai_test import ( "context" - "encoding/json" "errors" "io" "os" @@ -190,6 +189,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { c := openai.NewClient(apiToken) ctx := context.Background() + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + var result MyStructuredResponse + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") + } resp, err := c.CreateChatCompletion( ctx, openai.ChatCompletionRequest{ @@ -212,31 +222,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ - Name: "cases", - Schema: jsonschema.Definition{ - Type: jsonschema.Object, - Properties: map[string]jsonschema.Definition{ - "PascalCase": jsonschema.Definition{Type: jsonschema.String}, - "CamelCase": jsonschema.Definition{Type: jsonschema.String}, - "KebabCase": jsonschema.Definition{Type: jsonschema.String}, - "SnakeCase": jsonschema.Definition{Type: jsonschema.String}, - }, - Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, - AdditionalProperties: false, - }, + Name: "cases", + Schema: schema, Strict: true, }, }, }, ) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") - var result = make(map[string]string) - err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result) - checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") - for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { - if _, ok := result[key]; !ok { - t.Errorf("key:%s does not exist.", key) - } + if err == nil { + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") } } diff --git a/chat.go b/chat.go index 97c89a497..56e99a78b 100644 --- a/chat.go +++ b/chat.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "net/http" - - "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct { } type ChatCompletionResponseFormatJSONSchema struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Schema jsonschema.Definition `json:"schema"` - Strict bool `json:"strict"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. diff --git a/example_test.go b/example_test.go index 1bdb8496e..e5dbf44bf 100644 --- a/example_test.go +++ b/example_test.go @@ -59,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() { } defer stream.Close() - fmt.Printf("Stream response: ") + fmt.Print("Stream response: ") for { var response openai.ChatCompletionStreamResponse response, err = stream.Recv() diff --git a/jsonschema/json.go b/jsonschema/json.go index 7fd1e11bf..bcb253fae 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,7 +4,13 @@ // and/or pass in the schema in []byte format. package jsonschema -import "encoding/json" +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) type DataType string @@ -42,7 +48,7 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` } -func (d Definition) MarshalJSON() ([]byte, error) { +func (d *Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) } @@ -50,6 +56,99 @@ func (d Definition) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Alias }{ - Alias: (Alias)(d), + Alias: (Alias)(*d), }) } + +func (d *Definition) Unmarshal(content string, v any) error { + return VerifySchemaAndUnmarshal(*d, []byte(content), v) +} + +func GenerateSchemaForType(v any) (*Definition, error) { + return reflectSchema(reflect.TypeOf(v)) +} + +func reflectSchema(t reflect.Type) (*Definition, error) { + var d Definition + switch t.Kind() { + case reflect.String: + d.Type = String + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + d.Type = Integer + case reflect.Float32, reflect.Float64: + d.Type = Number + case reflect.Bool: + d.Type = Boolean + case reflect.Slice, reflect.Array: + d.Type = Array + items, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d.Items = items + case reflect.Struct: + d.Type = Object + d.AdditionalProperties = false + object, err := reflectSchemaObject(t) + if err != nil { + return nil, err + } + d = *object + case reflect.Ptr: + definition, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d = *definition + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.UnsafePointer: + return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) + default: + } + return &d, nil +} + +func reflectSchemaObject(t reflect.Type) (*Definition, error) { + var d = Definition{ + Type: Object, + AdditionalProperties: false, + } + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + var required = true + if jsonTag == "" { + jsonTag = field.Name + } else if strings.HasSuffix(jsonTag, ",omitempty") { + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false + } + + item, err := reflectSchema(field.Type) + if err != nil { + return nil, err + } + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + properties[jsonTag] = *item + + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } + if required { + requiredFields = append(requiredFields, jsonTag) + } + } + d.Required = requiredFields + d.Properties = properties + return &d, nil +} diff --git a/jsonschema/validate.go b/jsonschema/validate.go new file mode 100644 index 000000000..f14ffd4c4 --- /dev/null +++ b/jsonschema/validate.go @@ -0,0 +1,89 @@ +package jsonschema + +import ( + "encoding/json" + "errors" +) + +func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { + var data any + err := json.Unmarshal(content, &data) + if err != nil { + return err + } + if !Validate(schema, data) { + return errors.New("data validation failed against the provided schema") + } + return json.Unmarshal(content, &v) +} + +func Validate(schema Definition, data any) bool { + switch schema.Type { + case Object: + return validateObject(schema, data) + case Array: + return validateArray(schema, data) + case String: + _, ok := data.(string) + return ok + case Number: // float64 and int + _, ok := data.(float64) + if !ok { + _, ok = data.(int) + } + return ok + case Boolean: + _, ok := data.(bool) + return ok + case Integer: + _, ok := data.(int) + return ok + case Null: + return data == nil + default: + return false + } +} + +func validateObject(schema Definition, data any) bool { + dataMap, ok := data.(map[string]any) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value) { + return false + } else if !exists && contains(schema.Required, key) { + return false + } + } + return true +} + +func validateArray(schema Definition, data any) bool { + dataArray, ok := data.([]any) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item) { + return false + } + } + return true +} + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go new file mode 100644 index 000000000..c2c47a2ce --- /dev/null +++ b/jsonschema/validate_test.go @@ -0,0 +1,136 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +func Test_Validate(t *testing.T) { + type args struct { + data any + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want bool + }{ + // string integer number boolean + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true}, + {"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false}, + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true}, + {"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false}, + {"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true}, + {"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false}, + // array + {"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, true}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, false}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, true}, + {"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, false}, + // object + {"", args{data: map[string]any{ + "string": "abc", + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, true}, + {"", args{data: map[string]any{ + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want { + t.Errorf("Validate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshal(t *testing.T) { + type args struct { + schema jsonschema.Definition + content []byte + v any + } + var result1 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + var result2 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + }, + content: []byte(`{"string":"abc","number":123.4}`), + v: &result1, + }, false}, + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + Required: []string{"string", "number"}, + }, + content: []byte(`{"string":"abc"}`), + v: result2, + }, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil { + t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +} From 030b7cb7ed60fc4a8b2fd608f538c470b65b1131 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 24 Aug 2024 18:11:27 +0100 Subject: [PATCH 174/206] fix integration tests (#834) --- api_integration_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/api_integration_test.go b/api_integration_test.go index 8c9f3384f..7828d9451 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,6 +4,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "os" From c37cf9ab5b887fe0195d3cc6240780e9b1928a04 Mon Sep 17 00:00:00 2001 From: Tommy Mathisen Date: Sun, 1 Sep 2024 18:30:29 +0300 Subject: [PATCH 175/206] Dynamic model (#838) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index e8e9242c9..12ce4b558 100644 --- a/completion.go +++ b/completion.go @@ -25,6 +25,7 @@ const ( GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4oLatest = "chatgpt-4o-latest" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" @@ -93,6 +94,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4o: true, GPT4o20240513: true, GPT4o20240806: true, + GPT4oLatest: true, GPT4oMini: true, GPT4oMini20240718: true, GPT4TurboPreview: true, From 643da8d650b1f7db4706076a53b9d0acddccbd17 Mon Sep 17 00:00:00 2001 From: Arun Das <89579096+Arundas666@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:19:57 +0530 Subject: [PATCH 176/206] depricated model GPT3Ada changed to GPT3Babbage002 (#843) * depricated model GPT3Ada changed to GPT3Babbage002 * Delete test.mp3 --- README.md | 4 ++-- example_test.go | 4 ++-- examples/completion/main.go | 2 +- stream_test.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 0d6aafa40..b3ebc1471 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", } @@ -174,7 +174,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, diff --git a/example_test.go b/example_test.go index e5dbf44bf..5910ffb84 100644 --- a/example_test.go +++ b/example_test.go @@ -82,7 +82,7 @@ func ExampleClient_CreateCompletion() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, @@ -99,7 +99,7 @@ func ExampleClient_CreateCompletionStream() { stream, err := client.CreateCompletionStream( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, diff --git a/examples/completion/main.go b/examples/completion/main.go index 22af1fd82..8c5cbd5ca 100644 --- a/examples/completion/main.go +++ b/examples/completion/main.go @@ -13,7 +13,7 @@ func main() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, diff --git a/stream_test.go b/stream_test.go index 2822a3535..9dd95bb5f 100644 --- a/stream_test.go +++ b/stream_test.go @@ -169,7 +169,7 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { var apiErr *openai.APIError _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, Prompt: "Hello!", Stream: true, }) From 194a03e763f0d71333a6088bf613a35f65c50447 Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Wed, 11 Sep 2024 22:24:49 +0200 Subject: [PATCH 177/206] Add refusal (#844) * add custom marshaller, documentation and isolate tests * fix linter * add missing field --- chat.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/chat.go b/chat.go index 56e99a78b..dc60f35b9 100644 --- a/chat.go +++ b/chat.go @@ -82,6 +82,7 @@ type ChatMessagePart struct { type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in @@ -107,6 +108,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { msg := struct { Role string `json:"role"` Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"content,omitempty"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` @@ -115,9 +117,11 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { }(m) return json.Marshal(msg) } + msg := struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"-"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` @@ -131,12 +135,14 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` }{} + if err := json.Unmarshal(bs, &msg); err == nil { *m = ChatCompletionMessage(msg) return nil @@ -144,6 +150,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { multiMsg := struct { Role string `json:"role"` Content string + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"content"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` From a5fb55321b43aa6b31bb3ff57d43cb5a8f2e17ef Mon Sep 17 00:00:00 2001 From: Aaron Batilo Date: Tue, 17 Sep 2024 14:19:47 -0600 Subject: [PATCH 178/206] Support OpenAI reasoning models (#850) These model strings are now available for use. More info: https://openai.com/index/introducing-openai-o1-preview/ https://platform.openai.com/docs/guides/reasoning --- completion.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/completion.go b/completion.go index 12ce4b558..e1e065a8b 100644 --- a/completion.go +++ b/completion.go @@ -17,6 +17,10 @@ var ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( + O1Mini = "o1-mini" + O1Mini20240912 = "o1-mini-2024-09-12" + O1Preview = "o1-preview" + O1Preview20240912 = "o1-preview-2024-09-12" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -83,6 +87,10 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { + O1Mini: true, + O1Mini20240912: true, + O1Preview: true, + O1Preview20240912: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, From 1ec8c24ea7ae0e31d5e8332f8a0349d2ecd5b913 Mon Sep 17 00:00:00 2001 From: Wei-An Yen Date: Sat, 21 Sep 2024 02:22:01 +0800 Subject: [PATCH 179/206] fix: jsonschema integer validation (#852) --- jsonschema/validate.go | 4 ++++ jsonschema/validate_test.go | 48 +++++++++++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/jsonschema/validate.go b/jsonschema/validate.go index f14ffd4c4..49f9b8859 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -36,6 +36,10 @@ func Validate(schema Definition, data any) bool { _, ok := data.(bool) return ok case Integer: + // Golang unmarshals all numbers as float64, so we need to check if the float64 is an integer + if num, ok := data.(float64); ok { + return num == float64(int64(num)) + } _, ok := data.(int) return ok case Null: diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index c2c47a2ce..6fa30ab0c 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -86,14 +86,6 @@ func TestUnmarshal(t *testing.T) { content []byte v any } - var result1 struct { - String string `json:"string"` - Number float64 `json:"number"` - } - var result2 struct { - String string `json:"string"` - Number float64 `json:"number"` - } tests := []struct { name string args args @@ -108,7 +100,10 @@ func TestUnmarshal(t *testing.T) { }, }, content: []byte(`{"string":"abc","number":123.4}`), - v: &result1, + v: &struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, }, false}, {"", args{ schema: jsonschema.Definition{ @@ -120,7 +115,40 @@ func TestUnmarshal(t *testing.T) { Required: []string{"string", "number"}, }, content: []byte(`{"string":"abc"}`), - v: result2, + v: struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, + }, true}, + {"validate integer", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, + }, false}, + {"validate integer failed", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123.4}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, }, true}, } for _, tt := range tests { From 9add1c348607c14e8fde9966713c97f9a2351919 Mon Sep 17 00:00:00 2001 From: Ivan Timofeev Date: Fri, 20 Sep 2024 23:40:24 +0300 Subject: [PATCH 180/206] add max_completions_tokens for o1 series models (#857) * add max_completions_tokens for o1 series models * add validation for o1 series models validataion + beta limitations --- chat.go | 35 +++++--- chat_stream.go | 4 + chat_stream_test.go | 21 +++++ chat_test.go | 211 ++++++++++++++++++++++++++++++++++++++++++++ completion.go | 82 +++++++++++++++++ 5 files changed, 341 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index dc60f35b9..d47c95e4f 100644 --- a/chat.go +++ b/chat.go @@ -200,18 +200,25 @@ type ChatCompletionResponseFormatJSONSchema struct { // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + // MaxTokens The maximum number of tokens that can be generated in the chat completion. + // This value can be used to control costs for text generated via API. + // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens + MaxTokens int `json:"max_tokens,omitempty"` + // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, + // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning + MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias @@ -364,6 +371,10 @@ func (c *Client) CreateChatCompletion( return } + if err = validateRequestForO1Models(request); err != nil { + return + } + req, err := c.newRequest( ctx, http.MethodPost, diff --git a/chat_stream.go b/chat_stream.go index 3f90bc019..f43d01834 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,6 +60,10 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true + if err = validateRequestForO1Models(request); err != nil { + return + } + req, err := c.newRequest( ctx, http.MethodPost, diff --git a/chat_stream_test.go b/chat_stream_test.go index 63e45ee23..2e7c99b45 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -36,6 +36,27 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { } } +func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1/chat/completions" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + req := openai.ChatCompletionRequest{ + Model: openai.O1Preview, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } + _, err := client.CreateChatCompletionStream(ctx, req) + if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) { + t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err) + } +} + func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index 37dc09d4d..a54dd35e0 100644 --- a/chat_test.go +++ b/chat_test.go @@ -52,6 +52,199 @@ func TestChatCompletionsWrongModel(t *testing.T) { checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } +func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "o1-preview_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Preview, + }, + expectedError: openai.ErrO1MaxTokensDeprecated, + }, + { + name: "o1-mini_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Mini, + }, + expectedError: openai.ErrO1MaxTokensDeprecated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + LogProbs: true, + Model: openai.O1Preview, + }, + expectedError: openai.ErrO1BetaLimitationsLogprobs, + }, + { + name: "message_type_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + }, + }, + }, + expectedError: openai.ErrO1BetaLimitationsMessageTypes, + }, + { + name: "tool_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + }, + }, + }, + expectedError: openai.ErrO1BetaLimitationsTools, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + func TestChatRequestOmitEmpty(t *testing.T) { data, err := json.Marshal(openai.ChatCompletionRequest{ // We set model b/c it's required, so omitempty doesn't make sense @@ -97,6 +290,24 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestO1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O1Preview, + MaxCompletionsTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() diff --git a/completion.go b/completion.go index e1e065a8b..8e3172ace 100644 --- a/completion.go +++ b/completion.go @@ -7,11 +7,20 @@ import ( ) var ( + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionsTokens") //nolint:lll ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll ) +var ( + ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll + ErrO1BetaLimitationsStreaming = errors.New("this model has beta-limitations, streaming not supported") //nolint:lll + ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll + ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific @@ -85,6 +94,15 @@ const ( CodexCodeDavinci001 = "code-davinci-001" ) +// O1SeriesModels List of new Series of OpenAI models. +// Some old api attributes not supported. +var O1SeriesModels = map[string]struct{}{ + O1Mini: {}, + O1Mini20240912: {}, + O1Preview: {}, + O1Preview20240912: {}, +} + var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { O1Mini: true, @@ -146,6 +164,70 @@ func checkPromptType(prompt any) bool { return isString || isStringSlice } +var unsupportedToolsForO1Models = map[ToolType]struct{}{ + ToolTypeFunction: {}, +} + +var availableMessageRoleForO1Models = map[string]struct{}{ + ChatMessageRoleUser: {}, + ChatMessageRoleAssistant: {}, +} + +// validateRequestForO1Models checks for deprecated fields of OpenAI models. +func validateRequestForO1Models(request ChatCompletionRequest) error { + if _, found := O1SeriesModels[request.Model]; !found { + return nil + } + + if request.MaxTokens > 0 { + return ErrO1MaxTokensDeprecated + } + + // Beta Limitations + // refs:https://platform.openai.com/docs/guides/reasoning/beta-limitations + // Streaming: not supported + if request.Stream { + return ErrO1BetaLimitationsStreaming + } + // Logprobs: not supported. + if request.LogProbs { + return ErrO1BetaLimitationsLogprobs + } + + // Message types: user and assistant messages only, system messages are not supported. + for _, m := range request.Messages { + if _, found := availableMessageRoleForO1Models[m.Role]; !found { + return ErrO1BetaLimitationsMessageTypes + } + } + + // Tools: tools, function calling, and response format parameters are not supported + for _, t := range request.Tools { + if _, found := unsupportedToolsForO1Models[t.Type]; found { + return ErrO1BetaLimitationsTools + } + } + + // Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0. + if request.Temperature > 0 && request.Temperature != 1 { + return ErrO1BetaLimitationsOther + } + if request.TopP > 0 && request.TopP != 1 { + return ErrO1BetaLimitationsOther + } + if request.N > 0 && request.N != 1 { + return ErrO1BetaLimitationsOther + } + if request.PresencePenalty > 0 { + return ErrO1BetaLimitationsOther + } + if request.FrequencyPenalty > 0 { + return ErrO1BetaLimitationsOther + } + + return nil +} + // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { Model string `json:"model"` From 9a4f3a7dbf8f29408848c94cf933d1530ae64526 Mon Sep 17 00:00:00 2001 From: Jialin Tian Date: Sat, 21 Sep 2024 04:49:28 +0800 Subject: [PATCH 181/206] feat: add ParallelToolCalls to RunRequest (#847) --- run.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/run.go b/run.go index 5598f1dfb..0cdec2bdc 100644 --- a/run.go +++ b/run.go @@ -37,6 +37,8 @@ type Run struct { MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` httpHeader } From e095df5325a39ed94940dbe3882d2aa14eb64ad0 Mon Sep 17 00:00:00 2001 From: floodwm Date: Fri, 20 Sep 2024 23:54:25 +0300 Subject: [PATCH 182/206] run_id string Optional (#855) Filter messages by the run ID that generated them. Co-authored-by: wappi --- .zshrc | 0 client_test.go | 2 +- messages.go | 5 +++++ messages_test.go | 5 +++-- 4 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 .zshrc diff --git a/.zshrc b/.zshrc new file mode 100644 index 000000000..e69de29bb diff --git a/client_test.go b/client_test.go index 7119d8a7e..3f27b9dd7 100644 --- a/client_test.go +++ b/client_test.go @@ -340,7 +340,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { return client.CreateMessage(ctx, "", MessageRequest{}) }}, {"ListMessage", func() (any, error) { - return client.ListMessage(ctx, "", nil, nil, nil, nil) + return client.ListMessage(ctx, "", nil, nil, nil, nil, nil) }}, {"RetrieveMessage", func() (any, error) { return client.RetrieveMessage(ctx, "", "") diff --git a/messages.go b/messages.go index 1fddd6314..eefc29a36 100644 --- a/messages.go +++ b/messages.go @@ -100,6 +100,7 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, order *string, after *string, before *string, + runID *string, ) (messages MessagesList, err error) { urlValues := url.Values{} if limit != nil { @@ -114,6 +115,10 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, if before != nil { urlValues.Add("before", *before) } + if runID != nil { + urlValues.Add("run_id", *runID) + } + encodedValues := "" if len(urlValues) > 0 { encodedValues = "?" + urlValues.Encode() diff --git a/messages_test.go b/messages_test.go index 71ceb4d3a..b25755f98 100644 --- a/messages_test.go +++ b/messages_test.go @@ -208,7 +208,7 @@ func TestMessages(t *testing.T) { } var msgs openai.MessagesList - msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil, nil) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") @@ -219,7 +219,8 @@ func TestMessages(t *testing.T) { order := "desc" after := "obj_foo" before := "obj_bar" - msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) + runID := "run_abc123" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before, &runID) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") From 38bdc812df391bcec3d7defda2a456ea00bb54e5 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 26 Sep 2024 18:25:56 +0800 Subject: [PATCH 183/206] Optimize Client Error Return (#856) * update client error return * update client_test.go * update client_test.go * update file_api_test.go * update client_test.go * update client_test.go --- client.go | 9 ++++++ client_test.go | 76 +++++++++++++++++++++++++++++++++-------------- error.go | 6 ++-- files_api_test.go | 1 + 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/client.go b/client.go index 9f547e7cb..583244fe1 100644 --- a/client.go +++ b/client.go @@ -285,10 +285,18 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB } func (c *Client) handleErrorResp(resp *http.Response) error { + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } + return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) + } var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ + HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, } @@ -298,6 +306,7 @@ func (c *Client) handleErrorResp(resp *http.Response) error { return reqErr } + errRes.Error.HTTPStatus = resp.Status errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } diff --git a/client_test.go b/client_test.go index 3f27b9dd7..18da787a0 100644 --- a/client_test.go +++ b/client_test.go @@ -134,14 +134,17 @@ func TestHandleErrorResp(t *testing.T) { client := NewClient(mockToken) testCases := []struct { - name string - httpCode int - body io.Reader - expected string + name string + httpCode int + httpStatus string + contentType string + body io.Reader + expected string }{ { - name: "401 Invalid Authentication", - httpCode: http.StatusUnauthorized, + name: "401 Invalid Authentication", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -152,11 +155,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: You didn't provide an API key. ....", + expected: "error, status code: 401, status: , message: You didn't provide an API key. ....", }, { - name: "401 Azure Access Denied", - httpCode: http.StatusUnauthorized, + name: "401 Azure Access Denied", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -165,11 +169,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.", + expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.", }, { - name: "503 Model Overloaded", - httpCode: http.StatusServiceUnavailable, + name: "503 Model Overloaded", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{ @@ -179,22 +184,53 @@ func TestHandleErrorResp(t *testing.T) { "code":null } }`)), - expected: "error, status code: 503, message: That model...", + expected: "error, status code: 503, status: , message: That model...", }, { - name: "503 no message (Unknown response)", - httpCode: http.StatusServiceUnavailable, + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{} }`)), - expected: "error, status code: 503, message: ", + expected: "error, status code: 503, status: , message: ", + }, + { + name: "413 Request Entity Too Large", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: bytes.NewReader([]byte(` +413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ +`)), + expected: `error, status code: 413, status: , body: +413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ +`, + }, + { + name: "errorReader", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: &errorReader{err: errors.New("errorReader")}, + expected: "error, reading response body: errorReader", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - testCase := &http.Response{} + testCase := &http.Response{ + Header: map[string][]string{ + "Content-Type": {tc.contentType}, + }, + } testCase.StatusCode = tc.httpCode testCase.Body = io.NopCloser(tc.body) err := client.handleErrorResp(testCase) @@ -203,12 +239,6 @@ func TestHandleErrorResp(t *testing.T) { t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected) t.Fail() } - - e := &APIError{} - if !errors.As(err, &e) { - t.Errorf("(%s) Expected error to be of type APIError", tc.name) - t.Fail() - } }) } } diff --git a/error.go b/error.go index 37959a272..1f6a8971d 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ type APIError struct { Message string `json:"message"` Param *string `json:"param,omitempty"` Type string `json:"type"` + HTTPStatus string `json:"-"` HTTPStatusCode int `json:"-"` InnerError *InnerError `json:"innererror,omitempty"` } @@ -25,6 +26,7 @@ type InnerError struct { // RequestError provides information about generic request errors. type RequestError struct { + HTTPStatus string HTTPStatusCode int Err error } @@ -35,7 +37,7 @@ type ErrorResponse struct { func (e *APIError) Error() string { if e.HTTPStatusCode > 0 { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message) } return e.Message @@ -101,7 +103,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err) } func (e *RequestError) Unwrap() error { diff --git a/files_api_test.go b/files_api_test.go index c92162a84..aa4fda458 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -152,6 +152,7 @@ func TestGetFileContentReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, wantErrorResp) }) From 7f80303cc393edf2f6806ca37668346f8fa6247e Mon Sep 17 00:00:00 2001 From: Alex Philipp Date: Thu, 26 Sep 2024 05:26:22 -0500 Subject: [PATCH 184/206] Fix max_completion_tokens (#860) The json tag is incorrect, and results in an error from the API when using the o1 model. I didn't modify the struct field name to maintain compatibility if anyone else had started using it, but it wouldn't work for them either. --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index d47c95e4f..dd99c530e 100644 --- a/chat.go +++ b/chat.go @@ -209,7 +209,7 @@ type ChatCompletionRequest struct { MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning - MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"` + MaxCompletionsTokens int `json:"max_completion_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` N int `json:"n,omitempty"` From e9d8485e90092b8adcce82fdd0dcd7cf10327e8d Mon Sep 17 00:00:00 2001 From: Jialin Tian Date: Thu, 26 Sep 2024 18:26:54 +0800 Subject: [PATCH 185/206] fix: ParallelToolCalls should be added to RunRequest (#861) --- run.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run.go b/run.go index 0cdec2bdc..d3e755f05 100644 --- a/run.go +++ b/run.go @@ -37,8 +37,6 @@ type Run struct { MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` - // Disable the default behavior of parallel tool calls by setting it: false. - ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` httpHeader } @@ -112,6 +110,8 @@ type RunRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // This can be either a string or a ResponseFormat object. ResponseFormat any `json:"response_format,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread. From fdd59d93413154cd07b2e46a428b15eda40b26e2 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 26 Sep 2024 18:30:56 +0800 Subject: [PATCH 186/206] feat: usage struct add CompletionTokensDetails (#863) --- common.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/common.go b/common.go index cbfda4e3c..cde14154a 100644 --- a/common.go +++ b/common.go @@ -4,7 +4,13 @@ package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` +} + +// CompletionTokensDetails Breakdown of tokens used in a completion. +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` } From bac7d5936108965a9666a65d0d4d55bd0fe78808 Mon Sep 17 00:00:00 2001 From: Winston Liu Date: Thu, 3 Oct 2024 12:17:16 -0700 Subject: [PATCH 187/206] fix MaxCompletionTokens typo (#862) * fix spelling error * fix lint * Update chat.go * Update chat.go --- chat.go | 22 +++++++++++----------- chat_test.go | 38 +++++++++++++++++++------------------- completion.go | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/chat.go b/chat.go index dd99c530e..9adf2808d 100644 --- a/chat.go +++ b/chat.go @@ -207,18 +207,18 @@ type ChatCompletionRequest struct { // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens MaxTokens int `json:"max_tokens,omitempty"` - // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, + // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning - MaxCompletionsTokens int `json:"max_completion_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias diff --git a/chat_test.go b/chat_test.go index a54dd35e0..134026cdb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -100,17 +100,17 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "log_probs_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - LogProbs: true, - Model: openai.O1Preview, + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O1Preview, }, expectedError: openai.ErrO1BetaLimitationsLogprobs, }, { name: "message_type_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, @@ -122,8 +122,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "tool_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -143,8 +143,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_temperature_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -160,8 +160,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_top_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -178,8 +178,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_n_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -197,8 +197,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_presence_penalty_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -214,8 +214,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_frequency_penalty_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -296,8 +296,8 @@ func TestO1ModelChatCompletions(t *testing.T) { defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ - Model: openai.O1Preview, - MaxCompletionsTokens: 1000, + Model: openai.O1Preview, + MaxCompletionTokens: 1000, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, diff --git a/completion.go b/completion.go index 8e3172ace..80c4d39ae 100644 --- a/completion.go +++ b/completion.go @@ -7,7 +7,7 @@ import ( ) var ( - ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionsTokens") //nolint:lll + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll From 7c145ebb4be68610bc3bb5377b754944307d44fd Mon Sep 17 00:00:00 2001 From: Julio Martins <89476495+juliomartinsdev@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:19:48 -0300 Subject: [PATCH 188/206] add jailbreak filter result, add ContentFilterResults on output (#864) * add jailbreak filter result * add content filter results on completion output * add profanity content filter --- chat.go | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/chat.go b/chat.go index 9adf2808d..a7dee8e03 100644 --- a/chat.go +++ b/chat.go @@ -41,11 +41,23 @@ type Violence struct { Severity string `json:"severity,omitempty"` } +type JailBreak struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + +type Profanity struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + type ContentFilterResults struct { - Hate Hate `json:"hate,omitempty"` - SelfHarm SelfHarm `json:"self_harm,omitempty"` - Sexual Sexual `json:"sexual,omitempty"` - Violence Violence `json:"violence,omitempty"` + Hate Hate `json:"hate,omitempty"` + SelfHarm SelfHarm `json:"self_harm,omitempty"` + Sexual Sexual `json:"sexual,omitempty"` + Violence Violence `json:"violence,omitempty"` + JailBreak JailBreak `json:"jailbreak,omitempty"` + Profanity Profanity `json:"profanity,omitempty"` } type PromptAnnotation struct { @@ -338,19 +350,21 @@ type ChatCompletionChoice struct { // function_call: The model decided to call a function // content_filter: Omitted content due to a flag from our content filters // null: API response still in progress or incomplete - FinishReason FinishReason `json:"finish_reason"` - LogProbs *LogProbs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + LogProbs *LogProbs `json:"logprobs,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage Usage `json:"usage"` - SystemFingerprint string `json:"system_fingerprint"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` httpHeader } From 991326480f84981b6e89032b9f9710a3a83a6f0f Mon Sep 17 00:00:00 2001 From: Isaac Seymour Date: Wed, 9 Oct 2024 10:50:27 +0100 Subject: [PATCH 189/206] Completion API: add new params (#870) * Completion API: add 'store' param This param allows you to opt a completion request in to being stored, for use in distillations and evals. * Add cached and audio tokens to usage structs These have been added to the completions API recently: https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage --- common.go | 8 ++++++++ completion.go | 27 +++++++++++++++------------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/common.go b/common.go index cde14154a..8cc7289c0 100644 --- a/common.go +++ b/common.go @@ -7,10 +7,18 @@ type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` } // CompletionTokensDetails Breakdown of tokens used in a completion. type CompletionTokensDetails struct { + AudioTokens int `json:"audio_tokens"` ReasoningTokens int `json:"reasoning_tokens"` } + +// PromptTokensDetails Breakdown of tokens used in the prompt. +type PromptTokensDetails struct { + AudioTokens int `json:"audio_tokens"` + CachedTokens int `json:"cached_tokens"` +} diff --git a/completion.go b/completion.go index 80c4d39ae..afcf84671 100644 --- a/completion.go +++ b/completion.go @@ -238,18 +238,21 @@ type CompletionRequest struct { // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - Seed *int `json:"seed,omitempty"` - Stop []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - Suffix string `json:"suffix,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - User string `json:"user,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From cfe15ffd00bb908c32cf0d9e277786a14afdd2c7 Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Mon, 14 Oct 2024 18:50:39 +0530 Subject: [PATCH 190/206] return response body as byte slice for RequestError type (#873) --- client.go | 11 ++++++----- error.go | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 583244fe1..1e228a097 100644 --- a/client.go +++ b/client.go @@ -285,20 +285,21 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB } func (c *Client) handleErrorResp(resp *http.Response) error { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("error, reading response body: %w", err) - } return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) } var errRes ErrorResponse - err := json.NewDecoder(resp.Body).Decode(&errRes) + err = json.Unmarshal(body, &errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, + Body: body, } if errRes.Error != nil { reqErr.Err = errRes.Error diff --git a/error.go b/error.go index 1f6a8971d..fc9e7cdb9 100644 --- a/error.go +++ b/error.go @@ -29,6 +29,7 @@ type RequestError struct { HTTPStatus string HTTPStatusCode int Err error + Body []byte } type ErrorResponse struct { From 21f713457449b1ab386529b9495cbf1f27c0db5a Mon Sep 17 00:00:00 2001 From: Matt Jacobs Date: Mon, 14 Oct 2024 09:21:39 -0400 Subject: [PATCH 191/206] Adding new moderation model constants (#875) --- moderation.go | 12 ++++++++---- moderation_test.go | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/moderation.go b/moderation.go index c8652efc8..a0e09c0ee 100644 --- a/moderation.go +++ b/moderation.go @@ -14,8 +14,10 @@ import ( // If you use text-moderation-stable, we will provide advanced notice before updating the model. // Accuracy of text-moderation-stable may be slightly lower than for text-moderation-latest. const ( - ModerationTextStable = "text-moderation-stable" - ModerationTextLatest = "text-moderation-latest" + ModerationOmniLatest = "omni-moderation-latest" + ModerationOmni20240926 = "omni-moderation-2024-09-26" + ModerationTextStable = "text-moderation-stable" + ModerationTextLatest = "text-moderation-latest" // Deprecated: use ModerationTextStable and ModerationTextLatest instead. ModerationText001 = "text-moderation-001" ) @@ -25,8 +27,10 @@ var ( ) var validModerationModel = map[string]struct{}{ - ModerationTextStable: {}, - ModerationTextLatest: {}, + ModerationOmniLatest: {}, + ModerationOmni20240926: {}, + ModerationTextStable: {}, + ModerationTextLatest: {}, } // ModerationRequest represents a request structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 61171c384..a97f25bc6 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -37,6 +37,8 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) { getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), getModerationModelTestOption(openai.ModerationTextStable, nil), getModerationModelTestOption(openai.ModerationTextLatest, nil), + getModerationModelTestOption(openai.ModerationOmni20240926, nil), + getModerationModelTestOption(openai.ModerationOmniLatest, nil), getModerationModelTestOption("", nil), ) client, server, teardown := setupOpenAITestServer() From b162541513db0cf3d4d48da03be22b05861269cb Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 15 Oct 2024 20:09:34 +0100 Subject: [PATCH 192/206] Cleanup (#879) * remove obsolete files * update readme --- .zshrc | 0 Makefile | 35 ----------------------------------- README.md | 2 +- 3 files changed, 1 insertion(+), 36 deletions(-) delete mode 100644 .zshrc delete mode 100644 Makefile diff --git a/.zshrc b/.zshrc deleted file mode 100644 index e69de29bb..000000000 diff --git a/Makefile b/Makefile deleted file mode 100644 index 2e608aa0c..000000000 --- a/Makefile +++ /dev/null @@ -1,35 +0,0 @@ -##@ General - -# The help target prints out all targets with their descriptions organized -# beneath their categories. The categories are represented by '##@' and the -# target descriptions by '##'. The awk commands is responsible for reading the -# entire set of makefiles included in this invocation, looking for lines of the -# file as xyz: ## something, and then pretty-format the target and help. Then, -# if there's a line with ##@ something, that gets pretty-printed as a category. -# More info on the usage of ANSI control characters for terminal formatting: -# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters -# More info on the awk command: -# http://linuxcommand.org/lc3_adv_awk.php - -.PHONY: help -help: ## Display this help. - @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) - - -##@ Development - -.PHONY: test -TEST_ARGS ?= -v -TEST_TARGETS ?= ./... -test: ## Test the Go modules within this package. - @ echo ▶️ go test $(TEST_ARGS) $(TEST_TARGETS) - go test $(TEST_ARGS) $(TEST_TARGETS) - @ echo ✅ success! - - -.PHONY: lint -LINT_TARGETS ?= ./... -lint: ## Lint Go code with the installed golangci-lint - @ echo "▶️ golangci-lint run" - golangci-lint run $(LINT_TARGETS) - @ echo "✅ golangci-lint run" diff --git a/README.md b/README.md index b3ebc1471..57d1d35bf 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: -* ChatGPT +* ChatGPT 4o, o1 * GPT-3, GPT-4 * DALL·E 2, DALL·E 3 * Whisper From 9fe2c6ce1f5b756cd172ae9a7786beea69b2956f Mon Sep 17 00:00:00 2001 From: Sander Mack-Crane <71154168+smackcrane@users.noreply.github.com> Date: Tue, 15 Oct 2024 14:16:57 -0600 Subject: [PATCH 193/206] Completion API: add Store and Metadata parameters (#878) --- chat.go | 5 +++++ completion.go | 26 ++++++++++++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index a7dee8e03..2b13f8dd7 100644 --- a/chat.go +++ b/chat.go @@ -255,6 +255,11 @@ type ChatCompletionRequest struct { StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Disable the default behavior of parallel tool calls by setting it: false. ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` } type StreamOptions struct { diff --git a/completion.go b/completion.go index afcf84671..84ef2ad26 100644 --- a/completion.go +++ b/completion.go @@ -241,18 +241,20 @@ type CompletionRequest struct { LogitBias map[string]int `json:"logit_bias,omitempty"` // Store can be set to true to store the output of this completion request for use in distillations and evals. // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store - Store bool `json:"store,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - Seed *int `json:"seed,omitempty"` - Stop []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - Suffix string `json:"suffix,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - User string `json:"user,omitempty"` + Store bool `json:"store,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From fb15ff9dcd861e601fc2c54078aac2bbd3c06ce8 Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Tue, 22 Oct 2024 02:19:34 +0530 Subject: [PATCH 194/206] Handling for non-json response (#881) * removed handling for non-json response * added response body in RequestError.Error() and updated tests * done linting --- client.go | 3 --- client_test.go | 35 ++++++++++++++++++++--------------- error.go | 5 ++++- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 1e228a097..ed8595e0b 100644 --- a/client.go +++ b/client.go @@ -289,9 +289,6 @@ func (c *Client) handleErrorResp(resp *http.Response) error { if err != nil { return fmt.Errorf("error, reading response body: %w", err) } - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { - return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) - } var errRes ErrorResponse err = json.Unmarshal(body, &errRes) if err != nil || errRes.Error == nil { diff --git a/client_test.go b/client_test.go index 18da787a0..354a6b3f5 100644 --- a/client_test.go +++ b/client_test.go @@ -194,26 +194,31 @@ func TestHandleErrorResp(t *testing.T) { { "error":{} }`)), - expected: "error, status code: 503, status: , message: ", + expected: `error, status code: 503, status: , message: , body: + { + "error":{} + }`, }, { name: "413 Request Entity Too Large", httpCode: http.StatusRequestEntityTooLarge, contentType: "text/html", - body: bytes.NewReader([]byte(` -413 Request Entity Too Large - -

413 Request Entity Too Large

-
nginx
- -`)), - expected: `error, status code: 413, status: , body: -413 Request Entity Too Large - -

413 Request Entity Too Large

-
nginx
- -`, + body: bytes.NewReader([]byte(` + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `)), + expected: `error, status code: 413, status: , message: invalid character '<' looking for beginning of value, body: + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `, }, { name: "errorReader", diff --git a/error.go b/error.go index fc9e7cdb9..8a74bd52c 100644 --- a/error.go +++ b/error.go @@ -104,7 +104,10 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err) + return fmt.Sprintf( + "error, status code: %d, status: %s, message: %s, body: %s", + e.HTTPStatusCode, e.HTTPStatus, e.Err, e.Body, + ) } func (e *RequestError) Unwrap() error { From 3672c0dec601f89037d8d54e7df653d7df1f0c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edin=20=C4=86orali=C4=87?= <73831203+ecoralic@users.noreply.github.com> Date: Mon, 21 Oct 2024 22:57:02 +0200 Subject: [PATCH 195/206] fix: Updated Assistent struct with latest fields based on OpenAI docs (#883) --- assistant.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/assistant.go b/assistant.go index 4c89c1b2f..8aab5bcf0 100644 --- a/assistant.go +++ b/assistant.go @@ -14,17 +14,20 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` // Deprecated in v2 + Metadata map[string]any `json:"metadata,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` httpHeader } From 6e087322b77693e6e9227d9950a0c8d8a10a8d1a Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Fri, 25 Oct 2024 19:11:45 +0530 Subject: [PATCH 196/206] Updated checkPromptType function to handle prompt list in completions (#885) * updated checkPromptType function to handle prompt list in completions * removed generated test file * added corresponding unit testcases * Updated to use less nesting with early returns --- completion.go | 18 ++++++++++- completion_test.go | 78 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/completion.go b/completion.go index 84ef2ad26..77ea8c3ab 100644 --- a/completion.go +++ b/completion.go @@ -161,7 +161,23 @@ func checkEndpointSupportsModel(endpoint, model string) bool { func checkPromptType(prompt any) bool { _, isString := prompt.(string) _, isStringSlice := prompt.([]string) - return isString || isStringSlice + if isString || isStringSlice { + return true + } + + // check if it is prompt is []string hidden under []any + slice, isSlice := prompt.([]any) + if !isSlice { + return false + } + + for _, item := range slice { + _, itemIsString := item.(string) + if !itemIsString { + return false + } + } + return true // all items in the slice are string, so it is []string } var unsupportedToolsForO1Models = map[ToolType]struct{}{ diff --git a/completion_test.go b/completion_test.go index 89950bf94..935bbe864 100644 --- a/completion_test.go +++ b/completion_test.go @@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) { checks.NoError(t, err, "CreateCompletion error") } +// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts with wrong type. +func TestMultiplePromptsCompletionsWrong(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", 9}, + } + _, err := client.CreateCompletion(context.Background(), req) + if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err) + } +} + +// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts. +func TestMultiplePromptsCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"}, + } + _, err := client.CreateCompletion(context.Background(), req) + checks.NoError(t, err, "CreateCompletion error") +} + // handleCompletionEndpoint Handles the completion endpoint by the test server. func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error @@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if n == 0 { n = 1 } + // Handle different types of prompts: single string or list of strings + prompts := []string{} + switch v := completionReq.Prompt.(type) { + case string: + prompts = append(prompts, v) + case []interface{}: + for _, item := range v { + if str, ok := item.(string); ok { + prompts = append(prompts, str) + } + } + default: + http.Error(w, "Invalid prompt type", http.StatusBadRequest) + return + } + for i := 0; i < n; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt.(string) + completionStr + for _, prompt := range prompts { + // Generate a random string of length completionReq.MaxTokens + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = prompt + completionStr + } + + res.Choices = append(res.Choices, openai.CompletionChoice{ + Text: completionStr, + Index: len(res.Choices), + }) } - res.Choices = append(res.Choices, openai.CompletionChoice{ - Text: completionStr, - Index: i, - }) } - inputTokens := numTokens(completionReq.Prompt.(string)) * n - completionTokens := completionReq.MaxTokens * n + + inputTokens := 0 + for _, prompt := range prompts { + inputTokens += numTokens(prompt) + } + inputTokens *= n + completionTokens := completionReq.MaxTokens * len(prompts) * n res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } + + // Serialize the response and send it back resBytes, _ = json.Marshal(res) fmt.Fprintln(w, string(resBytes)) } From d10f1b81995ddce1aacacfa671d79f2784a68ef4 Mon Sep 17 00:00:00 2001 From: genglixia <62233468+Yu0u@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:22:52 +0800 Subject: [PATCH 197/206] add chatcompletion stream delta refusal and logprobs (#882) * add chatcompletion stream refusal and logprobs * fix slice to struct * add integration test * fix lint * fix lint * fix: the object should be pointer --------- Co-authored-by: genglixia --- chat_stream.go | 28 ++++- chat_stream_test.go | 265 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 289 insertions(+), 4 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index f43d01834..58b2651c0 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -10,13 +10,33 @@ type ChatCompletionStreamChoiceDelta struct { Role string `json:"role,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` +} + +type ChatCompletionStreamChoiceLogprobs struct { + Content []ChatCompletionTokenLogprob `json:"content,omitempty"` + Refusal []ChatCompletionTokenLogprob `json:"refusal,omitempty"` +} + +type ChatCompletionTokenLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes,omitempty"` + Logprob float64 `json:"logprob,omitempty"` + TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs"` +} + +type ChatCompletionTokenLogprobTopLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes"` + Logprob float64 `json:"logprob"` } type ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason FinishReason `json:"finish_reason"` - ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + Logprobs *ChatCompletionStreamChoiceLogprobs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } type PromptFilterResult struct { diff --git a/chat_stream_test.go b/chat_stream_test.go index 2e7c99b45..14684146c 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -358,6 +358,271 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithRefusal(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":" World"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT4oMini20240718, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: " World", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamWithLogprobs(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":{"content":[],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":{"content":[{"token":"Hello","logprob":-0.000020458236,"bytes":[72,101,108,108,111],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":" World"},"logprobs":{"content":[{"token":" World","logprob":-0.00055303273,"bytes":[32,87,111,114,108,100],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{}, + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: "Hello", + Logprob: -0.000020458236, + Bytes: []int64{72, 101, 108, 108, 111}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " World", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: " World", + Logprob: -0.00055303273, + Bytes: []int64{32, 87, 111, 114, 108, 100}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { wantCode := "429" wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " + From f5e6e0e4fed1284bafa4805f6487e5b5f8a4ccd1 Mon Sep 17 00:00:00 2001 From: Matt Davis Date: Fri, 8 Nov 2024 08:53:02 -0500 Subject: [PATCH 198/206] Added Vector Store File List properties that allow for pagination (#891) --- vector_store.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vector_store.go b/vector_store.go index 5c364362a..682bb1cf9 100644 --- a/vector_store.go +++ b/vector_store.go @@ -83,6 +83,9 @@ type VectorStoreFileRequest struct { type VectorStoreFilesList struct { VectorStoreFiles []VectorStoreFile `json:"data"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` httpHeader } From 6d066bb12dfbaa3cefa83f204c431fb0d0ef02fa Mon Sep 17 00:00:00 2001 From: Denny Depok <61371551+kodernubie@users.noreply.github.com> Date: Fri, 8 Nov 2024 20:54:27 +0700 Subject: [PATCH 199/206] Support Attachments in MessageRequest (#890) * add attachments in MessageRequest * Move tools const to message * remove const, just use assistanttool const --- messages.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/messages.go b/messages.go index eefc29a36..902363938 100644 --- a/messages.go +++ b/messages.go @@ -52,10 +52,11 @@ type ImageFile struct { } type MessageRequest struct { - Role string `json:"role"` - Content string `json:"content"` - FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility - Metadata map[string]any `json:"metadata,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` } type MessageFile struct { From b3ece4d32e9416105bc2427b735448e82abd448b Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Wed, 20 Nov 2024 02:07:10 +0530 Subject: [PATCH 200/206] Updated client_test to solve lint error (#900) * updated client_test to solve lint error * modified golangci yml to solve linter issues * minor change --- .golangci.yml | 6 +++--- client_test.go | 10 ++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 58fab4a20..724cb7375 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -57,7 +57,7 @@ linters-settings: # Default: true skipRecvDeref: false - gomnd: + mnd: # List of function patterns to exclude from analysis. # Values always ignored: `time.Date` # Default: [] @@ -167,7 +167,7 @@ linters: - durationcheck # check for two durations multiplied together - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds + # Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - exhaustive # check exhaustiveness of enum switch statements - exportloopref # checks for pointers to enclosing loop variables - forbidigo # Forbids identifiers @@ -180,7 +180,6 @@ linters: - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomnd # An analyzer to detect magic numbers. - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with f at the end @@ -188,6 +187,7 @@ linters: - lll # Reports long lines - makezero # Finds slice declarations with non-zero initial length # - nakedret # Finds naked returns in functions greater than a specified function length + - mnd # An analyzer to detect magic numbers. - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. diff --git a/client_test.go b/client_test.go index 354a6b3f5..2ed82f13c 100644 --- a/client_test.go +++ b/client_test.go @@ -513,8 +513,14 @@ func TestClient_suffixWithAPIVersion(t *testing.T) { } defer func() { if r := recover(); r != nil { - if r.(string) != tt.wantPanic { - t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic) + // Check if the panic message matches the expected panic message + if rStr, ok := r.(string); ok { + if rStr != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", rStr, tt.wantPanic) + } + } else { + // If the panic is not a string, log it + t.Errorf("suffixWithAPIVersion() panicked with non-string value: %v", r) } } }() From 168761616567a1cf2645c98f6f19329877f0beaa Mon Sep 17 00:00:00 2001 From: LinYushen Date: Thu, 21 Nov 2024 04:26:10 +0800 Subject: [PATCH 201/206] o1 model support stream (#904) --- chat_stream_test.go | 21 --------------------- completion.go | 7 ------- 2 files changed, 28 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 14684146c..28a9acf67 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -36,27 +36,6 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { } } -func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) { - config := openai.DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1/chat/completions" - client := openai.NewClientWithConfig(config) - ctx := context.Background() - - req := openai.ChatCompletionRequest{ - Model: openai.O1Preview, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - } - _, err := client.CreateChatCompletionStream(ctx, req) - if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) { - t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err) - } -} - func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/completion.go b/completion.go index 77ea8c3ab..9e3073694 100644 --- a/completion.go +++ b/completion.go @@ -15,7 +15,6 @@ var ( var ( ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll - ErrO1BetaLimitationsStreaming = errors.New("this model has beta-limitations, streaming not supported") //nolint:lll ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll @@ -199,12 +198,6 @@ func validateRequestForO1Models(request ChatCompletionRequest) error { return ErrO1MaxTokensDeprecated } - // Beta Limitations - // refs:https://platform.openai.com/docs/guides/reasoning/beta-limitations - // Streaming: not supported - if request.Stream { - return ErrO1BetaLimitationsStreaming - } // Logprobs: not supported. if request.LogProbs { return ErrO1BetaLimitationsLogprobs From 74ed75f291f8f55d1104a541090d46c021169115 Mon Sep 17 00:00:00 2001 From: nagar-ajay Date: Thu, 21 Nov 2024 02:09:44 +0530 Subject: [PATCH 202/206] Make user field optional in embedding request (#899) * make user optional in embedding request * fix unit test --- batch_test.go | 2 +- embeddings.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/batch_test.go b/batch_test.go index 4b2261e0e..f4714f4eb 100644 --- a/batch_test.go +++ b/batch_test.go @@ -211,7 +211,7 @@ func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { Input: []string{"Hello", "World"}, }, }, - }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/embeddings.go b/embeddings.go index 74eb8aa57..4a0e682da 100644 --- a/embeddings.go +++ b/embeddings.go @@ -155,7 +155,7 @@ const ( type EmbeddingRequest struct { Input any `json:"input"` Model EmbeddingModel `json:"model"` - User string `json:"user"` + User string `json:"user,omitempty"` EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. From 21fa42c18dbafef43977ab73c403eef6d694b14a Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Sat, 30 Nov 2024 17:39:47 +0800 Subject: [PATCH 203/206] feat: add gpt-4o-2024-11-20 model (#905) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 9e3073694..f11566081 100644 --- a/completion.go +++ b/completion.go @@ -37,6 +37,7 @@ const ( GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4o20241120 = "gpt-4o-2024-11-20" GPT4oLatest = "chatgpt-4o-latest" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" @@ -119,6 +120,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4o: true, GPT4o20240513: true, GPT4o20240806: true, + GPT4o20241120: true, GPT4oLatest: true, GPT4oMini: true, GPT4oMini20240718: true, From c203ca001fecd40210cfcf9923ab69235c92e321 Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sat, 30 Nov 2024 18:29:05 +0800 Subject: [PATCH 204/206] feat: add RecvRaw (#896) --- stream_reader.go | 39 ++++++++++++++++++++++----------------- stream_reader_test.go | 13 +++++++++++++ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index 4210a1948..ecfa26807 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -32,17 +32,28 @@ type streamReader[T streamable] struct { } func (stream *streamReader[T]) Recv() (response T, err error) { - if stream.isFinished { - err = io.EOF + rawLine, err := stream.RecvRaw() + if err != nil { return } - response, err = stream.processLines() - return + err = stream.unmarshaler.Unmarshal(rawLine, &response) + if err != nil { + return + } + return response, nil +} + +func (stream *streamReader[T]) RecvRaw() ([]byte, error) { + if stream.isFinished { + return nil, io.EOF + } + + return stream.processLines() } //nolint:gocognit -func (stream *streamReader[T]) processLines() (T, error) { +func (stream *streamReader[T]) processLines() ([]byte, error) { var ( emptyMessagesCount uint hasErrorPrefix bool @@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) { if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { - return *new(T), fmt.Errorf("error, %w", respErr.Error) + return nil, fmt.Errorf("error, %w", respErr.Error) } - return *new(T), readErr + return nil, readErr } noSpaceLine := bytes.TrimSpace(rawLine) @@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) { } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { - return *new(T), writeErr + return nil, writeErr } emptyMessagesCount++ if emptyMessagesCount > stream.emptyMessagesLimit { - return *new(T), ErrTooManyEmptyStreamMessages + return nil, ErrTooManyEmptyStreamMessages } continue @@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) { noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) if string(noPrefixLine) == "[DONE]" { stream.isFinished = true - return *new(T), io.EOF - } - - var response T - unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response) - if unmarshalErr != nil { - return *new(T), unmarshalErr + return nil, io.EOF } - return response, nil + return noPrefixLine, nil } } diff --git a/stream_reader_test.go b/stream_reader_test.go index cd6e46eff..449a14b43 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { _, err := stream.Recv() checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) } + +func TestStreamReaderRecvRaw(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))), + } + rawLine, err := stream.RecvRaw() + if err != nil { + t.Fatalf("Did not return raw line: %v", err) + } + if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) { + t.Fatalf("Did not return raw line: %v", string(rawLine)) + } +} From af5355f5b1a7701f891109e8a17b7b245ac5363b Mon Sep 17 00:00:00 2001 From: Tim Misiak Date: Sun, 8 Dec 2024 05:12:05 -0800 Subject: [PATCH 205/206] Fix ID field to be optional (#911) The ID field is not always present for streaming responses. Without omitempty, the entire ToolCall struct will be missing. --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 2b13f8dd7..fcaf79cf7 100644 --- a/chat.go +++ b/chat.go @@ -179,7 +179,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` - ID string `json:"id"` + ID string `json:"id,omitempty"` Type ToolType `json:"type"` Function FunctionCall `json:"function"` } From 56a9acf86fc3ce0e9030feafa346d64bade94027 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sun, 8 Dec 2024 13:16:48 +0000 Subject: [PATCH 206/206] Ignore test.mp3 (#913) --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 99b40bf17..b0ac1605c 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,7 @@ # Auth token for tests .openai-token -.idea \ No newline at end of file +.idea + +# Generated by tests +test.mp3 \ No newline at end of file