Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
korneelf1 committed Jul 24, 2024
1 parent 0d2f25b commit 61eee43
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@


class TestStaticMetrics(unittest.TestCase):

def setUp(self):

self.dummy_net = nn.Module()

self.net = models.net
Expand Down Expand Up @@ -353,13 +351,16 @@ def test_synaptic_ops(self):

syn = synaptic_operations()
syn_ops = syn(model, out, inp)

self.assertEqual(syn_ops["Effective_MACs"], 615)
self.assertEqual(syn_ops["Effective_ACs"], 0)

# test RNN network
batch_size = 2
inp = [torch.ones(batch_size, 25), torch.ones(batch_size, 5)] # input, (hidden, cell)
inp = [
torch.ones(batch_size, 25),
torch.ones(batch_size, 5),
] # input, (hidden, cell)
inp[0][0, 0] = 4 # avoid getting classified as snn
model = TorchModel(self.net_RNN)

Expand All @@ -375,7 +376,10 @@ def test_synaptic_ops(self):

# test GRU network
batch_size = 2
inp = [torch.ones(batch_size, 25), torch.ones(batch_size, 5)] # input, (hidden, cell)
inp = [
torch.ones(batch_size, 25),
torch.ones(batch_size, 5),
] # input, (hidden, cell)
inp[0][0, 0] = 4 # avoid getting classified as snn
model = TorchModel(self.net_GRU)

Expand Down Expand Up @@ -403,6 +407,7 @@ def test_membrane_potential_updates(self):

self.assertEqual(tot_mem_updates, 50)


# TODO: refactor this metric if needed
# def test_neuron_update_metric():
# net_relu_0 = nn.Sequential(
Expand Down

0 comments on commit 61eee43

Please sign in to comment.