-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBayesianNode.py
295 lines (231 loc) · 8.16 KB
/
BayesianNode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
from PyGameHelpers import *
from Misc import errorPrint as er, debugPrint as db, normalPrint as pr
from Misc import *
class BayesianNode:
image = None
imageRect = None
imageHighlighted = None
imageHighlightedRect = None
imageConnection = None
imageConnectionRect = None
@staticmethod
def loadImages(setSize = None):
BayesianNode.image, BayesianNode.imageRect = loadImage("circle.png")
BayesianNode.imageHighlighted, BayesianNode.imageHighlightedRect = loadImage("circle_highlighted.png")
#BayesianNode.imageConnection, BayesianNode.imageConnectionRect = loadImage("arrow.png")
def __init__(self, network = None):
self.network = network
self.parents = []
self.connections = []
self.values = []
self.CPT = []
self.finalized = False
self.position = [0, 0]
self.size = 0
self.observable = True
self.name = "DefaultNode"
self.currentValue = None
def setName(self, name):
self.name = name
def getName(self):
return self.name
def setNetwork(self, network):
if(self.finalized):
er("[Node]: Node is already finalized. Can not perform setNetwork.")
return
self.network = network
def addValue(self, value):
if(self.finalized):
er("[Node]: Node is already finalized. Can not perform addValue.")
return
# Set first value added as default value
if(self.currentValue == None):
self.currentValue = value
self.values.append(value)
def getValues(self):
return self.values
def getCurrentValue(self):
return self.currentValue
def setCurrentValue(self, value):
if not value in self.values:
er("Value does not exist in node")
return
self.currentValue = value
# Update CPT's, if finalized
if self.finalized:
relevantRow = self.values.index(value)
for row in self.CPT:
for valuePair in row:
valuePair[1] = 0
for valuePair in self.CPT[relevantRow]:
valuePair[1] = 1
def addParent(self, parent):
if(self.finalized):
er("[Node]: Node is already finalized. Can not perform addParent.")
return
self.parents.append(parent)
connection = Connection(self, parent)
self.addConnection(connection)
parent.addConnection(connection)
def addParents(self, parents):
if(self.finalized):
er("[Node]: Node is already finalized. Can not perform addParents.")
return
for parent in self.parents:
self.addParent(parent)
def removeParent(self, parent):
if parent in self.parents:
self.parents.remove(parent)
self.removeConnectionTo(parent)
def getParents(self):
return self.parents
def finalize(self):
if(self.finalized == True):
er("[Node]: Node is already finalized. Please call definalize if you want to refinalize.")
if(self.network == None):
er("[Node]: Can't finalize a node which is not a part of a network, please call \"setNetwork()\"")
self.finalized = True
if(len(self.parents) == 0):
if not self.observable:
self.CPT = [[0] for i in range(len(self.values))]
for i in range(0,len(self.values)):
pr("[Node]: Please specify probability for " + self.getName() + " = " + str(self.values[i]))
ans = parseInputToNumber(input("Answer: "))
self.CPT[i] = ([[None, None]], ans)
return
else:
self.CPT = [[0] for i in range(len(self.values))]
for i in range(len(self.values)):
self.CPT[i] = ([[None, None]], 1/len(self.values))
return
# Create CPT's
valueLists = [[parent, parent.getValues()] for parent in self.parents]
totalLength = 1
for parent, valueSet in valueLists:
totalLength *= len(valueSet)
pr("[Node]: Finalizing node with a CPT of total size " + str(totalLength*len(self.values)))
self.CPT = [[0]*totalLength for i in range(len(self.values))]
if(self.observable == True):
pr("[Node]: This node is observable. You should therefore make sure only one variable per row/column is set.\n[Node]: The CPT will be normalized automatically.")
else:
pr("[Node]: This node is not observable. All column/row-combinations can contain information.\n[Node]: The CPT will be normalized automatically")
emptyArray = []
i = 0
for value in self.values:
self.getListPossibleValues(self.CPT, i, 0, [self, value], emptyArray, list(valueLists))
i = i+1
def getListPossibleValues(self, cpt, i, j, rootTuple, currentSetValues, remainingList):
# If we're at the bottom of the chain
if len(remainingList) == 0:
pr("[Node]: Please specify probability of " + str(rootTuple[0].getName()).upper() + " = " + str(rootTuple[1]).upper() + " given: ")
for parent, value in currentSetValues:
pr(str(parent.getName()) + " = " + value)
ans = parseInputToNumber(input("Answer: "))
cpt[i][j] = (list(currentSetValues), ans)
return
# Else
valueTuple = remainingList.pop()
parent = valueTuple[0]
values = valueTuple[1]
for value in values:
currentSetValues.append([parent, value])
self.getListPossibleValues(cpt, i, j, rootTuple, currentSetValues, list(remainingList))
currentSetValues.remove([parent, value])
j = j+1
def definalize(self):
self.finalized = False
def setObservable(self, boolean):
if(self.finalized):
er("[Node]: Node is already finalized. Can not perform setObservable.")
return
for parent in self.parents:
if not parent.observable:
er("[Node]: Cannot make node observable if parent is unobservable "+
"(this does not make any logical sense)")
return
self.observable = boolean
def getProbabilityOfValue(self, value):
if not value in self.values:
er("[Node]: Value is not in self.values")
# Find row
row = self.values.index(value)
# Set start, we assume a start prob of 0
probability = 0
# If this is observable we can really just check if the value passed in
# is the same as the currently set value
if(self.observable):
if(value == self.currentValue):
return 1
else:
return 0
else:
for valueTuple in self.CPT[row]:
nodeProb = valueTuple[1]
for parent, value in valueTuple[0]:
if not parent == None:
nodeProb *= parent.getProbabilityOfValue(value)
probability += nodeProb
return probability
def getProbabilityOfTuple(self, value):
if(isinstance(parent, str)):
for row in self.CPT:
for valuePair in row:
if(valuePair[0][0].getName() == parent and valuePair[0][1] == value):
return valuePair[1]
elif(isinstance(parent, BayesianNode)):
for row in self.CPT:
for valuePair in row:
if(valuePair[0][0] == parent and valuePair[0][1] == value):
return valuePair[1]
def setPosition(self, position):
self.position = position
def getPosition(self):
return self.position
def setSize(self, size):
self.size = size
def getRect(self):
return pygame.Rect(position[0] - size, position[1] - size, size, size)
def addConnection(self, connection):
if(self.finalized):
er("[Node]: Node is already finalized. Can not perform addConnection.")
return
self.connections.append(connection)
def removeConnection(self, connection):
if(self.finalized):
er("[Node]: Node is already finalized. Can not perform removeConnection.")
return
self.connect
if connection in self.connections:
if(connection.getParent() != self):
connection.getParent().removeConnection(connection)
self.connections.remove(connection)
def removeConnectionTo(self, parent):
if(finalized):
er("[Node]: Node is already finalized. Can not perform removeConnectionTo.")
return
self.connect
for connection in self.connections:
if parent == connection.getParent():
self.connections.remove(connection)
parent.removeConnection(connection)
def getConnections(self):
return self.connections
class Connection:
def __init__(self, child, parent):
self.child = child
self.parent = parent
#Generate end-points and boundingBox
self.endPoints = [self.parent.getPosition(), self.child.getPosition()]
self.size = [abs(self.endPoints[0][0] - self.endPoints[0][1]),
abs(self.endPoints[1][0] - self.endPoints[1][1])]
minX = min(self.parent.position[0], self.child.position[0])
minY = min(self.parent.position[1], self.child.position[1])
self.rect = pygame.Rect(minX, minY, self.size[0], self.size[1])
def getRect(self):
return self.rect
def getParent(self):
return self.parent
def getChild(self):
return self.child
def getNodes(self):
return [self.parent, self.child]