diff --git a/neurobench/utils.py b/neurobench/utils.py index 82fb635..0b7f749 100644 --- a/neurobench/utils.py +++ b/neurobench/utils.py @@ -249,12 +249,13 @@ def single_layer_MACs(inputs, layer, total=False): # NOTE: these activation functions are currently NOT included in NeuroBench # if no explicit states are passed to recurrent layers, then h and c are initialized to zero (pytorch convention) layer_bin = make_binary_copy(layer, all_ones=total) - # transpose from batches, timesteps, features to features, batches - # print(layer_bin.weight_ih.shape) + # layer_weight_ih is [4*hidden_size, input_size] + # inputs[0].transpose(0, -1) is [input_size, batch_size] out_ih = torch.matmul( layer_bin.weight_ih, inputs[0].transpose(0, -1) ) # accounts for i,f,g,o out_hh = torch.zeros_like(out_ih) + # out shape is 4*h, batch, for hidden feature dim h biases = 0 bias_ih = 0 @@ -290,10 +291,16 @@ def single_layer_MACs(inputs, layer, total=False): out += biases # biases are added here for computation of c and h which depend on correct computation of ifgo out[out != 0] = 1 - # out is vector with i,f,g,o - ifgo = out.reshape(4, -1) # each row is one of i,f,g,o + + # out is vector with i,f,g,o, shape is 4*h, batch + hidden = out.shape[0] // 4 + ifgo = out.reshape(4, hidden, -1) # 4, h, B if in_states: - c_1 = ifgo[1, :] * inputs[1][1] + ifgo[0, :] * ifgo[2, :] + # inputs[1][1] shape is [B, h] + # element-wise multiply (vector products f*c + i*g) + c_1 = ( + ifgo[1, :] * inputs[1][1].transpose(0, -1) + ifgo[0, :] * ifgo[2, :] + ) else: c_1 = ifgo[0, :] * ifgo[2, :] @@ -323,9 +330,13 @@ def single_layer_MACs(inputs, layer, total=False): # Wir*x, Whr*h, Wiz*x, Whz*h, Win*x, Whn*h macs += rzn.sum() # multiplications of all weights and inputs/hidden states rzn += biases # add biases - rzn = rzn.reshape(3, -1) # each row is one of r,z,n - out_hh_n = out_hh.reshape(3, -1)[2, :] + bias_hh.reshape(3, -1)[2, :] + hidden = rzn.shape[0] // 3 + rzn = rzn.reshape(3, hidden, -1) # 3, h, B + out_hh = out_hh.reshape(3, hidden, -1) + bias_hh = bias_hh.reshape(3, hidden, -1) + + out_hh_n = out_hh[2, :] + bias_hh[2, :] r = rzn[0, :] # get r z = rzn[1, :] @@ -334,23 +345,21 @@ def single_layer_MACs(inputs, layer, total=False): n_hh_term_macs = ( r * out_hh_n - ) # elementwise_multiplication to find macs ofr*(Whn*h + bhn) specifically + ) # elementwise_multiplication to find macs of r*(Whn*h + bhn) specifically n_hh_term_macs[n_hh_term_macs != 0] = 1 macs += n_hh_term_macs.sum() # note hh part of n is already binarized, does not influence calculation of macs for n - n = ( - out_hh.reshape(3, -1)[2, :] - + bias_ih.reshape(3, -1)[2, :] - + n_hh_term_macs - ) + n = out_hh[2, :] + bias_ih[2, :] + n_hh_term_macs n[n != 0] = 1 z_a = 1 - z # only do this now because affects z_a z[z != 0] = 1 z_a[z_a != 0] = 1 t_1 = z_a * n - t_2 = z * inputs[1] + t_2 = z * inputs[1].transpose( + 0, -1 + ) # inputs are shape [B, h], all else is [h, B] t_1[t_1 != 0] = 1 t_2[t_2 != 0] = 1 diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 9964694..b0c2a21 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -25,9 +25,7 @@ class TestStaticMetrics(unittest.TestCase): - def setUp(self): - self.dummy_net = nn.Module() self.net = models.net @@ -339,11 +337,11 @@ def test_synaptic_ops(self): self.assertEqual(syn_ops["Effective_ACs"], 1000) # test lstm network - + batch_size = 2 inp = [ - torch.ones(1, 25), - (torch.ones(1, 5), torch.ones(1, 5)), - ] # input, (hidden, cell) + torch.ones(batch_size, 25), + (torch.ones(batch_size, 5), torch.ones(batch_size, 5)), + ] # input (batch_size, inp_size), (hidden, cell) inp[0][0, 0] = 4 # avoid getting classified as snn model = TorchModel(self.net_lstm) @@ -358,8 +356,11 @@ def test_synaptic_ops(self): self.assertEqual(syn_ops["Effective_ACs"], 0) # test RNN network - - inp = [torch.ones(1, 25), torch.ones(1, 5)] # input, (hidden, cell) + batch_size = 2 + inp = [ + torch.ones(batch_size, 25), + torch.ones(batch_size, 5), + ] # input, (hidden, cell) inp[0][0, 0] = 4 # avoid getting classified as snn model = TorchModel(self.net_RNN) @@ -374,7 +375,11 @@ def test_synaptic_ops(self): self.assertEqual(syn_ops["Effective_ACs"], 0) # test GRU network - inp = [torch.ones(1, 25), torch.ones(1, 5)] # input, (hidden, cell) + batch_size = 2 + inp = [ + torch.ones(batch_size, 25), + torch.ones(batch_size, 5), + ] # input, (hidden, cell) inp[0][0, 0] = 4 # avoid getting classified as snn model = TorchModel(self.net_GRU)