forked from nachiket92/conv-social-pooling
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
143 lines (109 loc) · 5.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from __future__ import division
import torch
from torch.autograd import Variable
import torch.nn as nn
from utils import outputActivation
class highwayNet(nn.Module):
## Initialization
def __init__(self,args):
super(highwayNet, self).__init__()
## Unpack arguments
self.args = args
## Use gpu flag
self.use_cuda = args['use_cuda']
# Flag for maneuver based (True) vs uni-modal decoder (False)
self.use_maneuvers = args['use_maneuvers']
# Flag for train mode (True) vs test-mode (False)
self.train_flag = args['train_flag']
## Sizes of network layers
self.encoder_size = args['encoder_size']
self.decoder_size = args['decoder_size']
self.in_length = args['in_length']
self.out_length = args['out_length']
self.grid_size = args['grid_size']
self.soc_conv_depth = args['soc_conv_depth']
self.conv_3x1_depth = args['conv_3x1_depth']
self.dyn_embedding_size = args['dyn_embedding_size']
self.input_embedding_size = args['input_embedding_size']
self.num_lat_classes = args['num_lat_classes']
self.num_lon_classes = args['num_lon_classes']
self.soc_embedding_size = (((args['grid_size'][0]-4)+1)//2)*self.conv_3x1_depth
## Define network weights
# Input embedding layer
self.ip_emb = torch.nn.Linear(2,self.input_embedding_size)
# Encoder LSTM
self.enc_lstm = torch.nn.LSTM(self.input_embedding_size,self.encoder_size,1)
# Vehicle dynamics embedding
self.dyn_emb = torch.nn.Linear(self.encoder_size,self.dyn_embedding_size)
# Convolutional social pooling layer and social embedding layer
self.soc_conv = torch.nn.Conv2d(self.encoder_size,self.soc_conv_depth,3)
self.conv_3x1 = torch.nn.Conv2d(self.soc_conv_depth, self.conv_3x1_depth, (3,1))
self.soc_maxpool = torch.nn.MaxPool2d((2,1),padding = (1,0))
# FC social pooling layer (for comparison):
# self.soc_fc = torch.nn.Linear(self.soc_conv_depth * self.grid_size[0] * self.grid_size[1], (((args['grid_size'][0]-4)+1)//2)*self.conv_3x1_depth)
# Decoder LSTM
if self.use_maneuvers:
self.dec_lstm = torch.nn.LSTM(self.soc_embedding_size + self.dyn_embedding_size + self.num_lat_classes + self.num_lon_classes, self.decoder_size)
else:
self.dec_lstm = torch.nn.LSTM(self.soc_embedding_size + self.dyn_embedding_size, self.decoder_size)
# Output layers:
self.op = torch.nn.Linear(self.decoder_size,5)
self.op_lat = torch.nn.Linear(self.soc_embedding_size + self.dyn_embedding_size, self.num_lat_classes)
self.op_lon = torch.nn.Linear(self.soc_embedding_size + self.dyn_embedding_size, self.num_lon_classes)
# Activations:
self.leaky_relu = torch.nn.LeakyReLU(0.1)
self.relu = torch.nn.ReLU()
self.softmax = torch.nn.Softmax(dim=1)
## Forward Pass
def forward(self,hist,nbrs,masks,lat_enc,lon_enc):
## Forward pass hist:
_,(hist_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist)))
hist_enc = self.leaky_relu(self.dyn_emb(hist_enc.view(hist_enc.shape[1],hist_enc.shape[2])))
## Forward pass nbrs
_, (nbrs_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs)))
nbrs_enc = nbrs_enc.view(nbrs_enc.shape[1], nbrs_enc.shape[2])
## Masked scatter
soc_enc = torch.zeros_like(masks).float()
soc_enc = soc_enc.masked_scatter_(masks, nbrs_enc)
soc_enc = soc_enc.permute(0,3,2,1)
## Apply convolutional social pooling:
soc_enc = self.soc_maxpool(self.leaky_relu(self.conv_3x1(self.leaky_relu(self.soc_conv(soc_enc)))))
soc_enc = soc_enc.view(-1,self.soc_embedding_size)
## Apply fc soc pooling
# soc_enc = soc_enc.contiguous()
# soc_enc = soc_enc.view(-1, self.soc_conv_depth * self.grid_size[0] * self.grid_size[1])
# soc_enc = self.leaky_relu(self.soc_fc(soc_enc))
## Concatenate encodings:
enc = torch.cat((soc_enc,hist_enc),1)
if self.use_maneuvers:
## Maneuver recognition:
lat_pred = self.softmax(self.op_lat(enc))
lon_pred = self.softmax(self.op_lon(enc))
if self.train_flag:
## Concatenate maneuver encoding of the true maneuver
enc = torch.cat((enc, lat_enc, lon_enc), 1)
fut_pred = self.decode(enc)
return fut_pred, lat_pred, lon_pred
else:
fut_pred = []
## Predict trajectory distributions for each maneuver class
for k in range(self.num_lon_classes):
for l in range(self.num_lat_classes):
lat_enc_tmp = torch.zeros_like(lat_enc)
lon_enc_tmp = torch.zeros_like(lon_enc)
lat_enc_tmp[:, l] = 1
lon_enc_tmp[:, k] = 1
enc_tmp = torch.cat((enc, lat_enc_tmp, lon_enc_tmp), 1)
fut_pred.append(self.decode(enc_tmp))
return fut_pred, lat_pred, lon_pred
else:
fut_pred = self.decode(enc)
return fut_pred
def decode(self,enc):
enc = enc.repeat(self.out_length, 1, 1)
h_dec, _ = self.dec_lstm(enc)
h_dec = h_dec.permute(1, 0, 2)
fut_pred = self.op(h_dec)
fut_pred = fut_pred.permute(1, 0, 2)
fut_pred = outputActivation(fut_pred)
return fut_pred