-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
87 lines (80 loc) · 2.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from gensim.models.word2vec import Word2Vec
import logging
import sys
import gzip
import numpy as np
def get_logger(name, level=logging.INFO, handler=sys.stdout,
formatter='%(asctime)s - %(name)s - %(levelname)s - %(message)s'):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
formatter = logging.Formatter(formatter)
stream_handler = logging.StreamHandler(handler)
stream_handler.setLevel(level)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger
def print_FLAGS(FLAGS,logger):
Flags_Dict = {}
logger.info("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
logger.info("{} = {}".format(attr, value))
Flags_Dict[attr] = value
logger.info("\n")
return Flags_Dict
def get_max_length(word_sentences):
max_len = 0
for sentence in word_sentences:
length = len(sentence)
if length > max_len:
max_len = length
return max_len
def padSequence(dataset,max_length,beginZero=True):
dataset_p = []
actual_sequence_length =[]
#added np.atleast_2d here
for x in dataset:
row_length = len(x)
actual_sequence_length.append(row_length)
if(row_length <=max_length):
if(beginZero):
dataset_p.append(np.pad(x,pad_width=(max_length-len(x),0),mode='constant',constant_values=0))
else:
dataset_p.append(np.pad(x,pad_width=(0,max_length-len(x)),mode='constant',constant_values=0))
else:
dataset_p.append(x[0:max_length])
return np.array(dataset_p),actual_sequence_length
def load_word_embedding_dict(embedding,embedding_path,logger):
"""
load word embeddings from file
:param embedding:
:param embedding_path:
:param logger:
:return: embedding dict, embedding dimention, caseless
"""
if embedding == 'word2vec':
# loading word2vec
logger.info("Loading word2vec ...")
word2vec = Word2Vec.load_word2vec_format(embedding_path, binary=True)
embedd_dim = word2vec.vector_size
return word2vec, embedd_dim, False
elif embedding == 'glove':
# loading GloVe
logger.info("Loading GloVe ...")
embedd_dim = -1
embedd_dict = dict()
with open(embedding_path, 'r') as file:
for line in file:
line = line.strip()
if len(line) == 0:
continue
tokens = line.split()
if embedd_dim < 0:
embedd_dim = len(tokens) - 1 #BECAUSE THE ZEROTH INDEX IS OCCUPIED BY THE WORD
else:
assert (embedd_dim + 1 == len(tokens))
embedd = np.empty([1, embedd_dim], dtype=np.float64)
embedd[:] = tokens[1:]
embedd_dict[tokens[0]] = embedd
return embedd_dict, embedd_dim, True
else:
raise ValueError("embedding should choose from [word2vec, glove]")