diff --git a/openai.go b/openai.go index a092e04..a78216f 100644 --- a/openai.go +++ b/openai.go @@ -6,8 +6,10 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" + "mime/multipart" "net/http" + "net/url" "time" ) @@ -39,25 +41,45 @@ func NewSession(apiKey string) *Session { // MakeRequest make HTTP requests and authenticates them with // session's API key. MakeRequest marshals input as the request body, // and unmarshals the response as output. -func (s *Session) MakeRequest(ctx context.Context, endpoint string, input, output interface{}) error { +func (s *Session) MakeRequest(ctx context.Context, endpoint string, input, output any) error { buf, err := json.Marshal(input) if err != nil { return err } - req, err := http.NewRequest("POST", endpoint, bytes.NewReader(buf)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(buf)) if err != nil { return err } - req = req.WithContext(ctx) + return s.sendRequest(req, "application/json", output) +} + +// Upload makes a multi-part form data upload them with +// session's API key. Upload combines the file with the given params +// and unmarshals the response as output. +func (s *Session) Upload(ctx context.Context, endpoint string, file io.Reader, fileExt string, params url.Values, output any) error { + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + go func() { + err := upload(mw, file, fileExt, params) + pw.CloseWithError(err) + }() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, pr) + if err != nil { + return err + } + return s.sendRequest(req, mw.FormDataContentType(), output) +} + +func (s *Session) sendRequest(req *http.Request, contentType string, output any) error { if s.apiKey != "" { req.Header.Set("Authorization", "Bearer "+s.apiKey) } if s.OrganizationID != "" { req.Header.Set("OpenAI-Organization", s.OrganizationID) } - req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Type", contentType) resp, err := s.HTTPClient.Do(req) if err != nil { @@ -66,7 +88,7 @@ func (s *Session) MakeRequest(ctx context.Context, endpoint string, input, outpu defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 400 { - respBody, err := ioutil.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) if err != nil { return err } @@ -78,6 +100,32 @@ func (s *Session) MakeRequest(ctx context.Context, endpoint string, input, outpu return json.NewDecoder(resp.Body).Decode(output) } +func upload(mw *multipart.Writer, file io.Reader, fileExt string, params url.Values) error { + for key := range params { + w, err := mw.CreateFormField(key) + if err != nil { + return fmt.Errorf("error creating %q field: %w", key, err) + } + _, err = fmt.Fprint(w, params.Get(key)) + if err != nil { + return fmt.Errorf("error writing %q field: %w", key, err) + } + } + w, err := mw.CreateFormFile("file", "audio."+fileExt) + if err != nil { + return fmt.Errorf("error creating file: %w", err) + } + _, err = io.Copy(w, file) + if err != nil { + return fmt.Errorf("error copying file: %w", err) + } + err = mw.Close() + if err != nil { + return fmt.Errorf("error closing multipart writer: %w", err) + } + return nil +} + // APIError is returned from API requests if the API // responds with an error. type APIError struct { diff --git a/whisper/whisper.go b/whisper/whisper.go new file mode 100644 index 0000000..81826d3 --- /dev/null +++ b/whisper/whisper.go @@ -0,0 +1,64 @@ +// Package whisper implements a client for OpenAI's Whisper +// audio transcriber. +package whisper + +import ( + "context" + "fmt" + "io" + "net/url" + + "github.com/rakyll/openai-go" +) + +const defaultCreateCompletionsEndpoint = "https://api.openai.com/v1/audio/transcriptions" + +// Client is a client to communicate with Open AI's ChatGPT APIs. +type Client struct { + s *openai.Session + model string + + // CreateCompletionsEndpoint allows overriding the default API endpoint. + // Set this field before using the client. + CreateCompletionEndpoint string +} + +// NewClient creates a new default client that uses the given session +// and defaults to the given model. +func NewClient(session *openai.Session, model string) *Client { + if model == "" { + model = "whisper-1" + } + return &Client{ + s: session, + model: model, + CreateCompletionEndpoint: defaultCreateCompletionsEndpoint, + } +} + +type CreateCompletionParams struct { + Model string + Language string + Audio io.Reader + AudioFormat string // such as "mp3" or "wav", etc. +} + +type CreateCompletionResponse struct { + Text string `json:"text,omitempty"` +} + +func (c *Client) Transcribe(ctx context.Context, p *CreateCompletionParams) (*CreateCompletionResponse, error) { + if p.AudioFormat == "" { + return nil, fmt.Errorf("audio format is required") + } + if p.Model == "" { + p.Model = c.model + } + params := url.Values{} + params.Set("model", p.Model) + if p.Language != "" { + params.Set("language", p.Language) + } + var r CreateCompletionResponse + return &r, c.s.Upload(ctx, c.CreateCompletionEndpoint, p.Audio, p.AudioFormat, params, &r) +}