Skip to content

Commit

Permalink
Propagate kernel size through attention Attention-UNet (Project-MONAI…
Browse files Browse the repository at this point in the history
…#7734)

Fixes Project-MONAI#7726.

### Description
Passes the `kernel_size` parameter to `ConvBlocks` within Attention
UNet, creating a net with the expected number of parameters.

Using the example in Project-MONAI#7726 on this branch:
```
from monai.networks.nets import AttentionUnet

model = AttentionUnet(
        spatial_dims = 2,
        in_channels = 1,
        out_channels = 1,
        channels = (2, 4, 8, 16),
        strides = (2,2,2),
        kernel_size = 5,
        up_kernel_size = 5
)
```
outputs the expected values: 
```
Total params: 18,846
Trainable params: 18,846
Non-trainable params: 0
Total mult-adds (M): 0.37
```

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Peter Kaplinsky <[email protected]>
Co-authored-by: Peter Kaplinsky <[email protected]>
  • Loading branch information
Pkaps25 and Peter Kaplinsky authored May 7, 2024
1 parent e1a69b0 commit fe733b0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
12 changes: 10 additions & 2 deletions monai/networks/nets/attentionunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
kernel_size: Sequence[int] | int = 3,
strides: int = 1,
dropout=0.0,
):
Expand Down Expand Up @@ -219,7 +219,13 @@ def __init__(
self.kernel_size = kernel_size
self.dropout = dropout

head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout)
head = ConvBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=channels[0],
dropout=dropout,
kernel_size=self.kernel_size,
)
reduce_channels = Convolution(
spatial_dims=spatial_dims,
in_channels=channels[0],
Expand All @@ -245,6 +251,7 @@ def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module:
out_channels=channels[1],
strides=strides[0],
dropout=self.dropout,
kernel_size=self.kernel_size,
),
subblock,
),
Expand All @@ -271,6 +278,7 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -
out_channels=out_channels,
strides=strides,
dropout=self.dropout,
kernel_size=self.kernel_size,
),
up_kernel_size=self.up_kernel_size,
strides=strides,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_attentionunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
import unittest

import torch
import torch.nn as nn

import monai.networks.nets.attentionunet as att
from tests.utils import skip_if_no_cuda, skip_if_quick


def get_net_parameters(net: nn.Module) -> int:
"""Returns the total number of parameters in a Module."""
return sum(param.numel() for param in net.parameters())


class TestAttentionUnet(unittest.TestCase):

def test_attention_block(self):
Expand Down Expand Up @@ -50,6 +56,20 @@ def test_attentionunet(self):
self.assertEqual(output.shape[0], input.shape[0])
self.assertEqual(output.shape[1], 2)

def test_attentionunet_kernel_size(self):
args_dict = {
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 2,
"channels": (3, 4, 5),
"up_kernel_size": 5,
"strides": (1, 2),
}
model_a = att.AttentionUnet(**args_dict, kernel_size=5)
model_b = att.AttentionUnet(**args_dict, kernel_size=7)
self.assertEqual(get_net_parameters(model_a), 3534)
self.assertEqual(get_net_parameters(model_b), 5574)

@skip_if_no_cuda
def test_attentionunet_gpu(self):
for dims in [2, 3]:
Expand Down

0 comments on commit fe733b0

Please sign in to comment.