Skip to content

Commit

Permalink
update txt2txt to make it openai-api compatible
Browse files Browse the repository at this point in the history
oobabooga dropped the streaming api, gotta do this instead now
  • Loading branch information
softmix committed Mar 22, 2024
1 parent 1572b72 commit fda3832
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 131 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module bot
go 1.22.1

require (
github.com/gorilla/websocket v1.5.1
github.com/mattn/go-sqlite3 v1.14.22
github.com/rs/zerolog v1.32.0
github.com/sethvargo/go-retry v0.2.4
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
Expand Down
227 changes: 99 additions & 128 deletions txt2txt.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
package main

import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"os"
"strings"

"github.com/gorilla/websocket"
"github.com/rs/zerolog/log"
"maunium.net/go/mautrix/event"
)

type Txt2txt struct {
aiCharacter AICharacter
Histories map[string]History
Histories map[string][]Message
}

type AICharacter struct {
Expand All @@ -25,106 +28,52 @@ type AICharacter struct {
}

type RequestData struct {
UserInput string `json:"user_input"`
History History `json:"history"`
Mode string `json:"mode"`
Character string `json:"character"`
InstructionTemplate string `json:"instruction_template"`
YourName string `json:"your_name"`
Regenerate bool `json:"regenerate"`
Continue bool `json:"_continue"`
StopAtNewline bool `json:"stop_at_newline"`
ChatPromptSize int `json:"chat_prompt_size"`
ChatGenerationAttempts int `json:"chat_generation_attempts"`
ChatInstructCommand string `json:"chat-instruct_command"`
MaxNewTokens int `json:"max_new_tokens"`
DoSample bool `json:"do_sample"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TypicalP float64 `json:"typical_p"`
EpsilonCutoff int `json:"epsilon_cutoff"`
EtaCutoff int `json:"eta_cutoff"`
Tfs int `json:"tfs"`
TopA int `json:"top_a"`
RepetitionPenalty float64 `json:"repetition_penalty"`
TopK int `json:"top_k"`
MinLength int `json:"min_length"`
NoRepeatNgramSize int `json:"no_repeat_ngram_size"`
NumBeams int `json:"num_beams"`
PenaltyAlpha int `json:"penalty_alpha"`
Preset string `json:"preset"`
LengthPenalty int `json:"length_penalty"`
EarlyStopping bool `json:"early_stopping"`
MirostatMode int `json:"mirostat_mode"`
MirostatTau int `json:"mirostat_tau"`
MirostatEta float64 `json:"mirostat_eta"`
Seed int `json:"seed"`
AddBOSToken bool `json:"add_bos_token"`
TruncationLength int `json:"truncation_length"`
BanEOSToken bool `json:"ban_eos_token"`
SkipSpecialTokens bool `json:"skip_special_tokens"`
StoppingStrings []string `json:"stopping_strings"`
Messages []Message `json:"messages"`
Mode string `json:"mode"`
Character string `json:"character,omitempty"`
Stream bool `json:"stream"`
}

type History struct {
Internal [][]string `json:"internal"`
Visible [][]string `json:"visible"`
type IncomingData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []Resp `json:"choices"`
Usage Usage `json:"usage"`
}

type IncomingData struct {
Event string `json:"event"`
History History `json:"history,omitempty"`
type Resp struct {
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
Delta Message `json:"delta"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

func dataForPrompt(username, user_input string, history History) RequestData {
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

func dataForPrompt(username, user_input string, history []Message) RequestData {
return RequestData{
UserInput: user_input,
History: history,

Mode: "chat",
Character: Bot.txt2txt.aiCharacter.name,
InstructionTemplate: "None",
YourName: username,

Regenerate: false,
Continue: false,
StopAtNewline: false,
ChatPromptSize: 2048,
ChatGenerationAttempts: 1,
ChatInstructCommand: "",

MaxNewTokens: 500,
DoSample: true,
Temperature: 0.98,
TopP: 0.37,
TypicalP: 0.19,
EpsilonCutoff: 0,
EtaCutoff: 0,
Tfs: 1,
TopA: 0,
RepetitionPenalty: 1.18,
TopK: 100,
MinLength: 0,
NoRepeatNgramSize: 0,
NumBeams: 1,
PenaltyAlpha: 0,
Preset: "None",
LengthPenalty: 1,
EarlyStopping: false,
MirostatMode: 0,
MirostatTau: 5,
MirostatEta: 0.1,
Seed: -1,
AddBOSToken: true,
TruncationLength: 2048,
BanEOSToken: false,
SkipSpecialTokens: true,
StoppingStrings: []string{"END_OF_DIALOG"},
Messages: append(history, Message{
Role: "user",
Content: user_input,
}),
Mode: "chat",
Stream: true,
//Character: Bot.txt2txt.aiCharacter.name,
}
}

func NewTxt2txt() *Txt2txt {
instructions_body, err := ioutil.ReadFile("prompts_instructions.md")
instructions_body, err := os.ReadFile("prompts_instructions.md")
if err != nil {
log.Fatal().Msg("Couldn't read prompts_instructions.md")
}
Expand All @@ -134,7 +83,7 @@ func NewTxt2txt() *Txt2txt {
name: "TavernAI-Gray", // TODO
instructions: string(instructions_body),
},
Histories: make(map[string]History),
Histories: make(map[string][]Message),
}
}

Expand All @@ -156,7 +105,7 @@ func (b *Txt2txt) LoadHistories() error {
data, err := ioutil.ReadFile(Bot.configuration.Txt2TxtHistoryFile)
if err != nil {
if os.IsNotExist(err) {
b.Histories = map[string]History{}
b.Histories = map[string][]Message{}
b.SaveHistories()
return nil
}
Expand All @@ -171,11 +120,8 @@ func (b *Txt2txt) LoadHistories() error {

func (b *Txt2txt) GetPredictionForPrompt(event *event.Event, prompt string) (string, error) {
history := b.Histories[string(event.RoomID)]
if len(history.Visible) == 0 {
history.Visible = [][]string{}
}
if len(history.Internal) == 0 {
history.Internal = [][]string{}
if len(history) == 0 {
history = []Message{}
}

username, err := Bot.client.GetDisplayName(context.Background(), event.Sender)
Expand All @@ -188,57 +134,82 @@ func (b *Txt2txt) GetPredictionForPrompt(event *event.Event, prompt string) (str
return prompt, err
}

if len(reply.Visible[0]) > 0 {
if len(reply) > 0 {
b.Histories[string(event.RoomID)] = reply
b.SaveHistories()
}

log.Debug().Msgf("Bot response: %s", reply)
return reply.Visible[len(reply.Visible)-1][1], err
return reply[len(reply)-1].Content, nil
}

func run(requestData RequestData) (History, error) {
conn, _, err := websocket.DefaultDialer.Dial(Bot.configuration.Txt2TxtAPIURL, nil)
func run(requestData RequestData) ([]Message, error) {
// Marshal the request data to JSON
requestDataBytes, err := json.Marshal(requestData)
if err != nil {
return requestData.History, err
return requestData.Messages, err
}
defer conn.Close()

messageBytes, err := json.Marshal(requestData)
// Create a new request
req, err := http.NewRequest("POST", Bot.configuration.Txt2TxtAPIURL, bytes.NewBuffer(requestDataBytes))
if err != nil {
return requestData.History, err
return requestData.Messages, err
}
req.Header.Set("Content-Type", "application/json")

log.Debug().Msgf("Sending: %s", messageBytes)
err = conn.WriteMessage(websocket.TextMessage, messageBytes)
// Execute the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return requestData.History, err
return requestData.Messages, err
}
defer resp.Body.Close()

var incomingData IncomingData
var result History
curLen := 0
// Ensure we only accept a 200 OK response indicating that the SSE stream is established
if resp.StatusCode != http.StatusOK {
return requestData.Messages, fmt.Errorf("received non-200 status code: %d", resp)

Check failure on line 170 in txt2txt.go

View workflow job for this annotation

GitHub Actions / build

fmt.Errorf format %d has arg resp of wrong type *net/http.Response
}

// Create a buffered reader for the response body to read line by line
reader := bufio.NewReader(resp.Body)
var result []Message

var currentMessageContent string
processLoop:
for {
_, message, err := conn.ReadMessage()
if err != nil {
return requestData.History, err
}
log.Debug().Msgf("Received: %s", message)

err = json.Unmarshal(message, &incomingData)
line, err := reader.ReadBytes('\n')
if err != nil {
return requestData.History, err
if err.Error() == "unexpected EOF" {
log.Error().Err(err).Msg("Unexpected EOF")
break
}
log.Error().Err(err).Msg("Error reading line")
return requestData.Messages, err
}

switch incomingData.Event {
case "text_stream":
curMessage := incomingData.History.Visible[len(incomingData.History.Visible)-1][1][curLen:]
curLen += len(curMessage)
fmt.Print(curMessage)
result = incomingData.History
case "stream_end":
break processLoop
// Process only lines starting with "data: ", which contains the actual message
if strings.HasPrefix(string(line), "data: ") {
var incomingData IncomingData
dataBytes := line[5:] // Remove "data: " prefix and trim the line
dataBytes = bytes.TrimSpace(dataBytes)

err = json.Unmarshal(dataBytes, &incomingData)
if err != nil {
return requestData.Messages, err
}

currentMessageContent += incomingData.Choices[0].Delta.Content

fmt.Print(incomingData.Choices[0].Delta.Content)

if incomingData.Choices[0].FinishReason != nil {
result = append(requestData.Messages, Message{
Role: incomingData.Choices[0].Delta.Role,
Content: currentMessageContent,
})
currentMessageContent = ""
break processLoop
}
}
}
return result, nil
Expand Down

0 comments on commit fda3832

Please sign in to comment.