Skip to content

Commit

Permalink
Add Whisper client (#2)
Browse files Browse the repository at this point in the history
* Add Whisper client

* Fix merge conflicts + abstract request sending
  • Loading branch information
Marwan Sulaiman authored Mar 3, 2023
1 parent aa9a03b commit e332e66
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 6 deletions.
60 changes: 54 additions & 6 deletions openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"mime/multipart"
"net/http"
"net/url"
"time"
)

Expand Down Expand Up @@ -39,25 +41,45 @@ func NewSession(apiKey string) *Session {
// MakeRequest make HTTP requests and authenticates them with
// session's API key. MakeRequest marshals input as the request body,
// and unmarshals the response as output.
func (s *Session) MakeRequest(ctx context.Context, endpoint string, input, output interface{}) error {
func (s *Session) MakeRequest(ctx context.Context, endpoint string, input, output any) error {
buf, err := json.Marshal(input)
if err != nil {
return err
}

req, err := http.NewRequest("POST", endpoint, bytes.NewReader(buf))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(buf))
if err != nil {
return err
}
req = req.WithContext(ctx)

return s.sendRequest(req, "application/json", output)
}

// Upload makes a multi-part form data upload them with
// session's API key. Upload combines the file with the given params
// and unmarshals the response as output.
func (s *Session) Upload(ctx context.Context, endpoint string, file io.Reader, fileExt string, params url.Values, output any) error {
pr, pw := io.Pipe()
mw := multipart.NewWriter(pw)
go func() {
err := upload(mw, file, fileExt, params)
pw.CloseWithError(err)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, pr)
if err != nil {
return err
}
return s.sendRequest(req, mw.FormDataContentType(), output)
}

func (s *Session) sendRequest(req *http.Request, contentType string, output any) error {
if s.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+s.apiKey)
}
if s.OrganizationID != "" {
req.Header.Set("OpenAI-Organization", s.OrganizationID)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Type", contentType)

resp, err := s.HTTPClient.Do(req)
if err != nil {
Expand All @@ -66,7 +88,7 @@ func (s *Session) MakeRequest(ctx context.Context, endpoint string, input, outpu
defer resp.Body.Close()

if resp.StatusCode < 200 || resp.StatusCode >= 400 {
respBody, err := ioutil.ReadAll(resp.Body)
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
Expand All @@ -78,6 +100,32 @@ func (s *Session) MakeRequest(ctx context.Context, endpoint string, input, outpu
return json.NewDecoder(resp.Body).Decode(output)
}

func upload(mw *multipart.Writer, file io.Reader, fileExt string, params url.Values) error {
for key := range params {
w, err := mw.CreateFormField(key)
if err != nil {
return fmt.Errorf("error creating %q field: %w", key, err)
}
_, err = fmt.Fprint(w, params.Get(key))
if err != nil {
return fmt.Errorf("error writing %q field: %w", key, err)
}
}
w, err := mw.CreateFormFile("file", "audio."+fileExt)
if err != nil {
return fmt.Errorf("error creating file: %w", err)
}
_, err = io.Copy(w, file)
if err != nil {
return fmt.Errorf("error copying file: %w", err)
}
err = mw.Close()
if err != nil {
return fmt.Errorf("error closing multipart writer: %w", err)
}
return nil
}

// APIError is returned from API requests if the API
// responds with an error.
type APIError struct {
Expand Down
64 changes: 64 additions & 0 deletions whisper/whisper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Package whisper implements a client for OpenAI's Whisper
// audio transcriber.
package whisper

import (
"context"
"fmt"
"io"
"net/url"

"github.com/rakyll/openai-go"
)

const defaultCreateCompletionsEndpoint = "https://api.openai.com/v1/audio/transcriptions"

// Client is a client to communicate with Open AI's ChatGPT APIs.
type Client struct {
s *openai.Session
model string

// CreateCompletionsEndpoint allows overriding the default API endpoint.
// Set this field before using the client.
CreateCompletionEndpoint string
}

// NewClient creates a new default client that uses the given session
// and defaults to the given model.
func NewClient(session *openai.Session, model string) *Client {
if model == "" {
model = "whisper-1"
}
return &Client{
s: session,
model: model,
CreateCompletionEndpoint: defaultCreateCompletionsEndpoint,
}
}

type CreateCompletionParams struct {
Model string
Language string
Audio io.Reader
AudioFormat string // such as "mp3" or "wav", etc.
}

type CreateCompletionResponse struct {
Text string `json:"text,omitempty"`
}

func (c *Client) Transcribe(ctx context.Context, p *CreateCompletionParams) (*CreateCompletionResponse, error) {
if p.AudioFormat == "" {
return nil, fmt.Errorf("audio format is required")
}
if p.Model == "" {
p.Model = c.model
}
params := url.Values{}
params.Set("model", p.Model)
if p.Language != "" {
params.Set("language", p.Language)
}
var r CreateCompletionResponse
return &r, c.s.Upload(ctx, c.CreateCompletionEndpoint, p.Audio, p.AudioFormat, params, &r)
}

0 comments on commit e332e66

Please sign in to comment.