Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Apr 28, 2024
1 parent eec3308 commit 7d82d8a
Showing 1 changed file with 62 additions and 32 deletions.
94 changes: 62 additions & 32 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,44 +85,47 @@ def forward(self, x: torch.Tensor):
x = self.norm1(x)
# Window partition
if self.window_size > 0:
if x.shape[1] != int(torch.prod(torch.tensor(self.input_size))):
raise ValueError(
f"Input tensor spatial dimension {x.shape[1]} should be equal to {self.input_size} product"
)

if len(self.input_size) == 2:
x = rearrange(x, "b (h w) c -> b h w c", h=self.input_size[0], w=self.input_size[1])
x, pad_hw = window_partition(x, self.window_size)
x = rearrange(x, "b h w c -> b (h w) c", h=self.window_size, w=self.window_size)
elif len(self.input_size) == 3:
x = rearrange(
x, "b (h w d) c -> b h w d c", h=self.input_size[0], w=self.input_size[1], d=self.input_size[2]
)
x, pad_hwd = window_partition_3d(x, self.window_size)
x = rearrange(x, "b h w d c -> b (h w d) c", h=self.window_size, w=self.window_size, d=self.window_size)

x, pad = window_partition(x, self.window_size, self.input_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
if len(self.input_size) == 2:
x = rearrange(x, "b (h w) c -> b h w c", h=self.window_size, w=self.window_size)
x = window_unpartition(x, self.window_size, pad_hw, (self.input_size[0], self.input_size[1]))
x = rearrange(x, "b h w c -> b (h w) c", h=self.input_size[0], w=self.input_size[1])
elif len(self.input_size) == 3:
x = rearrange(x, "b (h w d) c -> b h w d c", h=self.window_size, w=self.window_size, d=self.window_size)
x = window_unpartition_3d(
x, self.window_size, pad_hwd, (self.input_size[0], self.input_size[1], self.input_size[2])
)
x = rearrange(
x, "b h w d c -> b (h w d) c", h=self.input_size[0], w=self.input_size[1], d=self.input_size[2]
)

x = window_unpartition(x, self.window_size, pad, self.input_size)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x


def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]:
"""
Partition into non-overlapping windows with padding if needed. Support 2D and 3D.
Args:
x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size)
input_size (Tuple): input spatial dimension: (H, W) or (H, W, D)
window_size (int): window size
Returns:
windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C].
with n = 1...len(input_size) and window_size_i == window_size.
(S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size)
"""
if x.shape[1] != int(torch.prod(torch.tensor(input_size))):
raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product")

if len(input_size) == 2:
x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1])
x, pad_hw = window_partition_2d(x, window_size)
x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size)
return x, pad_hw
elif len(input_size) == 3:
x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2])
x, pad_hwd = window_partition_3d(x, window_size)
x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size)
return x, pad_hwd
else:
raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ")


def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed. Support only 2D.
Args:
Expand Down Expand Up @@ -169,10 +172,37 @@ def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor
x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c)
return windows, (hp, wp, dp)
...


def window_unpartition(
def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C].
with n = 1...len(spatial_dims) and window_size == window_size_i
window_size (int): window size.
pad (Tuple): padded spatial dims (H, W) or (H, W, D)
spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding.
Returns:
x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C].
"""
x: torch.Tensor
if len(spatial_dims) == 2:
x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size)
x = window_unpartition_2d(x, window_size, pad, spatial_dims)
x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1])
return x
elif len(spatial_dims) == 3:
x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size)
x = window_unpartition_3d(x, window_size, pad, spatial_dims)
x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2])
return x
else:
raise ValueError()


def window_unpartition_2d(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Expand Down

0 comments on commit 7d82d8a

Please sign in to comment.