Skip to content

Commit

Permalink
Merge pull request #74 from Leizhenpeng/refactor
Browse files Browse the repository at this point in the history
refactor: improve server code quality by reorganizing cache and test code
  • Loading branch information
Leizhenpeng authored Mar 9, 2023
2 parents b35b24c + 6b9783c commit 0ddafce
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 129 deletions.
2 changes: 1 addition & 1 deletion code/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func main() {
config := initialization.LoadConfig(*cfg)
initialization.LoadLarkClient(*config)

gpt := &services.ChatGPT{ApiKey: config.OpenaiApiKey}
gpt := services.NewChatGPT(config.OpenaiApiKey)
handlers.InitHandlers(*gpt, *config)

eventHandler := dispatcher.NewEventDispatcher(
Expand Down
135 changes: 58 additions & 77 deletions code/services/gpt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"strings"
"time"
)

Expand All @@ -17,26 +17,26 @@ const (
engine = "gpt-3.5-turbo"
)

type Messages struct {
Role string `json:"role"`
Content string `json:"content"`
}

// ChatGPTResponseBody 请求体
type ChatGPTResponseBody struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []ChoiceItem `json:"choices"`
Choices []ChatGPTChoiceItem `json:"choices"`
Usage map[string]interface{} `json:"usage"`
}
type ChoiceItem struct {
type ChatGPTChoiceItem struct {
Message Messages `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}

type Messages struct {
Role string `json:"role"`
Content string `json:"content"`
}

// ChatGPTRequestBody 响应体
type ChatGPTRequestBody struct {
Model string `json:"model"`
Expand All @@ -51,58 +51,6 @@ type ChatGPT struct {
ApiKey string
}

func (gpt ChatGPT) Completions(msg []Messages) (resp Messages, err error) {
requestBody := ChatGPTRequestBody{
Model: engine,
Messages: msg,
MaxTokens: maxTokens,
Temperature: temperature,
TopP: 1,
FrequencyPenalty: 0,
PresencePenalty: 0,
}
requestData, err := json.Marshal(requestBody)

if err != nil {
return resp, err
}
log.Printf("request gtp json string : %v", string(requestData))
req, err := http.NewRequest("POST", BASEURL+"chat/completions", bytes.NewBuffer(requestData))
if err != nil {
return resp, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+gpt.ApiKey)
client := &http.Client{Timeout: 110 * time.Second}
response, err := client.Do(req)
if err != nil {
return resp, err
}
defer response.Body.Close()
if response.StatusCode/2 != 100 {
return resp, fmt.Errorf("gtp api %s", response.Status)
}
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return resp, err
}

gptResponseBody := &ChatGPTResponseBody{}
// log.Println(string(body))
err = json.Unmarshal(body, gptResponseBody)
if err != nil {
return resp, err
}

resp = gptResponseBody.Choices[0].Message
return resp, nil
}

func FormatQuestion(question string) string {
return "Answer:" + question
}

type ImageGenerationRequestBody struct {
Prompt string `json:"prompt"`
N int `json:"n"`
Expand All @@ -117,43 +65,71 @@ type ImageGenerationResponseBody struct {
} `json:"data"`
}

func (gpt ChatGPT) GenerateImage(prompt string, size string,
n int) ([]string, error) {
requestBody := ImageGenerationRequestBody{
Prompt: prompt,
N: n,
Size: size,
ResponseFormat: "b64_json",
}
func (gpt ChatGPT) sendRequest(url, method string, requestBody interface{}, responseBody interface{}) error {
requestData, err := json.Marshal(requestBody)
if err != nil {
return nil, err
return err
}

req, err := http.NewRequest("POST", BASEURL+"images/generations", bytes.NewBuffer(requestData))
req, err := http.NewRequest(method, url, bytes.NewBuffer(requestData))
if err != nil {
return nil, err
return err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+gpt.ApiKey)
client := &http.Client{Timeout: 110 * time.Second}
response, err := client.Do(req)
if err != nil {
return nil, err
return err
}
defer response.Body.Close()
if response.StatusCode/2 != 100 {
return nil, fmt.Errorf("image generation api %s",
response.Status)
return fmt.Errorf("%s api %s", strings.ToUpper(method), response.Status)
}
body, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, err
return err
}

err = json.Unmarshal(body, responseBody)
if err != nil {
return err
}
return nil
}

func (gpt ChatGPT) Completions(msg []Messages) (resp Messages, err error) {
requestBody := ChatGPTRequestBody{
Model: engine,
Messages: msg,
MaxTokens: maxTokens,
Temperature: temperature,
TopP: 1,
FrequencyPenalty: 0,
PresencePenalty: 0,
}

gptResponseBody := &ChatGPTResponseBody{}
err = gpt.sendRequest(BASEURL+"chat/completions", "POST", requestBody, gptResponseBody)

if err == nil {
resp = gptResponseBody.Choices[0].Message
}
return resp, err
}

func (gpt ChatGPT) GenerateImage(prompt string, size string, n int) ([]string, error) {
requestBody := ImageGenerationRequestBody{
Prompt: prompt,
N: n,
Size: size,
ResponseFormat: "b64_json",
}

imageResponseBody := &ImageGenerationResponseBody{}
err = json.Unmarshal(body, imageResponseBody)
err := gpt.sendRequest(BASEURL+"images/generations", "POST", requestBody, imageResponseBody)

if err != nil {
return nil, err
}
Expand All @@ -163,7 +139,6 @@ func (gpt ChatGPT) GenerateImage(prompt string, size string,
b64Pool = append(b64Pool, data.Base64Json)
}
return b64Pool, nil

}

func (gpt ChatGPT) GenerateOneImage(prompt string, size string) (string, error) {
Expand All @@ -173,3 +148,9 @@ func (gpt ChatGPT) GenerateOneImage(prompt string, size string) (string, error)
}
return b64s[0], nil
}

func NewChatGPT(apiKey string) *ChatGPT {
return &ChatGPT{
ApiKey: apiKey,
}
}
15 changes: 10 additions & 5 deletions code/services/gpt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,34 @@ import (

func TestCompletions(t *testing.T) {
config := initialization.LoadConfig("../config.yaml")
msg := []Messages{

msgs := []Messages{
{Role: "system", Content: "你是一个专业的翻译官,负责中英文翻译。"},
{Role: "user", Content: "翻译这段话: The assistant messages help store prior responses. They can also be written by a developer to help give examples of desired behavior."},
}

chatGpt := &ChatGPT{ApiKey: config.OpenaiApiKey}
resp, err := chatGpt.Completions(msg)
resp, err := chatGpt.Completions(msgs)
if err != nil {
t.Error(err)
t.Errorf("TestCompletions failed with error: %v", err)
}

fmt.Println(resp.Content, resp.Role)
}

func TestGenerateOneImage(t *testing.T) {
config := initialization.LoadConfig("../config.yaml")

gpt := ChatGPT{ApiKey: config.OpenaiApiKey}
prompt := "a red apple"
size := "256x256"

imageURL, err := gpt.GenerateOneImage(prompt, size)
if err != nil {
t.Fatalf("GenerateImage failed with error: %v", err)
t.Errorf("TestGenerateOneImage failed with error: %v", err)
}

if imageURL == "" {
t.Fatalf("GenerateImage returned empty imageURL")
t.Errorf("TestGenerateOneImage returned empty imageURL")
}
}
17 changes: 7 additions & 10 deletions code/services/msgCache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ import (
type MsgService struct {
cache *cache.Cache
}
type MsgCacheInterface interface {
IfProcessed(msgId string) bool
TagProcessed(msgId string)
Clear(userId string) bool
}

var msgService *MsgService

func (u MsgService) IfProcessed(msgId string) bool {
get, b := u.cache.Get(msgId)
if !b {
return false
}
return get.(bool)
_, found := u.cache.Get(msgId)
return found
}
func (u MsgService) TagProcessed(msgId string) {
u.cache.Set(msgId, true, time.Minute*30)
Expand All @@ -27,11 +29,6 @@ func (u MsgService) Clear(userId string) bool {
return true
}

type MsgCacheInterface interface {
IfProcessed(msg string) bool
TagProcessed(msg string)
}

func GetMsgCache() MsgCacheInterface {
if msgService == nil {
msgService = &MsgService{cache: cache.New(30*time.Minute, 30*time.Minute)}
Expand Down
Loading

0 comments on commit 0ddafce

Please sign in to comment.