forked from stanford-oval/storm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_storm_wiki_deepseek.py
157 lines (135 loc) · 7.6 KB
/
run_storm_wiki_deepseek.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
STORM Wiki pipeline powered by DeepSeek models and You.com or Bing search engine.
You need to set up the following environment variables to run this script:
- DEEPSEEK_API_KEY: DeepSeek API key
- DEEPSEEK_API_BASE: DeepSeek API base URL (default is https://api.deepseek.com)
- YDC_API_KEY: You.com API key; or, BING_SEARCH_API_KEY: Bing Search API key
Output will be structured as below
args.output_dir/
topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
conversation_log.json # Log of information-seeking conversation
raw_search_results.json # Raw search results from search engine
direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
storm_gen_outline.txt # Outline refined with collected information
url_to_info.json # Sources that are used in the final article
storm_gen_article.txt # Final article generated
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""
import os
import sys
import re
import logging
from argparse import ArgumentParser
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import DeepSeekModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.utils import load_api_key
def sanitize_topic(topic):
"""
Sanitize the topic name for use in file names.
Remove or replace characters that are not allowed in file names.
"""
# Replace spaces with underscores
topic = topic.replace(' ', '_')
# Remove any character that isn't alphanumeric, underscore, or hyphen
topic = re.sub(r'[^a-zA-Z0-9_-]', '', topic)
# Ensure the topic isn't empty after sanitization
if not topic:
topic = "unnamed_topic"
return topic
def main(args):
load_api_key(toml_file_path='secrets.toml')
lm_configs = STORMWikiLMConfigs()
# Ensure DEEPSEEK_API_KEY is set
if not os.getenv("DEEPSEEK_API_KEY"):
raise ValueError("DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file.")
deepseek_kwargs = {
'api_key': os.getenv("DEEPSEEK_API_KEY"),
'api_base': os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"),
'temperature': args.temperature,
'top_p': args.top_p,
}
# DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks
# Users can choose the appropriate model based on their needs
conv_simulator_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs)
question_asker_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs)
outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs)
article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs)
article_polish_lm = DeepSeekModel(model=args.model, max_tokens=4000, **deepseek_kwargs)
lm_configs.set_conv_simulator_lm(conv_simulator_lm)
lm_configs.set_question_asker_lm(question_asker_lm)
lm_configs.set_outline_gen_lm(outline_gen_lm)
lm_configs.set_article_gen_lm(article_gen_lm)
lm_configs.set_article_polish_lm(article_polish_lm)
engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)
# STORM is a knowledge curation system which consumes information from the retrieval module.
# Currently, the information source is the Internet and we use search engine API as the retrieval module.
if args.retriever == 'bing':
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
else:
raise ValueError(f"Invalid retriever: {args.retriever}. Choose either 'bing' or 'you'.")
runner = STORMWikiRunner(engine_args, lm_configs, rm)
topic = input('Topic: ')
sanitized_topic = sanitize_topic(topic)
try:
runner.run(
topic=sanitized_topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
remove_duplicate=args.remove_duplicate,
)
runner.post_run()
runner.summary()
except Exception as e:
logger.exception(f"An error occurred: {str(e)}")
raise
if __name__ == '__main__':
parser = ArgumentParser()
# global arguments
parser.add_argument('--output-dir', type=str, default='./results/deepseek',
help='Directory to store the outputs.')
parser.add_argument('--max-thread-num', type=int, default=3,
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'], required=True,
help='The search engine API to use for retrieving information.')
parser.add_argument('--model', type=str, choices=['deepseek-chat', 'deepseek-coder'], default='deepseek-chat',
help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.')
parser.add_argument('--temperature', type=float, default=1.0,
help='Sampling temperature to use.')
parser.add_argument('--top_p', type=float, default=0.9,
help='Top-p sampling parameter.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
help='If True, simulate conversation to research the topic; otherwise, load the results.')
parser.add_argument('--do-generate-outline', action='store_true',
help='If True, generate an outline for the topic; otherwise, load the results.')
parser.add_argument('--do-generate-article', action='store_true',
help='If True, generate an article for the topic; otherwise, load the results.')
parser.add_argument('--do-polish-article', action='store_true',
help='If True, polish the article by adding a summarization section and (optionally) removing '
'duplicate content.')
# hyperparameters for the pre-writing stage
parser.add_argument('--max-conv-turn', type=int, default=3,
help='Maximum number of questions in conversational question asking.')
parser.add_argument('--max-perspective', type=int, default=3,
help='Maximum number of perspectives to consider in perspective-guided question asking.')
parser.add_argument('--search-top-k', type=int, default=3,
help='Top k search results to consider for each search query.')
# hyperparameters for the writing stage
parser.add_argument('--retrieve-top-k', type=int, default=3,
help='Top k collected references for each section title.')
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')
main(parser.parse_args())