Skip to content

Commit

Permalink
Fix calculate_macs() for Linear layers. (#318)
Browse files Browse the repository at this point in the history
* Fix calculate_macs() for Linear layers.

Fix MACs in lst.out and lstm_half.out.

* Add test for torch.nn.Linear.

* Change groud-truth Total mult-adds in flan_t5_small.out.

 MACs increased from 280.27M to 18.25G because of the Linear layer fix.

---------

Co-authored-by: Andrew Lavin <[email protected]>
  • Loading branch information
andravin and AndrewLavin authored Nov 5, 2024
1 parent 29166cc commit 38ab72b
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/test_output/flan_t5_small.out
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ T5ForConditionalGeneration [3, 100, 512]
Total params: 128,743,488
Trainable params: 128,743,488
Non-trainable params: 0
Total mult-adds (M): 280.27
Total mult-adds (G): 18.25
==============================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 326.28
Expand Down
15 changes: 15 additions & 0 deletions tests/test_output/linear.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
========================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param # Mult-Adds
========================================================================================================================
Linear [32, 16, 8] [32, 16, 64] 576 294,912
========================================================================================================================
Total params: 576
Trainable params: 576
Non-trainable params: 0
Total mult-adds (M): 0.29
========================================================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.26
Params size (MB): 0.00
Estimated Total Size (MB): 0.28
========================================================================================================================
4 changes: 2 additions & 2 deletions tests/test_output/lstm.out
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ LSTMNet (LSTMNet) -- [100, 20]
│ └─weight_hh_l1 [2048, 512] ├─1,048,576
│ └─bias_ih_l1 [2048] ├─2,048
│ └─bias_hh_l1 [2048] └─2,048
├─Linear (decoder) -- [1, 100, 20] 10,260 10,260
├─Linear (decoder) -- [1, 100, 20] 10,260 1,026,000
│ └─weight [512, 20] ├─10,240
│ └─bias [20] └─20
========================================================================================================================
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 376.85
Total mult-adds (M): 377.86
========================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.67
Expand Down
4 changes: 2 additions & 2 deletions tests/test_output/lstm_half.out
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ Layer (type (var_name)) Kernel Shape Output Shape
LSTMNet (LSTMNet) -- [100, 20] -- --
├─Embedding (embedding) -- [1, 100, 300] 6,000 6,000
├─LSTM (encoder) -- [1, 100, 512] 3,768,320 376,832,000
├─Linear (decoder) -- [1, 100, 20] 10,260 10,260
├─Linear (decoder) -- [1, 100, 20] 10,260 1,026,000
========================================================================================================================
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 376.85
Total mult-adds (M): 377.86
========================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.33
Expand Down
14 changes: 14 additions & 0 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,20 @@ def test_groups() -> None:
)


def test_linear() -> None:
input_shape = (32, 16, 8)
module = nn.Linear(8, 64)
col_names = ("input_size", "output_size", "num_params", "mult_adds")
input_data = torch.randn(*input_shape)
summary(
module,
input_data=input_data,
depth=1,
col_names=col_names,
col_width=20,
)


def test_single_input_batch_dim() -> None:
model = SingleInputNet()
col_names = ("kernel_size", "input_size", "output_size", "num_params", "mult_adds")
Expand Down
2 changes: 2 additions & 0 deletions torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def calculate_macs(self) -> None:
self.macs += int(
cur_params * prod(self.output_size[:1] + self.output_size[2:])
)
elif "Linear" in self.class_name:
self.macs += int(cur_params * prod(self.output_size[:-1]))
else:
self.macs += self.output_size[0] * cur_params
# RNN modules have inner weights such as weight_ih_l0
Expand Down

0 comments on commit 38ab72b

Please sign in to comment.