Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update rnncell synops calculation #227

Merged
merged 5 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions neurobench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,13 @@ def single_layer_MACs(inputs, layer, total=False):
# NOTE: these activation functions are currently NOT included in NeuroBench
# if no explicit states are passed to recurrent layers, then h and c are initialized to zero (pytorch convention)
layer_bin = make_binary_copy(layer, all_ones=total)
# transpose from batches, timesteps, features to features, batches
# print(layer_bin.weight_ih.shape)
# layer_weight_ih is [4*hidden_size, input_size]
# inputs[0].transpose(0, -1) is [input_size, batch_size]
out_ih = torch.matmul(
layer_bin.weight_ih, inputs[0].transpose(0, -1)
) # accounts for i,f,g,o
out_hh = torch.zeros_like(out_ih)
# out shape is 4*h, batch, for hidden feature dim h

biases = 0
bias_ih = 0
Expand Down Expand Up @@ -290,10 +291,16 @@ def single_layer_MACs(inputs, layer, total=False):

out += biases # biases are added here for computation of c and h which depend on correct computation of ifgo
out[out != 0] = 1
# out is vector with i,f,g,o
ifgo = out.reshape(4, -1) # each row is one of i,f,g,o

# out is vector with i,f,g,o, shape is 4*h, batch
hidden = out.shape[0] // 4
ifgo = out.reshape(4, hidden, -1) # 4, h, B
if in_states:
c_1 = ifgo[1, :] * inputs[1][1] + ifgo[0, :] * ifgo[2, :]
# inputs[1][1] shape is [B, h]
# element-wise multiply (vector products f*c + i*g)
c_1 = (
ifgo[1, :] * inputs[1][1].transpose(0, -1) + ifgo[0, :] * ifgo[2, :]
)
else:
c_1 = ifgo[0, :] * ifgo[2, :]

Expand Down Expand Up @@ -323,9 +330,13 @@ def single_layer_MACs(inputs, layer, total=False):
# Wir*x, Whr*h, Wiz*x, Whz*h, Win*x, Whn*h
macs += rzn.sum() # multiplications of all weights and inputs/hidden states
rzn += biases # add biases
rzn = rzn.reshape(3, -1) # each row is one of r,z,n

out_hh_n = out_hh.reshape(3, -1)[2, :] + bias_hh.reshape(3, -1)[2, :]
hidden = rzn.shape[0] // 3
rzn = rzn.reshape(3, hidden, -1) # 3, h, B
out_hh = out_hh.reshape(3, hidden, -1)
bias_hh = bias_hh.reshape(3, hidden, -1)

out_hh_n = out_hh[2, :] + bias_hh[2, :]
r = rzn[0, :] # get r
z = rzn[1, :]

Expand All @@ -334,23 +345,21 @@ def single_layer_MACs(inputs, layer, total=False):

n_hh_term_macs = (
r * out_hh_n
) # elementwise_multiplication to find macs ofr*(Whn*h + bhn) specifically
) # elementwise_multiplication to find macs of r*(Whn*h + bhn) specifically
n_hh_term_macs[n_hh_term_macs != 0] = 1
macs += n_hh_term_macs.sum()

# note hh part of n is already binarized, does not influence calculation of macs for n
n = (
out_hh.reshape(3, -1)[2, :]
+ bias_ih.reshape(3, -1)[2, :]
+ n_hh_term_macs
)
n = out_hh[2, :] + bias_ih[2, :] + n_hh_term_macs
n[n != 0] = 1
z_a = 1 - z
# only do this now because affects z_a
z[z != 0] = 1
z_a[z_a != 0] = 1
t_1 = z_a * n
t_2 = z * inputs[1]
t_2 = z * inputs[1].transpose(
0, -1
) # inputs are shape [B, h], all else is [h, B]

t_1[t_1 != 0] = 1
t_2[t_2 != 0] = 1
Expand Down
23 changes: 14 additions & 9 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@


class TestStaticMetrics(unittest.TestCase):

def setUp(self):

self.dummy_net = nn.Module()

self.net = models.net
Expand Down Expand Up @@ -339,11 +337,11 @@ def test_synaptic_ops(self):
self.assertEqual(syn_ops["Effective_ACs"], 1000)

# test lstm network

batch_size = 2
inp = [
torch.ones(1, 25),
(torch.ones(1, 5), torch.ones(1, 5)),
] # input, (hidden, cell)
torch.ones(batch_size, 25),
(torch.ones(batch_size, 5), torch.ones(batch_size, 5)),
] # input (batch_size, inp_size), (hidden, cell)
inp[0][0, 0] = 4 # avoid getting classified as snn
model = TorchModel(self.net_lstm)

Expand All @@ -358,8 +356,11 @@ def test_synaptic_ops(self):
self.assertEqual(syn_ops["Effective_ACs"], 0)

# test RNN network

inp = [torch.ones(1, 25), torch.ones(1, 5)] # input, (hidden, cell)
batch_size = 2
inp = [
torch.ones(batch_size, 25),
torch.ones(batch_size, 5),
] # input, (hidden, cell)
inp[0][0, 0] = 4 # avoid getting classified as snn
model = TorchModel(self.net_RNN)

Expand All @@ -374,7 +375,11 @@ def test_synaptic_ops(self):
self.assertEqual(syn_ops["Effective_ACs"], 0)

# test GRU network
inp = [torch.ones(1, 25), torch.ones(1, 5)] # input, (hidden, cell)
batch_size = 2
inp = [
torch.ones(batch_size, 25),
torch.ones(batch_size, 5),
] # input, (hidden, cell)
inp[0][0, 0] = 4 # avoid getting classified as snn
model = TorchModel(self.net_GRU)

Expand Down