-
Notifications
You must be signed in to change notification settings - Fork 0
/
banditron.py
121 lines (103 loc) · 4.49 KB
/
banditron.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from pymongo import MongoClient
import random
import datetime
REUTERS_CATEGORY_MAPPING = ['CCAT', 'ECAT', 'MCAT', 'GCAT']
def get_category_index(label):
for i in range(0,len(REUTERS_CATEGORY_MAPPING)):
if label == REUTERS_CATEGORY_MAPPING[i]:
return i
return -1
class Banditron:
def __init__(self):
self.gamma = 0.05
self.mongo = MongoClient('localhost',27017)['aml']['features']
self.dict = self.mongo.find({'_id':'1'})[0]['featureList']
self.weights = self.init_weights()
self.error_rate = 0.0
self.correct_classified = 0.0
self.incorrect_classified = 0.0
self.number_of_rounds = 0.0
def init_weights(self):
weights = []
for i in range(0,len(REUTERS_CATEGORY_MAPPING)):
weights.append([0.0] * len(self.dict))
return weights
def update_weights(self, update_matrix):
for i in range(0,len(self.weights)):
for j in range(0,len(self.weights[i])):
self.weights[i][j] += update_matrix[i][j]
def get_update_matrix(self, feature_vectors, calculated_label, predicted_label, true_label, probabilities):
update_matrix = self.init_weights()
for i in range(0,len(update_matrix)):
left = 0.0
right = 0.0
if true_label == predicted_label and predicted_label == i:
left = 1/probabilities[i]
if calculated_label == i:
right = 1.0
for j in range(0,len(feature_vectors)):
update_matrix[i][int(feature_vectors[j][0])] = feature_vectors[j][1] * (left - right)
return update_matrix
def run(self, doc_id, feature_vectors, true_label):
self.number_of_rounds += 1.0
calculated_label = self.predict_label(feature_vectors)
probabilities = self.calc_probabilities(calculated_label)
predicted_label = self.random_sample(probabilities)
if true_label == predicted_label:
self.correct_classified += 1.0
else:
self.incorrect_classified += 1.0
self.error_rate = self.incorrect_classified/self.number_of_rounds
update_matrix = self.get_update_matrix(feature_vectors, calculated_label, predicted_label, true_label, probabilities)
self.update_weights(update_matrix)
def predict_label(self, feature_vectors):
max = 0.0
label = 0
for i in range(0,len(self.weights)):
total = 0.0
for eachVector in feature_vectors:
total += eachVector[1]*self.weights[i][int(eachVector[0])]
if total >= max:
max = total
label = i
return label
def calc_probabilities(self, calculated_label):
probabilities = [0] * len(self.weights)
for i in range(0,len(probabilities)):
probabilities[i] = self.gamma/len(self.weights)
if i == calculated_label:
probabilities[i] += (1 - self.gamma)
return probabilities
def random_sample(self, probabilities):
number = random.random() * sum(probabilities)
for i in range(0,len(probabilities)):
if number < probabilities[i]:
return i
break
number -= probabilities[i]
return len(probabilities)-1
def main():
banditron = Banditron()
doc_ids = banditron.mongo.find({'_id': 2})[0]['dataset_list_docIds']
count = 0
error_list = list()
rounds = list()
for t in range(0,len(doc_ids)):
doc_id = doc_ids[t]
feature_vectors = banditron.mongo.find({'docId':str(doc_id)})[0]['featureList']
true_label = get_category_index(banditron.mongo.find({'docId':str(doc_id)})[0]['true_label'])
banditron.run(doc_id, feature_vectors, true_label)
if ((t+1)%1000) == 0:
print "%s rounds completed with error rate %s" %(str(t+1),str(banditron.error_rate))
rounds.append(banditron.number_of_rounds)
error_list.append(banditron.error_rate)
count += 1
if count > 100000:
break
mongo_plot = MongoClient('localhost',27017)['aml']['plots']
mongo_plot.update({'_id':'reuters_banditron'},{'$set':{'timeStamp':datetime.datetime.now(),'rounds':rounds,'error_rate':error_list}},True)
print "Correctly classified: %s" %str(banditron.correct_classified)
print "Incorrectly classified: %s" %str(banditron.incorrect_classified)
print "Error Rate: %s" %str(banditron.error_rate)
if __name__=="__main__":
main()