From 9fd3858716fbe33119692ce5df9f0949c0199d9b Mon Sep 17 00:00:00 2001 From: Ke Chen Date: Fri, 24 May 2024 15:02:04 +0800 Subject: [PATCH] feat: yocsef api --- apis/record/routes.go | 3 + apis/record/yocsef.go | 80 +++++++++++++++++++++ config/config.go | 3 + service/yocsef.go | 160 ++++++++++++++++++++++++++++++++++++++++++ utils/tools/search.go | 2 +- utils/utils.go | 34 ++++----- 6 files changed, 262 insertions(+), 20 deletions(-) create mode 100644 apis/record/yocsef.go create mode 100644 service/yocsef.go diff --git a/apis/record/routes.go b/apis/record/routes.go index 520d417..08e4bbc 100644 --- a/apis/record/routes.go +++ b/apis/record/routes.go @@ -24,4 +24,7 @@ func RegisterRoutes(routes fiber.Router) { routes.Get("/v1/models", OpenAIListModels) routes.Get("/v1/models/:name", OpenAIRetrieveModel) routes.Post("/v1/chat/completions", OpenAICreateChatCompletion) + + // yocsef API + routes.Get("/yocsef/inference", websocket.New(InferYocsefAsyncAPI)) } diff --git a/apis/record/yocsef.go b/apis/record/yocsef.go new file mode 100644 index 0000000..478682b --- /dev/null +++ b/apis/record/yocsef.go @@ -0,0 +1,80 @@ +package record + +import ( + . "MOSS_backend/models" + "MOSS_backend/service" + . "MOSS_backend/utils" + "context" + "errors" + "fmt" + "github.com/gofiber/websocket/v2" + "go.uber.org/zap" +) + +// InferYocsefAsyncAPI +// @Summary infer without login in websocket +// @Tags Websocket +// @Router /yocsef/inference [get] +// @Param json body InferenceRequest true "json" +// @Success 200 {object} InferenceResponse +func InferYocsefAsyncAPI(c *websocket.Conn) { + var ( + err error + ) + + defer func() { + if err != nil { + Logger.Error( + "client websocket return with error", + zap.Error(err), + ) + response := InferResponseModel{Status: -1, Output: err.Error()} + var httpError *HttpError + if errors.As(err, &httpError) { + response.StatusCode = httpError.Code + } + _ = c.WriteJSON(response) + } + }() + + procedure := func() error { + + // read body + var body InferenceRequest + if err = c.ReadJSON(&body); err != nil { + return fmt.Errorf("error receive message: %v", err) + } + + if body.Request == "" { + return BadRequest("request is empty") + } + + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("procedure finished")) + + // listen to interrupt and connection close + go func() { + defer cancel(errors.New("client connection closed or interrupt")) + _, _, err := c.ReadMessage() + if err != nil { + return + } + }() + + record, err := service.InferYocsef( + ctx, + c, + body.Request, + body.Records, + ) + if err != nil { + return err + } + + DB.Create(&record) + + return nil + } + + err = procedure() +} diff --git a/config/config.go b/config/config.go index 6fd0d43..cf339f6 100644 --- a/config/config.go +++ b/config/config.go @@ -65,6 +65,9 @@ var Config struct { DefaultModelID int `env:"DEFAULT_MODEL_ID" envDefault:"1"` NoNeedInviteCodeEmailSuffix []string `env:"NO_NEED_INVITE_CODE_EMAIL_SUFFIX" envSeparator:"," envDefault:"fudan.edu.cn"` + + // yocsef + YocsefInferenceUrl string `env:"YOCSEF_INFERENCE_URL"` } func InitConfig() { diff --git a/service/yocsef.go b/service/yocsef.go new file mode 100644 index 0000000..3ecd17b --- /dev/null +++ b/service/yocsef.go @@ -0,0 +1,160 @@ +package service + +import ( + "MOSS_backend/config" + "MOSS_backend/models" + "MOSS_backend/utils" + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "strings" +) + +type InferYocsefRequest struct { + Question string `json:"question,omitempty"` + ChatHistory [][]string `json:"chat_history,omitempty"` +} + +var yocsefHttpClient = &http.Client{} + +func InferYocsef( + ctx context.Context, + w utils.JSONWriter, + prompt string, + records models.RecordModels, +) ( + model *models.DirectRecord, + err error, +) { + if config.Config.YocsefInferenceUrl == "" { + return nil, errors.New("yocsef 推理模型暂不可用") + } + + var chatHistory = make([][]string, len(records)) + for i, record := range records { + chatHistory[i] = []string{record.Request, record.Response} + } + + var request = map[string]any{ + "input": map[string]any{ + "question": prompt, + "chat_history": chatHistory, + }, + } + requestData, err := json.Marshal(request) + if err != nil { + return + } + + // server send event + req, err := http.NewRequest("POST", config.Config.YocsefInferenceUrl, bytes.NewBuffer(requestData)) + if err != nil { + return + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + res, err := yocsefHttpClient.Do(req) + if err != nil { + return + } + defer res.Body.Close() + + var reader = bufio.NewReader(res.Body) + var resultBuilder strings.Builder + var nowOutput string + var detectedOutput string + + for { + line, err := reader.ReadBytes('\n') + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, err + } + if strings.HasPrefix(string(line), "event") { + continue + } + if strings.HasPrefix(string(line), "data") { + line = line[6:] + } + line = bytes.Trim(line, " \n\r") + if len(line) == 0 { + continue + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + var response map[string]any + err = json.Unmarshal(line, &response) + if err != nil { + return nil, err + } + + var ok bool + nowOutput, ok = response["content"].(string) + if !ok { + continue + } + resultBuilder.WriteString(nowOutput) + nowOutput = resultBuilder.String() + + var endDelimiter = "<|im_end|>" + if strings.Contains(nowOutput, endDelimiter) { + nowOutput = strings.Split(nowOutput, endDelimiter)[0] + break + } + + before, _, found := utils.CutLastAny(nowOutput, ",.?!\n,。?!") + if !found || before == detectedOutput { + continue + } + detectedOutput = before + + err = w.WriteJSON(InferResponseModel{ + Status: 1, + Output: nowOutput, + Stage: "MOSS", + }) + if err != nil { + return nil, err + } + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + if nowOutput != detectedOutput { + _ = w.WriteJSON(InferResponseModel{ + Status: 1, + Output: nowOutput, + Stage: "MOSS", + }) + } + + err = w.WriteJSON(InferResponseModel{ + Status: 0, + Output: nowOutput, + Stage: "MOSS", + }) + + var record = models.DirectRecord{Request: prompt, Response: nowOutput} + return &record, nil +} + +type InferResponseModel struct { + Status int `json:"status"` // 1 for output, 0 for end, -1 for error, -2 for sensitive + StatusCode int `json:"status_code,omitempty"` + Output string `json:"output,omitempty"` + Stage string `json:"stage,omitempty"` +} diff --git a/utils/tools/search.go b/utils/tools/search.go index 7cef429..f140f2a 100644 --- a/utils/tools/search.go +++ b/utils/tools/search.go @@ -141,7 +141,7 @@ func (t *searchTask) postprocess() (r *ResultModel) { } tmpAnswer := value.(Map)["summ"].(string) tmpAnswerRune := []rune(clean(tmpAnswer)) - tmpAnswerRune = tmpAnswerRune[:utils.Min(len(tmpAnswerRune), 400)] + tmpAnswerRune = tmpAnswerRune[:min(len(tmpAnswerRune), 400)] tmpAnswer = string(tmpAnswerRune) tmpSample = append(tmpSample, fmt.Sprintf("<|%d|>: %s", t.s.searchResultsIndex, tmpAnswer)) diff --git a/utils/utils.go b/utils/utils.go index 9be4c82..6340450 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,7 +2,6 @@ package utils import ( "github.com/gofiber/fiber/v2" - "golang.org/x/exp/constraints" ) type CanPreprocess interface { @@ -27,23 +26,7 @@ func GetRealIP(c *fiber.Ctx) string { } func StripContent(content string, length int) string { - return string([]rune(content)[:Min(len([]rune(content)), length)]) -} - -func Min[T constraints.Ordered](x, y T) T { - if x < y { - return x - } else { - return y - } -} - -func Max[T constraints.Ordered](x, y T) T { - if x > y { - return x - } else { - return y - } + return string([]rune(content)[:min(len([]rune(content)), length)]) } func CutLastAny(s string, chars string) (before, after string, found bool) { @@ -58,7 +41,7 @@ func CutLastAny(s string, chars string) (before, after string, found bool) { } } if index > 0 { - maxIndex = Max(maxIndex, index) + maxIndex = min(maxIndex, index) } } if maxIndex == -1 { @@ -67,3 +50,16 @@ func CutLastAny(s string, chars string) (before, after string, found bool) { return string(sourceRunes[:maxIndex+1]), string(sourceRunes[maxIndex+1:]), true } } + +type JSONReader interface { + ReadJson(any) error +} + +type JSONWriter interface { + WriteJSON(any) error +} + +type JsonReaderWriter interface { + JSONReader + JSONWriter +}