-
Notifications
You must be signed in to change notification settings - Fork 0
/
texts_ai.py
102 lines (80 loc) · 3.27 KB
/
texts_ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import generate.gpt
import generate.llama
import multiprocessing
import os
import custom_sql
import time
import datetime
class suppress_stdout_stderr(object):
def __init__(self):
# Open a pair of null files
self.null_fds = [os.open(os.devnull, os.O_RDWR) for _ in range(2)]
# Save the actual stdout (1) and stderr (2) file descriptors.
self.save_fds = [os.dup(1), os.dup(2)]
def __enter__(self):
# Assign the null pointers to stdout and stderr.
os.dup2(self.null_fds[0], 1)
os.dup2(self.null_fds[1], 2)
def __exit__(self, *_):
# Re-assign the real stdout/stderr back to (1) and (2)
os.dup2(self.save_fds[0], 1)
os.dup2(self.save_fds[1], 2)
# Close all file descriptors
for fd in self.null_fds + self.save_fds:
os.close(fd)
def transform_date(date):
unix_timestamp = int(978307200) * 1000000000 # "978307200" = "2001-01-01 00:00:00 UTC"
new_date = int((date + unix_timestamp) / 1000000000)
return new_date
def get_message(db_file, sql, delay):
msg = []
start_time = int(datetime.datetime.now().timestamp())
while True:
messages = custom_sql.query(db_file, sql)
if transform_date(messages[0][2]) > start_time:
contact = messages[0][0]
time.sleep(delay)
messages = custom_sql.query(db_file, sql)
for message in messages:
if transform_date(message[2]) > start_time:
if message[0] == contact:
msg.append(message[3])
else:
break
msg.reverse()
return contact, " ".join(msg)
else:
time.sleep(1)
def dispatch_reply(gen_function, model_path, contact, prompt):
print(f"Received from {contact}: {prompt}")
with suppress_stdout_stderr():
reply = gen_function(prompt, model_path)
send_message(contact, reply)
def send_message(contact, message):
message = message.replace("'", "")
os.system(f"osascript send_message.scpt '{contact}' '{message}'")
print(f"Sent to {contact}: {message}")
def monitor(gen_function, model_path, db_file, sql, delay):
print("Monitoring... (Press Ctrl-C to exit)")
while True:
contact, message = get_message(db_file, sql, delay)
multiprocessing.Process(target=dispatch_reply, args=(gen_function, model_path, contact, message)).start()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("model_type", choices=["gpt", "llama"], help="type of model to use")
parser.add_argument("model_path", type=str, help="path/to/model")
parser.add_argument("--delay", type=int, default=5, help="delay between checks for new messages")
args = parser.parse_args()
if args.model_type == "gpt":
gen_function = generate.gpt.gen
elif args.model_type == "llama":
gen_function = generate.llama.gen
else:
raise ValueError("Invalid model type: {}".format(args.model_type))
db_file = os.path.expanduser("~/Library/Messages/chat.db")
with open("sql/new_messages.sql", "r") as sql_file:
sql = sql_file.read().replace("\\n", "\n")
monitor(gen_function, args.model_path, db_file, sql, args.delay)
if __name__ == "__main__":
main()