-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
262 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package record | ||
|
||
import ( | ||
. "MOSS_backend/models" | ||
"MOSS_backend/service" | ||
. "MOSS_backend/utils" | ||
"context" | ||
"errors" | ||
"fmt" | ||
"github.com/gofiber/websocket/v2" | ||
"go.uber.org/zap" | ||
) | ||
|
||
// InferYocsefAsyncAPI | ||
// @Summary infer without login in websocket | ||
// @Tags Websocket | ||
// @Router /yocsef/inference [get] | ||
// @Param json body InferenceRequest true "json" | ||
// @Success 200 {object} InferenceResponse | ||
func InferYocsefAsyncAPI(c *websocket.Conn) { | ||
var ( | ||
err error | ||
) | ||
|
||
defer func() { | ||
if err != nil { | ||
Logger.Error( | ||
"client websocket return with error", | ||
zap.Error(err), | ||
) | ||
response := InferResponseModel{Status: -1, Output: err.Error()} | ||
var httpError *HttpError | ||
if errors.As(err, &httpError) { | ||
response.StatusCode = httpError.Code | ||
} | ||
_ = c.WriteJSON(response) | ||
} | ||
}() | ||
|
||
procedure := func() error { | ||
|
||
// read body | ||
var body InferenceRequest | ||
if err = c.ReadJSON(&body); err != nil { | ||
return fmt.Errorf("error receive message: %v", err) | ||
} | ||
|
||
if body.Request == "" { | ||
return BadRequest("request is empty") | ||
} | ||
|
||
ctx, cancel := context.WithCancelCause(context.Background()) | ||
defer cancel(errors.New("procedure finished")) | ||
|
||
// listen to interrupt and connection close | ||
go func() { | ||
defer cancel(errors.New("client connection closed or interrupt")) | ||
_, _, err := c.ReadMessage() | ||
if err != nil { | ||
return | ||
} | ||
}() | ||
|
||
record, err := service.InferYocsef( | ||
ctx, | ||
c, | ||
body.Request, | ||
body.Records, | ||
) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
DB.Create(&record) | ||
|
||
return nil | ||
} | ||
|
||
err = procedure() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package service | ||
|
||
import ( | ||
"MOSS_backend/config" | ||
"MOSS_backend/models" | ||
"MOSS_backend/utils" | ||
"bufio" | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"io" | ||
"net/http" | ||
"strings" | ||
) | ||
|
||
type InferYocsefRequest struct { | ||
Question string `json:"question,omitempty"` | ||
ChatHistory [][]string `json:"chat_history,omitempty"` | ||
} | ||
|
||
var yocsefHttpClient = &http.Client{} | ||
|
||
func InferYocsef( | ||
ctx context.Context, | ||
w utils.JSONWriter, | ||
prompt string, | ||
records models.RecordModels, | ||
) ( | ||
model *models.DirectRecord, | ||
err error, | ||
) { | ||
if config.Config.YocsefInferenceUrl == "" { | ||
return nil, errors.New("yocsef 推理模型暂不可用") | ||
} | ||
|
||
var chatHistory = make([][]string, len(records)) | ||
for i, record := range records { | ||
chatHistory[i] = []string{record.Request, record.Response} | ||
} | ||
|
||
var request = map[string]any{ | ||
"input": map[string]any{ | ||
"question": prompt, | ||
"chat_history": chatHistory, | ||
}, | ||
} | ||
requestData, err := json.Marshal(request) | ||
if err != nil { | ||
return | ||
} | ||
|
||
// server send event | ||
req, err := http.NewRequest("POST", config.Config.YocsefInferenceUrl, bytes.NewBuffer(requestData)) | ||
if err != nil { | ||
return | ||
} | ||
|
||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Accept", "text/event-stream") | ||
req.Header.Set("Cache-Control", "no-cache") | ||
req.Header.Set("Connection", "keep-alive") | ||
|
||
res, err := yocsefHttpClient.Do(req) | ||
if err != nil { | ||
return | ||
} | ||
defer res.Body.Close() | ||
|
||
var reader = bufio.NewReader(res.Body) | ||
var resultBuilder strings.Builder | ||
var nowOutput string | ||
var detectedOutput string | ||
|
||
for { | ||
line, err := reader.ReadBytes('\n') | ||
if errors.Is(err, io.EOF) { | ||
break | ||
} | ||
if err != nil { | ||
return nil, err | ||
} | ||
if strings.HasPrefix(string(line), "event") { | ||
continue | ||
} | ||
if strings.HasPrefix(string(line), "data") { | ||
line = line[6:] | ||
} | ||
line = bytes.Trim(line, " \n\r") | ||
if len(line) == 0 { | ||
continue | ||
} | ||
|
||
if ctx.Err() != nil { | ||
return nil, ctx.Err() | ||
} | ||
|
||
var response map[string]any | ||
err = json.Unmarshal(line, &response) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var ok bool | ||
nowOutput, ok = response["content"].(string) | ||
if !ok { | ||
continue | ||
} | ||
resultBuilder.WriteString(nowOutput) | ||
nowOutput = resultBuilder.String() | ||
|
||
var endDelimiter = "<|im_end|>" | ||
if strings.Contains(nowOutput, endDelimiter) { | ||
nowOutput = strings.Split(nowOutput, endDelimiter)[0] | ||
break | ||
} | ||
|
||
before, _, found := utils.CutLastAny(nowOutput, ",.?!\n,。?!") | ||
if !found || before == detectedOutput { | ||
continue | ||
} | ||
detectedOutput = before | ||
|
||
err = w.WriteJSON(InferResponseModel{ | ||
Status: 1, | ||
Output: nowOutput, | ||
Stage: "MOSS", | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
if ctx.Err() != nil { | ||
return nil, ctx.Err() | ||
} | ||
if nowOutput != detectedOutput { | ||
_ = w.WriteJSON(InferResponseModel{ | ||
Status: 1, | ||
Output: nowOutput, | ||
Stage: "MOSS", | ||
}) | ||
} | ||
|
||
err = w.WriteJSON(InferResponseModel{ | ||
Status: 0, | ||
Output: nowOutput, | ||
Stage: "MOSS", | ||
}) | ||
|
||
var record = models.DirectRecord{Request: prompt, Response: nowOutput} | ||
return &record, nil | ||
} | ||
|
||
type InferResponseModel struct { | ||
Status int `json:"status"` // 1 for output, 0 for end, -1 for error, -2 for sensitive | ||
StatusCode int `json:"status_code,omitempty"` | ||
Output string `json:"output,omitempty"` | ||
Stage string `json:"stage,omitempty"` | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters