-
Notifications
You must be signed in to change notification settings - Fork 1
/
chatgpt.py
152 lines (130 loc) · 6.52 KB
/
chatgpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
import os
import time
import traceback
from functools import wraps
from multiprocessing import Pool, TimeoutError
import openai
from openai.error import APIError, RateLimitError, ServiceUnavailableError
from schema import And, Schema, Use
from tqdm import tqdm
openai.api_key = "YOUR_API_KEY"
MODEL = "gpt-4"
BATCH = 4
SUBMIT = 5
MAX_SNIPPET_LEN = 300
TEMPERATURE = 1
dataset_path = f"data/BioASQ-task11bPhaseB-testset{BATCH}.txt"
with open(dataset_path, "r", encoding="utf-8") as f:
QUESTIONS = json.load(f)["questions"]
SCHEMA = {
"yesno": Schema({
"exact_answer": And(
str,
Use(str.lower), # lowercase the YES/NO
lambda s: s in ('yes', 'no')
),
"ideal_answer": And(Use(str), lambda s: s.strip() != "", lambda s: len(s.split()) <= 200)
}),
"list": Schema({
"exact_answer": And(
list,
lambda x: 100 >= len(x) > 0, # no more than 100 entries
lambda x: all((100 >= len(item) > 0) and isinstance(item, str) for item in x), # no more than 100 characters each
Use(lambda x: [[i] for i in x]) # convert to list of list
),
"ideal_answer": And(Use(str), lambda s: s.strip() != "", lambda s: len(s.split()) <= 200)
}),
"summary": Schema({
"ideal_answer": And(Use(str), lambda s: s.strip() != "", lambda s: len(s.split()) <= 200)
}),
"factoid": Schema({
"exact_answer": And(
list,
lambda x: 5 >= len(x) > 0, # no more than 5 entries
Use(lambda x: [[i] for i in x]) # convert to list of list
),
"ideal_answer": And(Use(str), lambda s: s.strip() != "", lambda s: len(s.split()) <= 200)
})
}
PROMPT = {
"yesno": """You can only use JSON format to answer my questions. The format must be {"exact_answer":"", "ideal_answer":""}, where exact_answer should be "yes" or "no", and ideal_answer is a short conversational response starting with yes/no then follow on the explain. You should read the chat history's content before answer the question. The first question is: """,
"list": """You can only use JSON format to answer my questions. The format must be {"exact_answer":[], "ideal_answer":""}, where exact_answer is a list of precise key entities to answer the question, and ideal_answer is a short conversational response containing an explanation. You should read the chat history's content before answer the question. The first question is: """,
"summary": """Reply to the answer clearly and easily in less than 3 sentences. You should read the chat history's content before answer the question. The first question is: """,
"factoid": """You can only use JSON format to answer my questions. The format must be {"exact_answer":[], "ideal_answer":""}. where exact_answer is a list of precise key entities to answer the question. ideal_answer is a short conversational response containing an explanation. You should read the chat history's content before answer the question. The first question is: """
}
def make_message(role, content):
return {"role": role, "content": content}
def gpt_api_retry(func):
@wraps(func)
def warp_func(**kwargs):
log = ""
for i in range(5):
try:
return func(**kwargs)
except (RateLimitError, APIError, ServiceUnavailableError, TimeoutError) as e:
message = f"Retry: {i+1} times, error: {traceback.format_exc()}\n\n"
tqdm.write(message)
log += message
time.sleep(3)
raise TimeoutError(f"Failed to get response from OpenAI API after 5 retries.\n\n{log}")
return warp_func
@gpt_api_retry
def completions_with_backoff(**kwargs):
with Pool(processes=1) as pool:
process = pool.apply_async(openai.ChatCompletion.create, kwds=kwargs)
return process.get(timeout=60)
def summary_snippet(snippet):
messages = [make_message("user", f'Conclusion and summarize this context in less than {MAX_SNIPPET_LEN} letters:"""{snippet}"""')]
completion = completions_with_backoff(model=MODEL, messages=messages, temperature=TEMPERATURE)
resp = completion.choices[0].message.content
assert isinstance(resp, str) and resp.strip() != "", f"summary_snippet failed: {resp}"
return resp
def get_question_answer(q):
resp = None
try: # get result from gpt request
messages = []
for sni in q["snippets"]:
snippet = summary_snippet(sni["text"]) if len(sni["text"]) > MAX_SNIPPET_LEN else sni["text"]
messages.append(make_message("assistant", snippet))
messages += [make_message("user", PROMPT[q["type"]]), make_message("user", q["body"])]
completion = completions_with_backoff(model=MODEL, messages=messages, temperature=TEMPERATURE)
resp = completion.choices[0].message.content.strip(".。\"'")
result = json.loads(resp) if q["type"] != "summary" else {"ideal_answer": resp}
result = SCHEMA[q["type"]].validate(result)
return result
except Exception as e:
tqdm.write("====Error, check error.txt====")
with open("error.txt", "w", encoding="utf-8") as f:
f.write(f"{str(e)}\nid:{q['id']}\ntype:{q['type']}\nresp:{resp}\n")
raise e
def request_gpt(specific_q=None):
loop = tqdm(QUESTIONS) if specific_q is None else [specific_q]
folder = f"gpt_result/11b_batch_{BATCH}_submit_{SUBMIT}"
if not os.path.isdir(folder):
os.mkdir(folder, mode=0o755)
for question in loop:
file_path = f"{folder}/{question['id']}.json"
if os.path.exists(file_path): # skip if already exists
continue
result = get_question_answer(question)
# add id and save file
result = {"id": question["id"], **result}
with open(file_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=4)
def merge_result_and_question():
submit_file = {"questions": []}
folder = f"gpt_result/11b_batch_{BATCH}_submit_{SUBMIT}"
for q in QUESTIONS:
with open(f"{folder}/{q['id']}.json", "r", encoding="utf-8") as f:
result = json.load(f)
merge_q = {**q, **result}
submit_file["questions"].append(merge_q)
submit_file_path = f"{folder}/submit.json"
if os.path.exists(submit_file_path):
input(f"Overwriting {submit_file_path}. Press enter to continue.....")
with open(submit_file_path, "w") as f:
json.dump(submit_file, f, ensure_ascii=False, indent=4)
if __name__ == "__main__":
request_gpt()
merge_result_and_question()