-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathmain.py
112 lines (84 loc) · 2.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from ST_LSTM import *
class Encoder(nn.Module):
"""docstring for Encoder"""
def __init__(self):
super(Encoder, self).__init__()
self.n_layers = n_layers
self.hidden_sizes = hidden_sizes
self.input_sizes = input_sizes
self.M = dict()
self.C = dict()
self.H = dict()
self.h = dict()
self.c = dict()
self.m = { 0 : Parameter(torch.Tensor())}
self.h[0] = Parameter(torch.Tensor(shape))
self.h[1] = Parameter(torch.Tensor(shape))
self.h[2] = Parameter(torch.Tensor(shape))
self.h[3] = Parameter(torch.Tensor(shape))
self.c[0] = Parameter(torch.Tensor(shape))
self.c[1] = Parameter(torch.Tensor(shape))
self.c[2] = Parameter(torch.Tensor(shape))
self.c[3] = Parameter(torch.Tensor(shape))
self.cells = nn.ModuleList([])
for i in self.n_layers:
cell = SpatioTemporal_LSTM(self.hidden_sizes[i], self.input_sizes[i])
self.cells.append(cell)
self._reset_parameters()
def _reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input_, first_timestep = False):
for j,cell in enumerate(self.cells):
if first_timestep == True:
if j == 0:
self.H[j], self.C[j], self.M[j] = cell(input_, (self.h[j],self.c[j],self.m[j]))
continue
else:
self.H[j], self.C[j], self.M[j] = cell(self.H[j-1], (self.h[j],self.c[j],self.M[j-1]))
continue
if j==0:
self.H[j], self.C[j], self.M[j] = cell(input_, (self.H[j],self.C[j],self.M[self.n_layers-1]))
continue
self.H[j], self.C[j], self.M[j] = cell(self.H[j-1],(self.H[j],self.C[j],self.M[j-1]))
return self.H , self. C, self.M
def initHidden(self):
result = Variable(torch.zeros(1, 1, self.hidden_size)) #################SHAPE
if use_cuda:
return result.cuda()
else:
return result
class Decoder(nn.Module):
"""
docstring for Decoder
Using M in zigzag fashion as suggested in Spatiotemporal LSTM
"""
def __init__(self):
super(Decoder, self).__init__()
self.n_layers = n_layers
self.hidden_sizes = hidden_sizes
self.input_sizes = input_sizes
self.cells = nn.ModuleList([])
for i in self.n_layers:
cell = SpatioTemporal_LSTM(self.hidden_sizes[i], self.input_sizes[i])
self.cells.append(cell)
def forward(self, input_, C,H,M):
for j,cell in enumerate(self.cells):
if j==0:
H[j], C[j],M[j] = cell(input_,(H[j],C[j],M[n_layers-1]))
if j==n_layers-1:
H[j], C[j],M[j] = cell(H[j-1],(H[j],C[j],M[j-1]))
output = H[j]
return output
def initHidden(self):
result = Variable(torch.zeros(1, 1, self.hidden_size))
if use_cuda:
return result.cuda()
else:
return result