diff --git a/image/create.go b/image/create.go index 8b882a4..488598c 100644 --- a/image/create.go +++ b/image/create.go @@ -9,6 +9,7 @@ type CreateParams struct { Size string `json:"size,omitempty"` Format string `json:"response_format,omitempty"` User string `json:"user,omitempty"` + Model string `json:"model,omitempty"` } type CreateResponse struct { @@ -17,6 +18,9 @@ type CreateResponse struct { } func (c *Client) Create(ctx context.Context, p *CreateParams) (*CreateResponse, error) { + if p.Model == "" { + p.Model = c.model + } var r CreateResponse if err := c.s.MakeRequest(ctx, c.CreateEndpoint, p, &r); err != nil { return nil, err diff --git a/image/image.go b/image/image.go index 002ddd5..c30d499 100644 --- a/image/image.go +++ b/image/image.go @@ -12,12 +12,16 @@ import ( ) const ( + DallE2 = "dall-e-2" + DallE3 = "dall-e-3" defaultCreateEndpoint = "https://api.openai.com/v1/images/generations" + defaultModel = DallE2 ) // Client is a client to communicate with Open AI's images API. type Client struct { - s *openai.Session + s *openai.Session + model string // CreateEndpoint allows overriding the default // for the image generation API endpoint. @@ -25,9 +29,14 @@ type Client struct { CreateEndpoint string } -func NewClient(session *openai.Session) *Client { +func NewClient(session *openai.Session, model ...string) *Client { + m := defaultModel + if len(model) > 0 { + m = model[0] + } return &Client{ s: session, + model: m, CreateEndpoint: defaultCreateEndpoint, } }