-
Notifications
You must be signed in to change notification settings - Fork 190
/
calculate_coverages.py
90 lines (76 loc) · 2.89 KB
/
calculate_coverages.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from __future__ import print_function
import pickle
import json
import csv
import sys
from io import open
# Allow us to import the torchmoji directory
from os.path import dirname, abspath
sys.path.insert(0, dirname(dirname(abspath(__file__))))
from torchmoji.sentence_tokenizer import SentenceTokenizer, coverage
try:
unicode # Python 2
except NameError:
unicode = str # Python 3
IS_PYTHON2 = int(sys.version[0]) == 2
OUTPUT_PATH = 'coverage.csv'
DATASET_PATHS = [
'../data/Olympic/raw.pickle',
'../data/PsychExp/raw.pickle',
'../data/SCv1/raw.pickle',
'../data/SCv2-GEN/raw.pickle',
'../data/SE0714/raw.pickle',
#'../data/SE1604/raw.pickle', # Excluded due to Twitter's ToS
'../data/SS-Twitter/raw.pickle',
'../data/SS-Youtube/raw.pickle',
]
with open('../model/vocabulary.json', 'r') as f:
vocab = json.load(f)
results = []
for p in DATASET_PATHS:
coverage_result = [p]
print('Calculating coverage for {}'.format(p))
with open(p, 'rb') as f:
if IS_PYTHON2:
s = pickle.load(f)
else:
s = pickle.load(f, fix_imports=True)
# Decode data
try:
s['texts'] = [unicode(x) for x in s['texts']]
except UnicodeDecodeError:
s['texts'] = [x.decode('utf-8') for x in s['texts']]
# Own
st = SentenceTokenizer({}, 30)
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
[s['train_ind'],
s['val_ind'],
s['test_ind']],
extend_with=10000)
coverage_result.append(coverage(tests[2]))
# Last
st = SentenceTokenizer(vocab, 30)
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
[s['train_ind'],
s['val_ind'],
s['test_ind']],
extend_with=0)
coverage_result.append(coverage(tests[2]))
# Full
st = SentenceTokenizer(vocab, 30)
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
[s['train_ind'],
s['val_ind'],
s['test_ind']],
extend_with=10000)
coverage_result.append(coverage(tests[2]))
results.append(coverage_result)
with open(OUTPUT_PATH, 'wb') as csvfile:
writer = csv.writer(csvfile, delimiter='\t', lineterminator='\n')
writer.writerow(['Dataset', 'Own', 'Last', 'Full'])
for i, row in enumerate(results):
try:
writer.writerow(row)
except:
print("Exception at row {}!".format(i))
print('Saved to {}'.format(OUTPUT_PATH))