-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
77 lines (57 loc) · 2.54 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from quantumbrain.graph import graph
from quantumbrain.serialize import serialize, unserialize
from quantumbrain import debug
class Model:
def __init__(self, inputs, outputs, name=None):
self.inputs = inputs
self.outputs = outputs
self.trainable = False
self.name = "model" if name is None else name
def __call__(self, *args, **kwargs):
return self.call(args[0])
def call(self, x):
for layer in graph.layers.values():
layer.forward_degree = layer.forward_degree_origin
layer.trainable = self.trainable
removed = [self.inputs]
while len(removed) > 0:
layer = removed.pop()
self.__run_forward(layer, x)
for next_layer in layer.next:
next_layer.forward_degree -= 1
if next_layer.forward_degree == 0:
removed.append(next_layer)
return self.outputs.out
def __run_forward(self, layer, x):
previous = layer.previous
if len(previous) == 0:
layer.run_forward(x)
elif len(previous) == 1:
layer.run_forward(previous[0].out)
if debug.debug_mode:
debug.dump("{}.forward()".format(layer.name))
else:
next_input = []
for item in previous:
next_input.append(item.out)
layer.run_forward(next_input)
if debug.debug_mode:
debug.dump("{}.forward()".format(layer.name))
def summary(self):
print("Model: \"{}\"".format(self.name))
print("----------------------------------------------------------")
print("{:30}\t\t{:30}".format("Layer(type)", "Output Shape"))
print("==========================================================")
layers = list(graph.layers.values())
for layer in layers[:len(layers) - 1]:
layer_name_col = "{}({})".format(layer.name, layer.__class__.__name__)
print("{:30}\t\t{}".format(layer_name_col, str(layer.shape)))
print("---------------------------------------------------------")
last_layer = layers[-1]
last_layer_name_col = "{}({})".format(last_layer.name, last_layer.__class__.__name__)
print("{:30}\t\t{}".format(last_layer_name_col, str(last_layer.shape)))
print("==========================================================")
def save(self, path):
serialize(path, graph.params, graph.grads)
def restore(self, path):
graph.params, graph.grads = unserialize(path)