Skip to content

Commit

Permalink
Refine the final cost stat.
Browse files Browse the repository at this point in the history
  • Loading branch information
winlinvip committed Jan 11, 2024
1 parent 85b7fac commit 004f371
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 19 deletions.
28 changes: 22 additions & 6 deletions backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ var robots []*Robot
var asrService ASRService
var ttsService TTSService

type ASRResult struct {
Text string
Duration time.Duration
}

type ASRService interface {
RequestASR(ctx context.Context, filepath, language, prompt string) (string, error)
RequestASR(ctx context.Context, filepath, language, prompt string) (*ASRResult, error)
}

type TTSService interface {
Expand Down Expand Up @@ -114,8 +119,14 @@ type Stage struct {
lastUploadAudio time.Time
// The time for last request ASR result.
lastRequestASR time.Time
// The last request ASR text.
lastRequestAsrText string
// The ASR duration of audio file.
lastAsrDuration time.Duration
// The time for last request Chat result, the first segment.
lastRequestChat time.Time
// The last response text of robot.
lastRobotFirstText string
// The time for last request TTS result, the first segment.
lastRequestTTS time.Time
// The time for last download the TTS result, the first segment.
Expand Down Expand Up @@ -580,12 +591,14 @@ func handleUploadQuestionAudio(ctx context.Context, w http.ResponseWriter, r *ht

// Do ASR, convert to text.
var asrText string
if respText, err := asrService.RequestASR(ctx, inputFile, robot.asrLanguage, stage.previousAsrText); err != nil {
if resp, err := asrService.RequestASR(ctx, inputFile, robot.asrLanguage, stage.previousAsrText); err != nil {
return errors.Wrapf(err, "transcription")
} else {
asrText = strings.TrimSpace(respText)
asrText = strings.TrimSpace(resp.Text)
stage.previousAsrText = asrText
stage.lastRequestASR = time.Now()
stage.lastAsrDuration = resp.Duration
stage.lastRequestAsrText = asrText
}
logger.Tf(ctx, "ASR ok, robot=%v(%v), lang=%v, prompt=<%v>, resp is <%v>",
robot.uuid, robot.label, robot.asrLanguage, stage.previousAsrText, asrText)
Expand Down Expand Up @@ -626,8 +639,9 @@ func handleUploadQuestionAudio(ctx context.Context, w http.ResponseWriter, r *ht

// Do chat, get the response in stream.
chatService := &openaiChatService{
onFirstResponse: func(ctx context.Context) {
onFirstResponse: func(ctx context.Context, text string) {
stage.lastRequestChat = time.Now()
stage.lastRobotFirstText = text
},
}
if err := chatService.RequestChat(ctx, rid, stage, robot); err != nil {
Expand Down Expand Up @@ -753,8 +767,10 @@ func handleDownloadAnswerTTS(ctx context.Context, w http.ResponseWriter, r *http

if !segment.logged && segment.first {
stage.lastDownloadAudio = time.Now()
logger.Tf(ctx, "Report cost total=%.1fs, upload=%.1fs, asr=%.1fs, chat=%.1fs, tts=%.1fs, download=%.1fs",
stage.total(), stage.upload(), stage.asr(), stage.chat(), stage.tts(), stage.download())
speech := float64(stage.lastAsrDuration) / float64(time.Second)
logger.Tf(ctx, "Report cost total=%.1fs, steps=[upload=%.1fs,asr=%.1fs,chat=%.1fs,tts=%.1fs,download=%.1fs], ask=%v, speech=%.1fs, answer=%v",
stage.total(), stage.upload(), stage.asr(), stage.chat(), stage.tts(), stage.download(),
stage.lastRequestAsrText, speech, stage.lastRobotFirstText)
}

// Important trace log. Note that browser may request multiple times, so we only log for the first
Expand Down
76 changes: 68 additions & 8 deletions backend/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"encoding/json"
errors_std "errors"
"fmt"
"github.com/ossrs/go-oryx-lib/errors"
Expand All @@ -12,6 +13,7 @@ import (
"os/exec"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
Expand Down Expand Up @@ -72,7 +74,7 @@ func NewOpenAIASRService() ASRService {
return &openaiASRService{}
}

func (v *openaiASRService) RequestASR(ctx context.Context, inputFile, language, prompt string) (string, error) {
func (v *openaiASRService) RequestASR(ctx context.Context, inputFile, language, prompt string) (*ASRResult, error) {
outputFile := fmt.Sprintf("%v.m4a", inputFile)

// Transcode input audio in opus or aac, to aac in m4a format.
Expand All @@ -87,11 +89,17 @@ func (v *openaiASRService) RequestASR(ctx context.Context, inputFile, language,
).Run()

if err != nil {
return "", errors.Errorf("Error converting the file")
return nil, errors.Errorf("Error converting the file")
}
logger.Tf(ctx, "Convert audio %v to %v ok", inputFile, outputFile)
}

duration, _, err := ffprobeAudio(ctx, outputFile)
if err != nil {
return nil, errors.Wrapf(err, "ffprobe")
}

// Request ASR.
client := openai.NewClientWithConfig(asrAIConfig)
resp, err := client.CreateTranscription(
ctx,
Expand All @@ -104,14 +112,65 @@ func (v *openaiASRService) RequestASR(ctx context.Context, inputFile, language,
},
)
if err != nil {
return "", errors.Wrapf(err, "asr")
return nil, errors.Wrapf(err, "asr")
}

return &ASRResult{Text: resp.Text, Duration: time.Duration(duration * float64(time.Second))}, nil
}

func ffprobeAudio(ctx context.Context, filename string) (duration float64, bitrate int, err error) {
args := []string{
"-show_error", "-show_private_data", "-v", "quiet", "-find_stream_info", "-print_format", "json",
"-show_format",
}
args = append(args, "-i", filename)

stdout, err := exec.CommandContext(ctx, "ffprobe", args...).Output()
if err != nil {
err = errors.Wrapf(err, "probe %v", filename)
return
}

type VLiveFileFormat struct {
Starttime string `json:"start_time"`
Duration string `json:"duration"`
Bitrate string `json:"bit_rate"`
Streams int32 `json:"nb_streams"`
Score int32 `json:"probe_score"`
HasVideo bool `json:"has_video"`
HasAudio bool `json:"has_audio"`
}

format := struct {
Format VLiveFileFormat `json:"format"`
}{}
if err = json.Unmarshal([]byte(stdout), &format); err != nil {
err = errors.Wrapf(err, "parse format %v", stdout)
return
}

var fv float64
if fv, err = strconv.ParseFloat(format.Format.Duration, 64); err != nil {
err = errors.Wrapf(err, "parse duration %v", format.Format.Duration)
return
} else {
duration = fv
}

var iv int64
if iv, err = strconv.ParseInt(format.Format.Bitrate, 10, 64); err != nil {
err = errors.Wrapf(err, "parse bitrate %v", format.Format.Bitrate)
return
} else {
bitrate = int(iv)
}

return resp.Text, nil
logger.Tf(ctx, "FFprobe input=%v, duration=%v, bitrate=%v", filename, duration, bitrate)
return
}

type openaiChatService struct {
onFirstResponse func(ctx context.Context)
onFirstResponse func(ctx context.Context, text string)
}

func (v *openaiChatService) RequestChat(ctx context.Context, rid string, stage *Stage, robot *Robot) error {
Expand Down Expand Up @@ -223,6 +282,7 @@ func (v *openaiChatService) handle(ctx context.Context, stage *Stage, robot *Rob
strings.ContainsRune(dc, ',') {
newSentence = true
}
//logger.Tf(ctx, "AI response: text=%v, new=%v", dc, newSentence)
}
}

Expand All @@ -243,7 +303,7 @@ func (v *openaiChatService) handle(ctx context.Context, stage *Stage, robot *Rob
if isEnglish(sentence) {
maxWords, minWords := 30, 3
if !firstSentense {
maxWords, minWords = 50, 10
maxWords, minWords = 50, 5
}

if nn := strings.Count(sentence, " "); nn >= maxWords {
Expand All @@ -254,7 +314,7 @@ func (v *openaiChatService) handle(ctx context.Context, stage *Stage, robot *Rob
} else {
maxWords, minWords := 50, 3
if !firstSentense {
maxWords, minWords = 100, 10
maxWords, minWords = 100, 5
}

if nn := utf8.RuneCount([]byte(sentence)); nn >= maxWords {
Expand All @@ -277,7 +337,7 @@ func (v *openaiChatService) handle(ctx context.Context, stage *Stage, robot *Rob
sentence = fmt.Sprintf("%v %v", robot.prefix, sentence)
}
if v.onFirstResponse != nil {
v.onFirstResponse(ctx)
v.onFirstResponse(ctx, sentence)
}
}

Expand Down
15 changes: 10 additions & 5 deletions backend/tencent.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewTencentASRService() ASRService {
return &tencentASRService{}
}

func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language, prompt string) (string, error) {
func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language, prompt string) (*ASRResult, error) {
outputFile := fmt.Sprintf("%v.wav", inputFile)

// Transcode input audio in opus or aac, to aac in m4a format.
Expand All @@ -60,11 +60,16 @@ func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language,
).Run()

if err != nil {
return "", errors.Errorf("Error converting the file")
return nil, errors.Errorf("Error converting the file")
}
logger.Tf(ctx, "Convert audio %v to %v ok", inputFile, outputFile)
}

duration, _, err := ffprobeAudio(ctx, outputFile)
if err != nil {
return nil, errors.Wrapf(err, "ffprobe")
}

// Request ASR.
EngineModelType := "16k_zh"
if language == "en" {
Expand All @@ -77,7 +82,7 @@ func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language,

data, err := ioutil.ReadFile(outputFile)
if err != nil {
return "", errors.Wrapf(err, "read wav file %v", outputFile)
return nil, errors.Wrapf(err, "read wav file %v", outputFile)
}

req := new(asr.FlashRecognitionRequest)
Expand All @@ -93,7 +98,7 @@ func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language,

resp, err := recognizer.Recognize(req, data)
if err != nil {
return "", errors.Wrapf(err, "recognize error")
return nil, errors.Wrapf(err, "recognize error")
}

var sb strings.Builder
Expand All @@ -102,7 +107,7 @@ func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language,
sb.WriteString(" ")
}

return strings.TrimSpace(sb.String()), nil
return &ASRResult{Text: strings.TrimSpace(sb.String()), Duration: time.Duration(duration * float64(time.Second))}, nil
}

type tencentTTSService struct {
Expand Down

0 comments on commit 004f371

Please sign in to comment.