-
Notifications
You must be signed in to change notification settings - Fork 217
/
wikienv.py
168 lines (146 loc) · 5.4 KB
/
wikienv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import ast
import json
import time
import gym
import requests
from bs4 import BeautifulSoup
# import wikipedia
def clean_str(p):
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
class textSpace(gym.spaces.Space):
def contains(self, x) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
return isinstance(x, str)
class WikiEnv(gym.Env):
def __init__(self):
"""
Initialize the environment.
"""
super().__init__()
self.page = None # current Wikipedia page
self.obs = None # current observation
self.lookup_keyword = None # current lookup keyword
self.lookup_list = None # list of paragraphs containing current lookup keyword
self.lookup_cnt = None # current lookup index
self.steps = 0 # current number of steps
self.answer = None # current answer from the agent
self.observation_space = self.action_space = textSpace()
self.search_time = 0
self.num_searches = 0
def _get_obs(self):
return self.obs
def _get_info(self):
return {"steps": self.steps, "answer": self.answer}
def reset(self, seed=None, return_info=False, options=None):
# We need the following line to seed self.np_random
# super().reset(seed=seed)
self.obs = ("Interact with Wikipedia using search[], lookup[], and "
"finish[].\n")
self.page = None
self.lookup_keyword = None
self.lookup_list = None
self.lookup_cnt = None
self.steps = 0
self.answer = None
observation = self._get_obs()
info = self._get_info()
return (observation, info) if return_info else observation
def construct_lookup_list(self, keyword):
# find all paragraphs
if self.page is None:
return []
paragraphs = self.page.split("\n")
paragraphs = [p.strip() for p in paragraphs if p.strip()]
# find all sentence
sentences = []
for p in paragraphs:
sentences += p.split('. ')
sentences = [s.strip() + '.' for s in sentences if s.strip()]
parts = sentences
parts = [p for p in parts if keyword.lower() in p.lower()]
return parts
@staticmethod
def get_page_obs(page):
# find all paragraphs
paragraphs = page.split("\n")
paragraphs = [p.strip() for p in paragraphs if p.strip()]
# find all sentence
sentences = []
for p in paragraphs:
sentences += p.split('. ')
sentences = [s.strip() + '.' for s in sentences if s.strip()]
return ' '.join(sentences[:5])
# ps = page.split("\n")
# ret = ps[0]
# for i in range(1, len(ps)):
# if len((ret + ps[i]).split(" ")) <= 50:
# ret += ps[i]
# else:
# break
# return ret
def search_step(self, entity):
entity_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={entity_}"
old_time = time.time()
response_text = requests.get(search_url).text
self.search_time += time.time() - old_time
self.num_searches += 1
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
if result_divs: # mismatch
self.result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
self.obs = f"Could not find {entity}. Similar: {self.result_titles[:5]}."
else:
page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")]
if any("may refer to:" in p for p in page):
self.search_step("[" + entity + "]")
else:
self.page = ""
for p in page:
if len(p.split(" ")) > 2:
self.page += clean_str(p)
if not p.endswith("\n"):
self.page += "\n"
self.obs = self.get_page_obs(self.page)
self.lookup_keyword = self.lookup_list = self.lookup_cnt = None
def step(self, action):
reward = 0
done = False
action = action.strip()
if self.answer is not None: # already finished
done = True
return self.obs, reward, done, self._get_info()
if action.startswith("search[") and action.endswith("]"):
entity = action[len("search["):-1]
# entity_ = entity.replace(" ", "_")
# search_url = f"https://en.wikipedia.org/wiki/{entity_}"
self.search_step(entity)
elif action.startswith("lookup[") and action.endswith("]"):
keyword = action[len("lookup["):-1]
if self.lookup_keyword != keyword: # reset lookup
self.lookup_keyword = keyword
self.lookup_list = self.construct_lookup_list(keyword)
self.lookup_cnt = 0
if self.lookup_cnt >= len(self.lookup_list):
self.obs = "No more results.\n"
else:
self.obs = f"(Result {self.lookup_cnt + 1} / {len(self.lookup_list)}) " + self.lookup_list[self.lookup_cnt]
self.lookup_cnt += 1
elif action.startswith("finish[") and action.endswith("]"):
answer = action[len("finish["):-1]
self.answer = answer
done = True
self.obs = f"Episode finished, reward = {reward}\n"
elif action.startswith("think[") and action.endswith("]"):
self.obs = "Nice thought."
else:
self.obs = "Invalid action: {}".format(action)
self.steps += 1
return self.obs, reward, done, self._get_info()
def get_time_info(self):
speed = self.search_time / self.num_searches if self.num_searches else 0
return {
"call_speed": speed,
"call_time": self.search_time,
"num_calls": self.num_searches,
}