From fcbe1ff4662a5eeb1c04eb1132a975e763dcaca4 Mon Sep 17 00:00:00 2001 From: Oguzhan Buyuksolak Date: Wed, 25 Sep 2024 15:55:59 +0300 Subject: [PATCH] Action_tcn, move num_frames to the model class --- models/ai85net-actiontcn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/ai85net-actiontcn.py b/models/ai85net-actiontcn.py index 5ec5225e5..65947e5ee 100644 --- a/models/ai85net-actiontcn.py +++ b/models/ai85net-actiontcn.py @@ -33,6 +33,7 @@ def __init__( self.num_classes = num_classes self.cnn_out_shape = (1, 1) self.cnn_out_channel = 32 + self.num_frames = 15 num_filters = 64 len_frame_vector = self.cnn_out_shape[0]*self.cnn_out_shape[1]*self.cnn_out_channel @@ -113,15 +114,14 @@ def create_cnn(self, x): def forward(self, x): """Forward prop""" batch_size = x.shape[0] - num_frames = x.shape[1] cnnoutputs = torch.zeros_like(x) cnnoutputs = cnnoutputs[:, :, :self.cnn_out_channel, :self.cnn_out_shape[0], :self.cnn_out_shape[1]] - for i in range(15): + for i in range(self.num_frames): prep_out = self.create_prep(x[:, i]) cnnoutputs = assign_cnnoutputs(cnnoutputs, i, self.create_cnn(prep_out)) - tcn_input = cnnoutputs.permute(0, 1, 3, 4, 2).reshape(batch_size, num_frames, -1) \ + tcn_input = cnnoutputs.permute(0, 1, 3, 4, 2).reshape(batch_size, self.num_frames, -1) \ .permute(0, 2, 1) tcn_output = self.tcn0(tcn_input) tcn_output = self.tcn1(tcn_output)