-
Notifications
You must be signed in to change notification settings - Fork 3
/
example.py
71 lines (55 loc) · 2.02 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import pickle
from engram import build_engram, sort_engrams
from transformer import get_transformer, get_generator
# Load GPT Neo
model, tokenizer = get_transformer()
generate = get_generator(model, tokenizer)
# Change as needed, this works with the shakespere example dataset in encode.py
memory_file = "shakespeare.pkl"
speaker_name = "JULIET"
GPT_name = "ROMEO"
print("Welcome to GPT Chat!")
print(f"You are chatting as {speaker_name}, and GPT is chatting as {GPT_name}.")
context = []
memories = []
if os.path.exists(memory_file):
with open(memory_file, 'rb') as handle:
memories = pickle.load(handle)
def add_engram(text, add_context=True):
context.append(text)
memoryCount = len(memories)
memories[-1]["next"] = memoryCount
engram = {
"text": text,
"engram": build_engram(model.forward, tokenizer(text, return_tensors="pt").input_ids.cuda()),
"next": -1,
"previous": memoryCount-1,
"distance": 0
}
memories.append(engram)
return engram
def build_context(now, short_term=10):
# sort engrams
m = sort_engrams(now, memories[:-short_term], top_k=600)
m = sort_engrams(now, m, top_k=150, do_distance=False, depth=2)
m = sort_engrams(now, m, top_k=42, do_distance=False, depth=3)
m.reverse()
text = ""
for memory in m:
if not memory["text"].startswith(speaker_name):
text = text + memories[memory["previous"]]["text"] + "\n"
text = text + memory["text"] + "\n"
if memory["text"].startswith(speaker_name):
text = text + memories[memory["next"]]["text"] + "\n"
for recent in context[-short_term:]: # 10 most recent messages
text = text + recent + "\n"
return text
while True:
# let user input a message
message = input(speaker_name + ": ")
engram = add_engram(speaker_name + ": " + message)
text = build_context(engram) + GPT_name + ":"
reply = generate(text).split("\n")[0]
print(GPT_name + ":" + reply)
add_engram(GPT_name + ":" + reply)