Skip to content

Commit

Permalink
feat: yocsef api
Browse files Browse the repository at this point in the history
  • Loading branch information
JingYiJun committed May 24, 2024
1 parent 3a2c538 commit 9fd3858
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 20 deletions.
3 changes: 3 additions & 0 deletions apis/record/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ func RegisterRoutes(routes fiber.Router) {
routes.Get("/v1/models", OpenAIListModels)
routes.Get("/v1/models/:name", OpenAIRetrieveModel)
routes.Post("/v1/chat/completions", OpenAICreateChatCompletion)

// yocsef API
routes.Get("/yocsef/inference", websocket.New(InferYocsefAsyncAPI))
}
80 changes: 80 additions & 0 deletions apis/record/yocsef.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package record

import (
. "MOSS_backend/models"
"MOSS_backend/service"
. "MOSS_backend/utils"
"context"
"errors"
"fmt"
"github.com/gofiber/websocket/v2"
"go.uber.org/zap"
)

// InferYocsefAsyncAPI
// @Summary infer without login in websocket
// @Tags Websocket
// @Router /yocsef/inference [get]
// @Param json body InferenceRequest true "json"
// @Success 200 {object} InferenceResponse
func InferYocsefAsyncAPI(c *websocket.Conn) {
var (
err error
)

defer func() {
if err != nil {
Logger.Error(
"client websocket return with error",
zap.Error(err),
)
response := InferResponseModel{Status: -1, Output: err.Error()}
var httpError *HttpError
if errors.As(err, &httpError) {
response.StatusCode = httpError.Code
}
_ = c.WriteJSON(response)
}
}()

procedure := func() error {

// read body
var body InferenceRequest
if err = c.ReadJSON(&body); err != nil {
return fmt.Errorf("error receive message: %v", err)
}

if body.Request == "" {
return BadRequest("request is empty")
}

ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(errors.New("procedure finished"))

// listen to interrupt and connection close
go func() {
defer cancel(errors.New("client connection closed or interrupt"))
_, _, err := c.ReadMessage()
if err != nil {
return
}
}()

record, err := service.InferYocsef(
ctx,
c,
body.Request,
body.Records,
)
if err != nil {
return err
}

DB.Create(&record)

return nil
}

err = procedure()
}
3 changes: 3 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ var Config struct {

DefaultModelID int `env:"DEFAULT_MODEL_ID" envDefault:"1"`
NoNeedInviteCodeEmailSuffix []string `env:"NO_NEED_INVITE_CODE_EMAIL_SUFFIX" envSeparator:"," envDefault:"fudan.edu.cn"`

// yocsef
YocsefInferenceUrl string `env:"YOCSEF_INFERENCE_URL"`
}

