Skip to content

Commit

Permalink
Action_tcn, move num_frames to the model class
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Sep 25, 2024
1 parent 1e8ad35 commit fcbe1ff
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions models/ai85net-actiontcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
self.num_classes = num_classes
self.cnn_out_shape = (1, 1)
self.cnn_out_channel = 32
self.num_frames = 15
num_filters = 64
len_frame_vector = self.cnn_out_shape[0]*self.cnn_out_shape[1]*self.cnn_out_channel

Expand Down Expand Up @@ -113,15 +114,14 @@ def create_cnn(self, x):
def forward(self, x):
"""Forward prop"""
batch_size = x.shape[0]
num_frames = x.shape[1]
cnnoutputs = torch.zeros_like(x)
cnnoutputs = cnnoutputs[:, :, :self.cnn_out_channel, :self.cnn_out_shape[0],
:self.cnn_out_shape[1]]

for i in range(15):
for i in range(self.num_frames):
prep_out = self.create_prep(x[:, i])
cnnoutputs = assign_cnnoutputs(cnnoutputs, i, self.create_cnn(prep_out))
tcn_input = cnnoutputs.permute(0, 1, 3, 4, 2).reshape(batch_size, num_frames, -1) \
tcn_input = cnnoutputs.permute(0, 1, 3, 4, 2).reshape(batch_size, self.num_frames, -1) \
.permute(0, 2, 1)
tcn_output = self.tcn0(tcn_input)
tcn_output = self.tcn1(tcn_output)
Expand Down

0 comments on commit fcbe1ff

Please sign in to comment.