-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathClassifierKeras.py
35 lines (30 loc) · 991 Bytes
/
ClassifierKeras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
#A classifier built on a Keras CNN
class ClassifierKeras:
def __init__(self, symbols, network):
self.symbols = symbols
self.network = network
#Classify character object
def classify(self,char):
activation= self.network.predict(np.reshape(char.image/255,(1,45,45,1)))
activation = np.reshape(activation,(len(self.symbols)))
char.symbol= self.getSymbol(activation)
#Map one-hot vector to symbol
def getSymbol(self,activation):
maxactivation = activation[0]
maxindex=0
for i in range(1,len(activation)):
if activation[i]>maxactivation:
maxactivation=activation[i]
maxindex=i
if maxactivation < 0.01: return "?"
return self.symbols[maxindex]
#Sort predicted symbols by probability
def getProbSortedSymbols(self,char):
activation= self.network.predict(char.image/255)
zipped=zip(activation,range(len(activation)))
zipped.sort()
index = []
for element in zipped:
index.append(self.symbols[element[1]])
return index