-
Notifications
You must be signed in to change notification settings - Fork 5
/
ai.go
118 lines (100 loc) · 2.8 KB
/
ai.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package main
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"strings"
openai "github.com/sashabaranov/go-openai"
)
var defaultModel = "gpt-4"
var defaultTemperature = 0.1
var defaultLang = "English"
type AI struct {
model string
temperature float64
lang string
client *openai.Client
}
func NewAI(model string, temperature float64, lang string) *AI {
var clientConfig openai.ClientConfig
apiKey := os.Getenv("OPENAI_API_KEY")
apiBase := os.Getenv("OPENAI_API_BASE")
if apiBase == "" {
clientConfig = openai.DefaultConfig(apiKey)
} else {
clientConfig = openai.DefaultAzureConfig(apiKey, apiBase)
}
client := openai.NewClientWithConfig(clientConfig)
ai := &AI{model: model, client: client, temperature: temperature, lang: lang}
// Azure uses deployments and is not exposed in the openai api so model is assumed to be okay
if clientConfig.APIType == openai.APITypeAzure {
return ai
}
_, err := client.GetModel(context.Background(), model)
if err != nil {
fmt.Println("Model gpt-4 not available for provided api key reverting to gpt-3.5.turbo. Sign up for the gpt-4 wait list here: https://openai.com/waitlist/gpt-4-api")
ai.model = "gpt-3.5-turbo"
}
return ai
}
func (ai *AI) Start(system, user string) []openai.ChatCompletionMessage {
messages := []openai.ChatCompletionMessage{
ai.SystemMessage(system),
ai.UserMessage(user),
}
return ai.Next(messages, "")
}
func (ai *AI) SystemMessage(message string) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: message,
}
}
func (ai *AI) UserMessage(message string) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: message,
}
}
func (ai *AI) AssistantMessage(message string) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: message,
}
}
func (ai *AI) Next(messages []openai.ChatCompletionMessage, prompt string) []openai.ChatCompletionMessage {
if prompt != "" {
messages = append(messages, ai.UserMessage(prompt))
}
request := openai.ChatCompletionRequest{
Model: ai.model,
Messages: messages,
}
stream, err := ai.client.CreateChatCompletionStream(context.Background(), request)
if err != nil {
panic(err)
}
defer stream.Close()
var contents []string
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
log.Println("\nStream finished")
break
}
if err != nil {
fmt.Printf("\nStream error: %v\n", err)
break
}
content := response.Choices[0].Delta.Content
fmt.Printf("%s", content)
contents = append(contents, content)
}
return append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: strings.Join(contents, ""),
})
}