-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquery_test_interaction.py
325 lines (295 loc) · 14 KB
/
query_test_interaction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import xml.etree.ElementTree as ET
import collections
import elasticsearch_copy
import os
import time
topic_division = [[28, 29, 25, 22, 6, 7], [26, 11, 1, 18, 21, 4], [19, 24, 27, 30, 12, 23], [13, 14, 3, 16, 8, 9], [15, 20, 5, 10, 17, 2]]
group_boost = [5, 5, 2, 2, 2]
cache_root = "cache"
result_root = "qresults"
data_root = "clinicaltrials"
def query_word_generate():
"""Generating query words according to keywords.txt"""
query_word = []
for i in range(1, 6):
s = ""
with open(cache_root+"/cache{}/keyword.txt".format(str(i)), "r") as keyword:
for line in keyword:
w = line.strip()
s += w
s += " "
query_word.append(s)
return query_word
query_word = query_word_generate()
def extract_query_xml():
"""
The query topics are provided in an XML file. This function is used to extract query terms from that XML file.
After extracting the query terms, it is stored in an ordered dictionary. This dictionary is then passed to the es_query function which will query Elasticsearch with those terms.
"""
# Provide the path to the query xml file
query_file = open(os.path.join(data_root, "topics2017.xml"), "r")
tree = ET.parse(query_file)
root = tree.getroot()
# Create an ordered dictionary to store the query terms
extracted_data = collections.OrderedDict()
# There are 30 query topics provided. First we store all the topics and iterate over each of them using a for loop.
# Each query topic contains multiple fields. In the try-except block we try to extract the terms for each particular query. These extracted terms are stored in an ordered dictionary with key as the field name and value as the extracted terms.
while 1:
try:
topics = root.findall('topic')
str = input("Enter the topic number you want to search 1~30, Enter 'q' to quit: ")
if str == "q":
break
idx = int(str) - 1
tnum = int(str)
item = topics[idx]
for i in range(5):
if tnum in topic_division[i]:
group_id = i
disease = item.find('disease').text
print("Let's see some clinical trials about {}".format(disease))
time.sleep(0.3)
gene = item.find('gene').text
demographic = item.find('demographic').text
other = item.find('other').text
extracted_data['tnum'] = tnum
extracted_data['disease'] = disease
extracted_data['gene'] = gene
extracted_data['age'] = int(demographic.split('-')[0])
extracted_data['sex'] = demographic.split(' ')[1]
extracted_data['other'] = other
es_query(group_id, extracted_data)
except:
print('Please enter a legal topic number 1~30')
extracted_data['tnum'] = None
extracted_data['disease'] = None
extracted_data['gene'] = None
extracted_data['age'] = None
extracted_data['sex'] = None
extracted_data['other'] = None
return
def es_query(group_id, extracted_data):
"""
This function is used to query Elasticsearch and write results to an output file.
It receives a dictionary containing the extracted query terms from the extract_query_xml function. After querying Elasticsearch, the retrieved results are written to an output file in the standard trec_eval format.
"""
try:
# Store the disease name from the received dictionary in the variable named query
main_query = extracted_data['disease']
age_query = int(extracted_data['age'])
sex_query = extracted_data['sex']
gene_query = extracted_data['gene']
if extracted_data['other'] != 'None':
aux_query = extracted_data['other']
else:
aux_query = None
# print("main_query : {}, age_query : {}, sex_query : {}, aux_query : {}, gene_query : {}".format(main_query, age_query, sex_query, aux_query, gene_query))
# For a simple query without any customizations, uncomment the following line
# res = es.search(index='ct', q=query, size=1000)['hits']['hits']
# Current implementation uses a customized query with multi-match and post-filters in a manner deemed best possible for the current retrieval process. Comment the following query if you plan to use the simple query in the previous line.
# We limit the retrieved results to 1000. The results are arranged in decreasing order of their assigned scores. We assign a rank to each result starting from 1 to 1000 based on decreasing scores. We normalize the score for each result based on the score of the first result with the maximum score.
# res = es.search(index = 'ct', body = {
# "query" :{
# "bool" :{
# "must":{
# "multi_match" : {
# "query" : main_query,
# "type" : "phrase_prefix",
# "fileds" : ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword", "mesh_term"],
# "fuzziness" : "AUTO"
# }
# },
# # "" : {"match" : {"gender" : sex_query}},
# "must" : {"range" : {"maximum_age" : {"gte" : age_query}}},
# "must" : {"range" : {"minimum_age": {"lte" : age_query}}},
# "should" : [
# {"term" : {"eligibility" : main_query}},
# {"term" : {"brief_summary" : main_query}},
# {"term" : {"detailed_description" : main_query}},
# {"term" : {"keyword" : main_query}},
# {"term" : {"brief_title" : main_query}},
# {"term" : {"mesh_term" : main_query}}
# ],
# },
# # "post_filter" : {
# # "term" : {"gender" : "All"}
# # }
# }
# }, size = 10)['hits']["hits"]
# match gene and disease
# res = es.search(index = 'ct', body = {
# "query" : {
# "bool" : {
# "must":{
# "multi_match": {
# "query": main_query,
# "fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword", "mesh_term"]
# }
# },
# "must": {
# "multi_match": {
# "query": gene_query,
# "fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword",
# "mesh_term"]
# }
# },
# }
# }
# }, size=1000)['hits']['hits']
# match gene and disease, filter gender
# res = es.search(index = 'ct', body = {
# "query" : {
# "bool" : {
# "must":{
# "multi_match": {
# "query": main_query,
# "fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword", "mesh_term"]
# }
# },
# "must": {
# "multi_match": {
# "query": gene_query,
# "fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword",
# "mesh_term"]
# }
# },
# }
# },
# "post_filter":
# {"term":
# {"gender": "all"}
# }
# }, size=1000)['hits']['hits']
# filter with age
# res = es.search(index='ct', body={
# "query": {
# "bool": {
# "must": {
# "multi_match": {
# "query": main_query,
# "fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword",
# "mesh_term"]
# }
# },
# "must": {
# "multi_match": {
# "query": gene_query,
# "fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword",
# "mesh_term"]
# }
# },
# "filter" : {
# "range" : {"maximum_age" : {"gte" : age_query}},
# "range" : {"minimum_age" : {"lte" : age_query}}
# }
# }
# },
# "post_filter":
# {"term" : {"gender": "all"},
# },
# }, size=1000)['hits']['hits']
#
# res = es.search(index='ct', body={
# "query": {
# "bool": {
# "must": {
# "multi_match": {
# "query": main_query,
# "fields": ["brief_title * 3", "brief_summary", "detailed_description", "eligibility",
# "keyword * 3",
# "mesh_term * 3"],
# }
# },
# "must": {
# "multi_match": {
# "query": gene_query,
# "fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword",
# "mesh_term"],
# }
# },
# "should": {
# "multi_match": {
# "query": "cancer",
# "fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword",
# "mesh_term"]
# }
# },
# "filter": {
# "range": {"maximum_age": {"gte": age_query}},
# "range": {"minimum_age": {"lte": age_query}}
# }
# }
# },
# "post_filter":
# {"term": {"gender": "all"},
# },
# }, size=1000)['hits']['hits']
# print("query_word ", query_word[group_id])
res = es.search(index='ct', body={
"query": {
"bool": {
"must": {
"multi_match": {
"query": main_query,
"fields": ["brief_title * 3", "brief_summary", "detailed_description", "eligibility",
"keyword * 3",
"mesh_term * 3"],
"boost" : group_boost[group_id]
}
},
"must": {
"multi_match": {
"query": gene_query,
"fields": ["brief_title", "brief_summary", "detailed_description", "eligibility", "keyword",
"mesh_term"],
"boost" : group_boost[group_id]
}
},
"should":{
"multi_match": {
"query" : query_word[group_id],
"fields" : ["brief_summary", "detailed_description"],
"boost" : 1
}
},
"filter": {
"range": {"maximum_age": {"gte": age_query}},
"range": {"minimum_age": {"lte": age_query}}
}
}
},
"post_filter":
{"term": {"gender": "all"},
},
}, size=1000)['hits']['hits']
max_score = res[0]['_score']
rank_ctr = 1
# with open('/home/sofiahuang/code/TREC_pm/TREC-2017-PM-CDS-Track/qresults/mini_output.txt', 'w') as f:
# f.write(json.dumps(res, indent=2, ensure_ascii=False))
# # print(json.dumps(res, indent=2, ensure_ascii=False))
# input()
# Write the retrieved results to an output file in the standard trec_eval format
for i in res:
print('nct_id:{}\trelevance ranking:{}\trelevance score:{}\n'
.format(i['_source']['nct_id'], rank_ctr, round(i['_score'] / max_score, 4)))
rank_ctr += 1
# with open('qresults/results_chen.txt', 'a') as op_file:
# for i in res:
# op_file.write(
# '{}\tQ0\t{}\t{}\t{}\t2_ec_complex\n'.format(extracted_data['tnum'], i['_source']['nct_id'],
# rank_ctr, round(i['_score'] / max_score, 4)))
# rank_ctr += 1
# print("finish_writing")
except Exception as e:
print("\nUnable to query/write!")
print('Error Message:', e, '\n')
return
if __name__ == '__main__':
# Create connection to Elasticsearch listening on localhost port 9200. It uses the Python Elasticsearch API which is the official low-level client for Elasticsearch.
try:
es = elasticsearch_copy.Elasticsearch([{'host': 'localhost', 'port': 9200}])
except Exception as e:
print('\nCannot connect to Elasticsearch!')
print('Error Message:', e, '\n')
# Call the function to start extracting the queries
extract_query_xml()