func InitConfig() {
Expand Down
160 changes: 160 additions & 0 deletions service/yocsef.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package service

import (
"MOSS_backend/config"
"MOSS_backend/models"
"MOSS_backend/utils"
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
)

type InferYocsefRequest struct {
Question string `json:"question,omitempty"`
ChatHistory [][]string `json:"chat_history,omitempty"`
}

var yocsefHttpClient = &http.Client{}

func InferYocsef(
ctx context.Context,
w utils.JSONWriter,
prompt string,
records models.RecordModels,
) (
model *models.DirectRecord,
err error,
) {
if config.Config.YocsefInferenceUrl == "" {
return nil, errors.New("yocsef 推理模型暂不可用")
}

var chatHistory = make([][]string, len(records))
for i, record := range records {
chatHistory[i] = []string{record.Request, record.Response}
}

var request = map[string]any{
"input": map[string]any{
"question": prompt,
"chat_history": chatHistory,
},
}
requestData, err := json.Marshal(request)
if err != nil {
return
}

// server send event
req, err := http.NewRequest("POST", config.Config.YocsefInferenceUrl, bytes.NewBuffer(requestData))
if err != nil {
return
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

res, err := yocsefHttpClient.Do(req)
if err != nil {
return
}
defer res.Body.Close()

var reader = bufio.NewReader(res.Body)
var resultBuilder strings.Builder
var nowOutput string
var detectedOutput string

for {
line, err := reader.ReadBytes('\n')
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
if strings.HasPrefix(string(line), "event") {
continue
}
if strings.HasPrefix(string(line), "data") {
line = line[6:]
}
line = bytes.Trim(line, " \n\r")
if len(line) == 0 {
continue
}

if ctx.Err() != nil {
return nil, ctx.Err()
}

var response map[string]any
err = json.Unmarshal(line, &response)
if err != nil {
return nil, err
}

var ok bool
nowOutput, ok = response["content"].(string)
if !ok {
continue
}
resultBuilder.WriteString(nowOutput)
nowOutput = resultBuilder.String()

var endDelimiter = "<|im_end|>"
if strings.Contains(nowOutput, endDelimiter) {
nowOutput = strings.Split(nowOutput, endDelimiter)[0]
break
}

before, _, found := utils.CutLastAny(nowOutput, ",.?!\n,。?!")
if !found || before == detectedOutput {
continue
}
detectedOutput = before

err = w.WriteJSON(InferResponseModel{
Status: 1,
Output: nowOutput,
Stage: "MOSS",
})
if err != nil {
return nil, err
}
}

if ctx.Err() != nil {
return nil, ctx.Err()
}
if nowOutput != detectedOutput {
_ = w.WriteJSON(InferResponseModel{
Status: 1,
Output: nowOutput,
Stage: "MOSS",
})
}

err = w.WriteJSON(InferResponseModel{
Status: 0,
Output: nowOutput,
Stage: "MOSS",
})

var record = models.DirectRecord{Request: prompt, Response: nowOutput}
return &record, nil
}

type InferResponseModel struct {
Status int `json:"status"` // 1 for output, 0 for end, -1 for error, -2 for sensitive
StatusCode int `json:"status_code,omitempty"`
Output string `json:"output,omitempty"`
Stage string `json:"stage,omitempty"`
}
2 changes: 1 addition & 1 deletion utils/tools/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (t *searchTask) postprocess() (r *ResultModel) {
}
tmpAnswer := value.(Map)["summ"].(string)
tmpAnswerRune := []rune(clean(tmpAnswer))
tmpAnswerRune = tmpAnswerRune[:utils.Min(len(tmpAnswerRune), 400)]
tmpAnswerRune = tmpAnswerRune[:min(len(tmpAnswerRune), 400)]
tmpAnswer = string(tmpAnswerRune)
tmpSample = append(tmpSample, fmt.Sprintf("<|%d|>: %s", t.s.searchResultsIndex, tmpAnswer))

Expand Down
34 changes: 15 additions & 19 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package utils

import (
"github.com/gofiber/fiber/v2"
"golang.org/x/exp/constraints"
)

type CanPreprocess interface {
Expand All @@ -27,23 +26,7 @@ func GetRealIP(c *fiber.Ctx) string {
}

func StripContent(content string, length int) string {
return string([]rune(content)[:Min(len([]rune(content)), length)])
}

func Min[T constraints.Ordered](x, y T) T {
if x < y {
return x
} else {
return y
}
}

func Max[T constraints.Ordered](x, y T) T {
if x > y {
return x
} else {
return y
}
return string([]rune(content)[:min(len([]rune(content)), length)])
}

func CutLastAny(s string, chars string) (before, after string, found bool) {
Expand All @@ -58,7 +41,7 @@ func CutLastAny(s string, chars string) (before, after string, found bool) {
}
}
if index > 0 {
maxIndex = Max(maxIndex, index)
maxIndex = min(maxIndex, index)
}
}
if maxIndex == -1 {
Expand All @@ -67,3 +50,16 @@ func CutLastAny(s string, chars string) (before, after string, found bool) {
return string(sourceRunes[:maxIndex+1]), string(sourceRunes[maxIndex+1:]), true
}
}

type JSONReader interface {
ReadJson(any) error
}

type JSONWriter interface {
WriteJSON(any) error
}

type JsonReaderWriter interface {
JSONReader
JSONWriter
}

0 comments on commit 9fd3858

Please sign in to comment.