diff --git a/code/main.go b/code/main.go index 259d84a4..6b3eaaaf 100644 --- a/code/main.go +++ b/code/main.go @@ -27,7 +27,7 @@ func main() { config := initialization.LoadConfig(*cfg) initialization.LoadLarkClient(*config) - gpt := &services.ChatGPT{ApiKey: config.OpenaiApiKey} + gpt := services.NewChatGPT(config.OpenaiApiKey) handlers.InitHandlers(*gpt, *config) eventHandler := dispatcher.NewEventDispatcher( diff --git a/code/services/gpt3.go b/code/services/gpt3.go index cbc85205..18e006bf 100644 --- a/code/services/gpt3.go +++ b/code/services/gpt3.go @@ -5,8 +5,8 @@ import ( "encoding/json" "fmt" "io/ioutil" - "log" "net/http" + "strings" "time" ) @@ -17,26 +17,26 @@ const ( engine = "gpt-3.5-turbo" ) +type Messages struct { + Role string `json:"role"` + Content string `json:"content"` +} + // ChatGPTResponseBody 请求体 type ChatGPTResponseBody struct { ID string `json:"id"` Object string `json:"object"` Created int `json:"created"` Model string `json:"model"` - Choices []ChoiceItem `json:"choices"` + Choices []ChatGPTChoiceItem `json:"choices"` Usage map[string]interface{} `json:"usage"` } -type ChoiceItem struct { +type ChatGPTChoiceItem struct { Message Messages `json:"message"` Index int `json:"index"` FinishReason string `json:"finish_reason"` } -type Messages struct { - Role string `json:"role"` - Content string `json:"content"` -} - // ChatGPTRequestBody 响应体 type ChatGPTRequestBody struct { Model string `json:"model"` @@ -51,58 +51,6 @@ type ChatGPT struct { ApiKey string } -func (gpt ChatGPT) Completions(msg []Messages) (resp Messages, err error) { - requestBody := ChatGPTRequestBody{ - Model: engine, - Messages: msg, - MaxTokens: maxTokens, - Temperature: temperature, - TopP: 1, - FrequencyPenalty: 0, - PresencePenalty: 0, - } - requestData, err := json.Marshal(requestBody) - - if err != nil { - return resp, err - } - log.Printf("request gtp json string : %v", string(requestData)) - req, err := http.NewRequest("POST", BASEURL+"chat/completions", bytes.NewBuffer(requestData)) - if err != nil { - return resp, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+gpt.ApiKey) - client := &http.Client{Timeout: 110 * time.Second} - response, err := client.Do(req) - if err != nil { - return resp, err - } - defer response.Body.Close() - if response.StatusCode/2 != 100 { - return resp, fmt.Errorf("gtp api %s", response.Status) - } - body, err := ioutil.ReadAll(response.Body) - if err != nil { - return resp, err - } - - gptResponseBody := &ChatGPTResponseBody{} - // log.Println(string(body)) - err = json.Unmarshal(body, gptResponseBody) - if err != nil { - return resp, err - } - - resp = gptResponseBody.Choices[0].Message - return resp, nil -} - -func FormatQuestion(question string) string { - return "Answer:" + question -} - type ImageGenerationRequestBody struct { Prompt string `json:"prompt"` N int `json:"n"` @@ -117,22 +65,15 @@ type ImageGenerationResponseBody struct { } `json:"data"` } -func (gpt ChatGPT) GenerateImage(prompt string, size string, - n int) ([]string, error) { - requestBody := ImageGenerationRequestBody{ - Prompt: prompt, - N: n, - Size: size, - ResponseFormat: "b64_json", - } +func (gpt ChatGPT) sendRequest(url, method string, requestBody interface{}, responseBody interface{}) error { requestData, err := json.Marshal(requestBody) if err != nil { - return nil, err + return err } - req, err := http.NewRequest("POST", BASEURL+"images/generations", bytes.NewBuffer(requestData)) + req, err := http.NewRequest(method, url, bytes.NewBuffer(requestData)) if err != nil { - return nil, err + return err } req.Header.Set("Content-Type", "application/json") @@ -140,20 +81,55 @@ func (gpt ChatGPT) GenerateImage(prompt string, size string, client := &http.Client{Timeout: 110 * time.Second} response, err := client.Do(req) if err != nil { - return nil, err + return err } defer response.Body.Close() if response.StatusCode/2 != 100 { - return nil, fmt.Errorf("image generation api %s", - response.Status) + return fmt.Errorf("%s api %s", strings.ToUpper(method), response.Status) } body, err := ioutil.ReadAll(response.Body) if err != nil { - return nil, err + return err + } + + err = json.Unmarshal(body, responseBody) + if err != nil { + return err + } + return nil +} + +func (gpt ChatGPT) Completions(msg []Messages) (resp Messages, err error) { + requestBody := ChatGPTRequestBody{ + Model: engine, + Messages: msg, + MaxTokens: maxTokens, + Temperature: temperature, + TopP: 1, + FrequencyPenalty: 0, + PresencePenalty: 0, + } + + gptResponseBody := &ChatGPTResponseBody{} + err = gpt.sendRequest(BASEURL+"chat/completions", "POST", requestBody, gptResponseBody) + + if err == nil { + resp = gptResponseBody.Choices[0].Message + } + return resp, err +} + +func (gpt ChatGPT) GenerateImage(prompt string, size string, n int) ([]string, error) { + requestBody := ImageGenerationRequestBody{ + Prompt: prompt, + N: n, + Size: size, + ResponseFormat: "b64_json", } imageResponseBody := &ImageGenerationResponseBody{} - err = json.Unmarshal(body, imageResponseBody) + err := gpt.sendRequest(BASEURL+"images/generations", "POST", requestBody, imageResponseBody) + if err != nil { return nil, err } @@ -163,7 +139,6 @@ func (gpt ChatGPT) GenerateImage(prompt string, size string, b64Pool = append(b64Pool, data.Base64Json) } return b64Pool, nil - } func (gpt ChatGPT) GenerateOneImage(prompt string, size string) (string, error) { @@ -173,3 +148,9 @@ func (gpt ChatGPT) GenerateOneImage(prompt string, size string) (string, error) } return b64s[0], nil } + +func NewChatGPT(apiKey string) *ChatGPT { + return &ChatGPT{ + ApiKey: apiKey, + } +} diff --git a/code/services/gpt3_test.go b/code/services/gpt3_test.go index ad9563fe..4133e636 100644 --- a/code/services/gpt3_test.go +++ b/code/services/gpt3_test.go @@ -8,14 +8,16 @@ import ( func TestCompletions(t *testing.T) { config := initialization.LoadConfig("../config.yaml") - msg := []Messages{ + + msgs := []Messages{ {Role: "system", Content: "你是一个专业的翻译官,负责中英文翻译。"}, {Role: "user", Content: "翻译这段话: The assistant messages help store prior responses. They can also be written by a developer to help give examples of desired behavior."}, } + chatGpt := &ChatGPT{ApiKey: config.OpenaiApiKey} - resp, err := chatGpt.Completions(msg) + resp, err := chatGpt.Completions(msgs) if err != nil { - t.Error(err) + t.Errorf("TestCompletions failed with error: %v", err) } fmt.Println(resp.Content, resp.Role) @@ -23,14 +25,17 @@ func TestCompletions(t *testing.T) { func TestGenerateOneImage(t *testing.T) { config := initialization.LoadConfig("../config.yaml") + gpt := ChatGPT{ApiKey: config.OpenaiApiKey} prompt := "a red apple" size := "256x256" + imageURL, err := gpt.GenerateOneImage(prompt, size) if err != nil { - t.Fatalf("GenerateImage failed with error: %v", err) + t.Errorf("TestGenerateOneImage failed with error: %v", err) } + if imageURL == "" { - t.Fatalf("GenerateImage returned empty imageURL") + t.Errorf("TestGenerateOneImage returned empty imageURL") } } diff --git a/code/services/msgCache.go b/code/services/msgCache.go index 997d18c6..21bc976f 100644 --- a/code/services/msgCache.go +++ b/code/services/msgCache.go @@ -8,15 +8,17 @@ import ( type MsgService struct { cache *cache.Cache } +type MsgCacheInterface interface { + IfProcessed(msgId string) bool + TagProcessed(msgId string) + Clear(userId string) bool +} var msgService *MsgService func (u MsgService) IfProcessed(msgId string) bool { - get, b := u.cache.Get(msgId) - if !b { - return false - } - return get.(bool) + _, found := u.cache.Get(msgId) + return found } func (u MsgService) TagProcessed(msgId string) { u.cache.Set(msgId, true, time.Minute*30) @@ -27,11 +29,6 @@ func (u MsgService) Clear(userId string) bool { return true } -type MsgCacheInterface interface { - IfProcessed(msg string) bool - TagProcessed(msg string) -} - func GetMsgCache() MsgCacheInterface { if msgService == nil { msgService = &MsgService{cache: cache.New(30*time.Minute, 30*time.Minute)} diff --git a/code/services/sessionCache.go b/code/services/sessionCache.go index e059cc40..054ecb74 100644 --- a/code/services/sessionCache.go +++ b/code/services/sessionCache.go @@ -8,40 +8,46 @@ import ( ) type SessionMode string - -var ( - ModePicCreate SessionMode = "pic_create" - ModePicVary SessionMode = "pic_vary" - ModeGPT SessionMode = "gpt" -) - type SessionService struct { cache *cache.Cache } - +type PicSetting struct { + resolution Resolution +} type Resolution string +type SessionMeta struct { + Mode SessionMode `json:"mode"` + Msg []Messages `json:"msg,omitempty"` + PicSetting PicSetting `json:"pic_setting,omitempty"` +} + const ( Resolution256 Resolution = "256x256" Resolution512 Resolution = "512x512" Resolution1024 Resolution = "1024x1024" ) +const ( + ModePicCreate SessionMode = "pic_create" + ModePicVary SessionMode = "pic_vary" + ModeGPT SessionMode = "gpt" +) -type PicSetting struct { - resolution Resolution -} - -type SessionMeta struct { - Mode SessionMode - Msg []Messages - PicSetting PicSetting +type SessionServiceCacheInterface interface { + GetMsg(sessionId string) []Messages + SetMsg(sessionId string, msg []Messages) + SetMode(sessionId string, mode SessionMode) + GetMode(sessionId string) SessionMode + SetPicResolution(sessionId string, resolution Resolution) + GetPicResolution(sessionId string) string + Clear(sessionId string) } var sessionServices *SessionService -func (s *SessionService) GetMode(sessionID string) SessionMode { +func (s *SessionService) GetMode(sessionId string) SessionMode { // Get the session mode from the cache. - sessionContext, ok := s.cache.Get(sessionID) + sessionContext, ok := s.cache.Get(sessionId) if !ok { return ModeGPT } @@ -49,19 +55,17 @@ func (s *SessionService) GetMode(sessionID string) SessionMode { return sessionMeta.Mode } -func (s *SessionService) SetMode(sessionID string, mode SessionMode) { - // Update the session mode in the cache. +func (s *SessionService) SetMode(sessionId string, mode SessionMode) { maxCacheTime := time.Hour * 12 - - sessionContext, ok := s.cache.Get(sessionID) + sessionContext, ok := s.cache.Get(sessionId) if !ok { sessionMeta := &SessionMeta{Mode: mode} - s.cache.Set(sessionID, sessionMeta, maxCacheTime) + s.cache.Set(sessionId, sessionMeta, maxCacheTime) return } sessionMeta := sessionContext.(*SessionMeta) sessionMeta.Mode = mode - s.cache.Set(sessionID, sessionMeta, maxCacheTime) + s.cache.Set(sessionId, sessionMeta, maxCacheTime) } func (s *SessionService) GetMsg(sessionId string) (msg []Messages) { @@ -126,19 +130,9 @@ func (s *SessionService) GetPicResolution(sessionId string) string { } -func (s *SessionService) Clear(sessionID string) { +func (s *SessionService) Clear(sessionId string) { // Delete the session context from the cache. - s.cache.Delete(sessionID) -} - -type SessionServiceCacheInterface interface { - GetMsg(sessionId string) []Messages - SetMsg(sessionId string, msg []Messages) - SetMode(sessionId string, mode SessionMode) - GetMode(sessionId string) SessionMode - SetPicResolution(sessionId string, resolution Resolution) - GetPicResolution(sessionId string) string - Clear(sessionId string) + s.cache.Delete(sessionId) } func GetSessionCache() SessionServiceCacheInterface {