-
Notifications
You must be signed in to change notification settings - Fork 9
/
utilities.py
111 lines (98 loc) · 2.84 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from __future__ import division
import cPickle as pickle
import csv
import numpy as np
import sys
import os
import cPickle as pickle
from nltk.corpus import stopwords
import json
import gzip
from tqdm import tqdm
from collections import Counter
from collections import defaultdict
""" Random utils.
"""
def batchify(data, i, bsz, max_sample):
start = int(i * bsz)
end = int(i * bsz) + bsz
if(end>max_sample):
end = max_sample
data = data[start:end]
return data
def dict_to_list(data_dict):
data_list = []
for key, value in tqdm(data_dict.items(),
desc='dict conversion'):
for v in value:
data_list.append([key, v[0], v[1]])
return data_list
def dictToFile(dict,path):
print "Writing to {}".format(path)
with gzip.open(path, 'w') as f:
f.write(json.dumps(dict))
def dictFromFileUnicode(path):
'''
Read js file:
key -> unicode keys
string values -> unicode value
'''
print "Loading {}".format(path)
with gzip.open(path, 'r') as f:
return json.loads(f.read())
def load_pickle(fin):
with open(fin,'r') as f:
obj = pickle.load(f)
return obj
def select_gpu(gpu):
os.environ["CUDA_VISIBLE_DEVICES"]="0"
if(gpu>=0):
os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu)
def load_pickle(fin):
with open(fin,'r') as f:
obj = pickle.load(f)
return obj
def load_set(fin):
data = []
with open(fin, 'r') as f:
reader= csv.reader(f, delimiter='\t')
for r in reader:
data.append(r)
return data
def length_stats(lengths, name=''):
print("=====================================")
print("Length Statistics for {}".format(name))
print("Max={}".format(np.max(lengths)))
print("Median={}".format(np.median(lengths)))
print("Mean={}".format(np.mean(lengths)))
print("Min={}".format(np.min(lengths)))
def show_stats(name, x):
print("{} max={} mean={} min={}".format(name,
np.max(x),
np.mean(x),
np.min(x)))
def print_args(args, path=None):
if path:
output_file = open(path, 'w')
args.command = ' '.join(sys.argv)
items = vars(args)
if path:
output_file.write('=============================================== \n')
for key in sorted(items.keys(), key=lambda s: s.lower()):
value = items[key]
if not value:
value = "None"
if path is not None:
output_file.write(" " + key + ": " + str(items[key]) + "\n")
if path:
output_file.write('=============================================== \n')
if path:
output_file.close()
del args.command
def mkdir_p(path):
if path == '':
return
try:
os.makedirs(path)
except:
pass