From 195bec9e0a201d97dd7bbb874f70519067d75eca Mon Sep 17 00:00:00 2001 From: Jason Yik Date: Wed, 17 Jul 2024 13:21:19 -0600 Subject: [PATCH] update rnncell synops calculation --- neurobench/utils.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/neurobench/utils.py b/neurobench/utils.py index 82fb635..0b7f749 100644 --- a/neurobench/utils.py +++ b/neurobench/utils.py @@ -249,12 +249,13 @@ def single_layer_MACs(inputs, layer, total=False): # NOTE: these activation functions are currently NOT included in NeuroBench # if no explicit states are passed to recurrent layers, then h and c are initialized to zero (pytorch convention) layer_bin = make_binary_copy(layer, all_ones=total) - # transpose from batches, timesteps, features to features, batches - # print(layer_bin.weight_ih.shape) + # layer_weight_ih is [4*hidden_size, input_size] + # inputs[0].transpose(0, -1) is [input_size, batch_size] out_ih = torch.matmul( layer_bin.weight_ih, inputs[0].transpose(0, -1) ) # accounts for i,f,g,o out_hh = torch.zeros_like(out_ih) + # out shape is 4*h, batch, for hidden feature dim h biases = 0 bias_ih = 0 @@ -290,10 +291,16 @@ def single_layer_MACs(inputs, layer, total=False): out += biases # biases are added here for computation of c and h which depend on correct computation of ifgo out[out != 0] = 1 - # out is vector with i,f,g,o - ifgo = out.reshape(4, -1) # each row is one of i,f,g,o + + # out is vector with i,f,g,o, shape is 4*h, batch + hidden = out.shape[0] // 4 + ifgo = out.reshape(4, hidden, -1) # 4, h, B if in_states: - c_1 = ifgo[1, :] * inputs[1][1] + ifgo[0, :] * ifgo[2, :] + # inputs[1][1] shape is [B, h] + # element-wise multiply (vector products f*c + i*g) + c_1 = ( + ifgo[1, :] * inputs[1][1].transpose(0, -1) + ifgo[0, :] * ifgo[2, :] + ) else: c_1 = ifgo[0, :] * ifgo[2, :] @@ -323,9 +330,13 @@ def single_layer_MACs(inputs, layer, total=False): # Wir*x, Whr*h, Wiz*x, Whz*h, Win*x, Whn*h macs += rzn.sum() # multiplications of all weights and inputs/hidden states rzn += biases # add biases - rzn = rzn.reshape(3, -1) # each row is one of r,z,n - out_hh_n = out_hh.reshape(3, -1)[2, :] + bias_hh.reshape(3, -1)[2, :] + hidden = rzn.shape[0] // 3 + rzn = rzn.reshape(3, hidden, -1) # 3, h, B + out_hh = out_hh.reshape(3, hidden, -1) + bias_hh = bias_hh.reshape(3, hidden, -1) + + out_hh_n = out_hh[2, :] + bias_hh[2, :] r = rzn[0, :] # get r z = rzn[1, :] @@ -334,23 +345,21 @@ def single_layer_MACs(inputs, layer, total=False): n_hh_term_macs = ( r * out_hh_n - ) # elementwise_multiplication to find macs ofr*(Whn*h + bhn) specifically + ) # elementwise_multiplication to find macs of r*(Whn*h + bhn) specifically n_hh_term_macs[n_hh_term_macs != 0] = 1 macs += n_hh_term_macs.sum() # note hh part of n is already binarized, does not influence calculation of macs for n - n = ( - out_hh.reshape(3, -1)[2, :] - + bias_ih.reshape(3, -1)[2, :] - + n_hh_term_macs - ) + n = out_hh[2, :] + bias_ih[2, :] + n_hh_term_macs n[n != 0] = 1 z_a = 1 - z # only do this now because affects z_a z[z != 0] = 1 z_a[z_a != 0] = 1 t_1 = z_a * n - t_2 = z * inputs[1] + t_2 = z * inputs[1].transpose( + 0, -1 + ) # inputs are shape [B, h], all else is [h, B] t_1[t_1 != 0] = 1 t_2[t_2 != 0] = 1