Skip to content

Commit

Permalink
update rnncell synops calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonlyik committed Jul 17, 2024
1 parent f835982 commit 195bec9
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions neurobench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,13 @@ def single_layer_MACs(inputs, layer, total=False):
# NOTE: these activation functions are currently NOT included in NeuroBench
# if no explicit states are passed to recurrent layers, then h and c are initialized to zero (pytorch convention)
layer_bin = make_binary_copy(layer, all_ones=total)
# transpose from batches, timesteps, features to features, batches
# print(layer_bin.weight_ih.shape)
# layer_weight_ih is [4*hidden_size, input_size]
# inputs[0].transpose(0, -1) is [input_size, batch_size]
out_ih = torch.matmul(
layer_bin.weight_ih, inputs[0].transpose(0, -1)
) # accounts for i,f,g,o
out_hh = torch.zeros_like(out_ih)
# out shape is 4*h, batch, for hidden feature dim h

biases = 0
bias_ih = 0
Expand Down Expand Up @@ -290,10 +291,16 @@ def single_layer_MACs(inputs, layer, total=False):

out += biases # biases are added here for computation of c and h which depend on correct computation of ifgo
out[out != 0] = 1
# out is vector with i,f,g,o
ifgo = out.reshape(4, -1) # each row is one of i,f,g,o

# out is vector with i,f,g,o, shape is 4*h, batch
hidden = out.shape[0] // 4
ifgo = out.reshape(4, hidden, -1) # 4, h, B
if in_states:
c_1 = ifgo[1, :] * inputs[1][1] + ifgo[0, :] * ifgo[2, :]
# inputs[1][1] shape is [B, h]
# element-wise multiply (vector products f*c + i*g)
c_1 = (
ifgo[1, :] * inputs[1][1].transpose(0, -1) + ifgo[0, :] * ifgo[2, :]
)
else:
c_1 = ifgo[0, :] * ifgo[2, :]

Expand Down Expand Up @@ -323,9 +330,13 @@ def single_layer_MACs(inputs, layer, total=False):
# Wir*x, Whr*h, Wiz*x, Whz*h, Win*x, Whn*h
macs += rzn.sum() # multiplications of all weights and inputs/hidden states
rzn += biases # add biases
rzn = rzn.reshape(3, -1) # each row is one of r,z,n

out_hh_n = out_hh.reshape(3, -1)[2, :] + bias_hh.reshape(3, -1)[2, :]
hidden = rzn.shape[0] // 3
rzn = rzn.reshape(3, hidden, -1) # 3, h, B
out_hh = out_hh.reshape(3, hidden, -1)
bias_hh = bias_hh.reshape(3, hidden, -1)

out_hh_n = out_hh[2, :] + bias_hh[2, :]
r = rzn[0, :] # get r
z = rzn[1, :]

Expand All @@ -334,23 +345,21 @@ def single_layer_MACs(inputs, layer, total=False):

n_hh_term_macs = (
r * out_hh_n
) # elementwise_multiplication to find macs ofr*(Whn*h + bhn) specifically
) # elementwise_multiplication to find macs of r*(Whn*h + bhn) specifically
n_hh_term_macs[n_hh_term_macs != 0] = 1
macs += n_hh_term_macs.sum()

# note hh part of n is already binarized, does not influence calculation of macs for n
n = (
out_hh.reshape(3, -1)[2, :]
+ bias_ih.reshape(3, -1)[2, :]
+ n_hh_term_macs
)
n = out_hh[2, :] + bias_ih[2, :] + n_hh_term_macs
n[n != 0] = 1
z_a = 1 - z
# only do this now because affects z_a
z[z != 0] = 1
z_a[z_a != 0] = 1
t_1 = z_a * n
t_2 = z * inputs[1]
t_2 = z * inputs[1].transpose(
0, -1
) # inputs are shape [B, h], all else is [h, B]

t_1[t_1 != 0] = 1
t_2[t_2 != 0] = 1
Expand Down

0 comments on commit 195bec9

Please sign in to comment.