-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathroberta_large_mnli.py
executable file
·53 lines (46 loc) · 1.82 KB
/
roberta_large_mnli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import pipeline, AutoTokenizer
from tqdm import tqdm
import json
import argparse
parser = argparse.ArgumentParser(description='main', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', default='MAGCS')
args = parser.parse_args()
dataset = args.dataset
# huggingface==4.0.0
device = 0
classifier = pipeline(task='sentiment-analysis', model='roberta-large-mnli', device=device, return_all_scores=True)
tokenizer = AutoTokenizer.from_pretrained('roberta-large-mnli')
label2name = {}
with open(f'{dataset}/{dataset}_label.json') as fin:
for line in tqdm(fin):
data = json.loads(line)
label = data['label']
name = data['name'][0]
label2name[label] = name
max_paper_len = 480
max_label_len = 20
with open(f'{dataset}/{dataset}_paper.json') as fin1, \
open(f'{dataset}/{dataset}_candidates.json') as fin2, \
open(f'{dataset}/{dataset}_predictions_mnli.json', 'w') as fout:
for line1, line2 in tqdm(zip(fin1, fin2)):
data1 = json.loads(line1)
data2 = json.loads(line2)
assert data1['paper'] == data2['paper']
text = (data1['title'] + ' ' + data1['abstract']).strip()
tokens = tokenizer(text, truncation=True, max_length=max_paper_len)
text = tokenizer.decode(tokens["input_ids"][1:-1])
score = {}
for label in data2['matched_label']:
name = label2name[label]
name = ' '.join(name.split())
tokens = tokenizer(name, truncation=True, max_length=max_label_len)
name = tokenizer.decode(tokens["input_ids"][1:-1])
input = f'{text} </s></s> this document is about {name}.'
output = classifier(input)
score[label] = output[0][-1]['score']
score_sorted = sorted(score.items(), key=lambda x:x[1], reverse=True)
top5 = [k for k, v in score_sorted[:5]]
out = {}
out['paper'] = data1['paper']
out['predictions'] = score_sorted
fout.write(json.dumps(out)+'\n')