-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtextbook_generation.py
104 lines (92 loc) · 3.83 KB
/
textbook_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import json
import math
import os
import random
import time
from datetime import datetime
from threading import Thread
import datasets
import openai
from code_synthesis_textbooks import *
from openai_api_wrapper import OpenaiAPIWrapper
topics = generate_topics()
def parse_args():
parser = argparse.ArgumentParser()
# parser.add_argument("--api_key",type=str, required=True, help="Your OpenAI API key")
parser.add_argument("--max_tokens", type=int, default=2048, help="Max tokens for generated output")
parser.add_argument('--gen_nums',type=int, default=10000)
parser.add_argument('--output_dir',type=str,default='./textbook')
parser.add_argument("--threads_num_per_key", type=int, default=200)
return parser.parse_args()
def save_response_to_file(worker_id, response, output_dir):
now = datetime.now()
timestamp = now.strftime('%Y%m%d%H%M%S')
filename = f'response_{timestamp}_worker{worker_id}.json'
dialogue_data = {
"timestamp": timestamp,
"worker":worker_id,
"response": response
}
filepath = os.path.join(output_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(dialogue_data, f, ensure_ascii=False, indent=4)
def generate_textbooks(worker_id, args,api_keys):
llm = OpenaiAPIWrapper()
current_key = api_keys[int(worker_id%4)]
llm.set_api_key(current_key) # 200个线程,每50个线程用一个key
start_time = time.time()
system_prompt = '''You are a helpful assistant. '''
user_prompt = synthesize_textbook(topics,num_topics=6)
messages = [{"role": "system", "content": system_prompt}]
messages.append({"role": "user", "content": user_prompt})
generated_num = 0
while generated_num < args.gen_nums:
try:
output = llm.call_turbo_using_messages(messages, max_tokens=args.max_tokens, temperature=1.0, top_p=0.9)
response = llm.parse_chatgpt_response(output)
generated_num += 1
except Exception as e:
print('Unexpected error')
print('Exception: {}'.format(e))
if 'exceeded your current quota' in str(e) or 'due to violation of our policies' in str(e):
with open('./failed_api_keys.txt', 'a') as f:
f.write('{}\n'.format(current_key))
# 读取文件内容
with open('./failed_api_keys.txt', 'r') as file:
lines = file.readlines()
# 去除重复项
unique_lines = list(set(lines))
# 重新储存到新文件
with open('failed_api_keys.txt', 'w') as file:
file.writelines(unique_lines)
# 当前环境下的api列表中去除当前失效key
api_keys.remove(current_key)
if len(api_keys) >= 4:
llm.set_api_key(api_keys[int(worker_id%4)])
else:
print('Available apis < 4.')
return None
continue
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
save_response_to_file(worker_id,response,args.output_dir)
print('# Worker_{} collected schemas: {}'.format(
worker_id, generated_num))
def main(args):
start_time = time.time()
threads = []
with open('extracted_keys.txt','r') as f:
api_keys = [key.strip() for key in f.readlines()]
for j in range(args.threads_num_per_key):
t = Thread(target=generate_textbooks, args=(j, args, api_keys))
t.start()
print(str(j)+"starts!")
threads.append(t)
for t in threads:
t.join()
total_time = time.time() - start_time
print(f'Generate Finshed! with {total_time}')
if __name__ == '__main__':
args = parse_args()
main(args)