forked from tech-srl/lstar_extraction
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRNNClassifier.py
139 lines (116 loc) · 5.64 KB
/
RNNClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from LSTM import LSTMNetwork
from GRU import GRUNetwork
from LinearTransform import LinearTransform
import dynet as dy
from time import perf_counter
import random
import matplotlib.pyplot as plt
from math import ceil
class RNNClassifier:
def __init__(self,alphabet,num_layers=2,input_dim=3,hidden_dim=5,RNNClass=LSTMNetwork):
self.alphabet = list(alphabet)
self.int2char = self.alphabet
self.char2int = {c:i for i,c in enumerate(self.int2char)}
self.int2class = [True,False] # binary classifier for now
self.class2int = {c:i for i,c in enumerate(self.int2class)}
self.vocab_size = len(self.alphabet)
self.pc = dy.ParameterCollection()
self.lookup = self.pc.add_lookup_parameters((self.vocab_size, input_dim))
self.linear_transform = LinearTransform(hidden_dim,len(self.class2int),self.pc)
self.rnn = RNNClass(num_layers=num_layers,input_dim=input_dim,hidden_dim=hidden_dim,pc=self.pc)
self.store_expressions()
self.all_losses = []
self.finish_signal = "Finished"
self.keep_going = "Keep Going"
def renew(self):
dy.renew_cg()
self.store_expressions()
def store_expressions(self):
self.rnn.store_expressions()
self.linear_transform.store_expressions()
def _char_to_input_vector(self,char):
return self.lookup[self.char2int[char]]
def _next_state(self,state,char):
return self.rnn.next_state(state,self._char_to_input_vector(char))
def _state_probability_distribution(self,state):
return dy.softmax(self.linear_transform.apply(state.output()))
def get_first_RState(self):
return self.rnn.initial_state.as_vec(), self._classify_state(self.rnn.initial_state)
def get_next_RState(self,vec,char):
#verification, could get rid of
if not char in self.alphabet:
print("char for next vector not from input alphabet")
return None
state = self.rnn.state_class(full_vec = vec, hidden_dim = self.rnn.hidden_dim)
state = self._next_state(state,char)
return state.as_vec(), self._classify_state(state)
def _word_is_over_input_alphabet(self,word):
return next((False for c in word if not c in self.alphabet),True)
def _state_accept_probability(self,s):
probabilities = self._state_probability_distribution(s)
return probabilities[self.class2int[True]]
def _classify_state(self,s):
return self._state_accept_probability(s).value()>0.5
def _probability_word_in_language(self,word):
#verification, could get rid of
if not self._word_is_over_input_alphabet(word):
print("word is not over input alphabet")
return False
s = self.rnn.initial_state
for c in word:
s = self._next_state(s,c)
return self._state_accept_probability(s)
def classify_word(self,word):
return self._probability_word_in_language(word).value()>0.5
def loss_on_word(self, word, label):
s = self.rnn.initial_state
p = self._probability_word_in_language(word)
p = p if label == True else (1-p) # now p = probability of correct label for word
#dy.picklneglogsoftmax on self.linear_transform.apply(state.output()) should be numerically stable
return -dy.log(p) # ideally p should be 1, in which case log(p)=0. the lower it gets: the greater -log(p) gets
# loss = dy.esum(loss)
def train_batch(self,word_dict,trainer):
self.renew()
loss = [self.loss_on_word(w,word_dict[w]) for w in word_dict]
loss = dy.esum(loss)
loss_value = loss.value()/len(word_dict)
loss.backward()
trainer.update()
return loss_value
def show_all_losses(self):
plt.scatter(range(len(self.all_losses)),self.all_losses,label="classification loss since initiation")
plt.legend()
plt.show()
def train_group(self,word_dict,iterations,trainer_class=dy.AdamTrainer,learning_rate=None,loss_every=100,
batch_size=20,show=True,print_time=True,stop_threshold=0):
if iterations == 0:
return
start = perf_counter()
trainer = trainer_class(self.pc)
if not None is learning_rate:
trainer.learning_rate = learning_rate
loss_values = []
if None is batch_size:
batch_size = len(word_dict) # leave None to define one huge batch
words = list(word_dict.keys())
num_batches = int(ceil(len(words)/batch_size))
for i in range(iterations):
random.shuffle(words)
batches_loss = []
for j in range(num_batches):
batch = words[j*batch_size:(j+1)*batch_size]
batches_loss.append(self.train_batch({w:word_dict[w] for w in batch},trainer))
loss_values.append(sum(batches_loss)/num_batches) # its not perfect because the last batch might be a different size and they were training during, but whatever, its here to give a general idea of what's going on
if loss_values[-1]<stop_threshold:
break
if (i+1)%loss_every == 0:
print("current average loss is: ",loss_values[-1])
self.all_losses += loss_values
if print_time:
print("total time:",perf_counter()-start)
if show:
plt.scatter(range(len(loss_values)),loss_values,label="classification loss for these epochs")
plt.legend()
plt.show()
self.show_all_losses()
return self.finish_signal if loss_values[-1] < stop_threshold else self.keep_going