Skip to content

Commit

Permalink
Follow naming conventions of the audio API (#7)
Browse files Browse the repository at this point in the history
Names of the packages and types should follow the names
used in https://platform.openai.com/docs/api-reference/audio
and be consistent with the naming strategy in other packages.

This is a breaking change but given the low current usage,
we are breaking it now before it's impossible to break.
  • Loading branch information
rakyll authored Mar 8, 2023
1 parent a037372 commit 5bdcb07
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
26 changes: 13 additions & 13 deletions whisper/whisper.go → audio/audio.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Package whisper implements a client for OpenAI's Whisper
// Package audio implements a client for OpenAI's Whisper
// audio transcriber.
package whisper
package audio

import (
"context"
Expand All @@ -11,16 +11,16 @@ import (
"github.com/rakyll/openai-go"
)

const defaultCreateCompletionsEndpoint = "https://api.openai.com/v1/audio/transcriptions"
const defaultCreateTranscriptionEndpoint = "https://api.openai.com/v1/audio/transcriptions"

// Client is a client to communicate with Open AI's ChatGPT APIs.
type Client struct {
s *openai.Session
model string

// CreateCompletionsEndpoint allows overriding the default API endpoint.
// CreateTranscriptionEndpoint allows overriding the default API endpoint.
// Set this field before using the client.
CreateCompletionEndpoint string
CreateTranscriptionEndpoint string
}

// NewClient creates a new default client that uses the given session
Expand All @@ -30,24 +30,24 @@ func NewClient(session *openai.Session, model string) *Client {
model = "whisper-1"
}
return &Client{
s: session,
model: model,
CreateCompletionEndpoint: defaultCreateCompletionsEndpoint,
s: session,
model: model,
CreateTranscriptionEndpoint: defaultCreateTranscriptionEndpoint,
}
}

type CreateCompletionParams struct {
type CreateTranscriptionParams struct {
Model string
Language string
Audio io.Reader
AudioFormat string // such as "mp3" or "wav", etc.
}

type CreateCompletionResponse struct {
type CreateTranscriptionResponse struct {
Text string `json:"text,omitempty"`
}

func (c *Client) Transcribe(ctx context.Context, p *CreateCompletionParams) (*CreateCompletionResponse, error) {
func (c *Client) CreateTranscription(ctx context.Context, p *CreateTranscriptionParams) (*CreateTranscriptionResponse, error) {
if p.AudioFormat == "" {
return nil, fmt.Errorf("audio format is required")
}
Expand All @@ -59,6 +59,6 @@ func (c *Client) Transcribe(ctx context.Context, p *CreateCompletionParams) (*Cr
if p.Language != "" {
params.Set("language", p.Language)
}
var r CreateCompletionResponse
return &r, c.s.Upload(ctx, c.CreateCompletionEndpoint, p.Audio, p.AudioFormat, params, &r)
var r CreateTranscriptionResponse
return &r, c.s.Upload(ctx, c.CreateTranscriptionEndpoint, p.Audio, p.AudioFormat, params, &r)
}
10 changes: 6 additions & 4 deletions examples/whisper/main.go → examples/audio/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"os"

"github.com/rakyll/openai-go"
"github.com/rakyll/openai-go/whisper"
"github.com/rakyll/openai-go/audio"
)

func main() {
sesh := openai.NewSession(os.Getenv("OPENAI_API_KEY"))
wc := whisper.NewClient(sesh, "")
ctx := context.Background()

s := openai.NewSession(os.Getenv("OPENAI_API_KEY"))
client := audio.NewClient(s, "")
filePath := os.Getenv("AUDIO_FILE_PATH")
if filePath == "" {
log.Fatal("must provide an AUDIO_FILE_PATH env var")
Expand All @@ -21,7 +23,7 @@ func main() {
log.Fatalf("error opening audio file: %v", err)
}
defer f.Close()
resp, err := wc.Transcribe(context.TODO(), &whisper.CreateCompletionParams{
resp, err := client.CreateTranscription(ctx, &audio.CreateTranscriptionParams{
Language: "en",
Audio: f,
AudioFormat: "mp3",
Expand Down
2 changes: 1 addition & 1 deletion examples/image/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ func main() {
if err != nil {
log.Fatalf("ReadAll error: %v", err)
}
log.Println(string(data[1:4]))
_ = data // use data
}
}

0 comments on commit 5bdcb07

Please sign in to comment.