-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: vgrau98 <[email protected]>
- Loading branch information
Showing
3 changed files
with
209 additions
and
181 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from monai.utils import optional_import | ||
|
||
rearrange, _ = optional_import("einops", name="rearrange") | ||
|
||
|
||
def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]: | ||
""" | ||
Partition into non-overlapping windows with padding if needed. Support 2D and 3D. | ||
Args: | ||
x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size) | ||
input_size (Tuple): input spatial dimension: (H, W) or (H, W, D) | ||
window_size (int): window size | ||
Returns: | ||
windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C]. | ||
with n = 1...len(input_size) and window_size_i == window_size. | ||
(S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size) | ||
""" | ||
if x.shape[1] != int(torch.prod(torch.tensor(input_size))): | ||
raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product") | ||
|
||
if len(input_size) == 2: | ||
x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1]) | ||
x, pad_hw = window_partition_2d(x, window_size) | ||
x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size) | ||
return x, pad_hw | ||
elif len(input_size) == 3: | ||
x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2]) | ||
x, pad_hwd = window_partition_3d(x, window_size) | ||
x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size) | ||
return x, pad_hwd | ||
else: | ||
raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ") | ||
|
||
|
||
def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: | ||
""" | ||
Partition into non-overlapping windows with padding if needed. Support only 2D. | ||
Args: | ||
x (tensor): input tokens with [B, H, W, C]. | ||
window_size (int): window size. | ||
Returns: | ||
windows: windows after partition with [B * num_windows, window_size, window_size, C]. | ||
(Hp, Wp): padded height and width before partition | ||
""" | ||
batch, h, w, c = x.shape | ||
|
||
pad_h = (window_size - h % window_size) % window_size | ||
pad_w = (window_size - w % window_size) % window_size | ||
if pad_h > 0 or pad_w > 0: | ||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) | ||
hp, wp = h + pad_h, w + pad_w | ||
|
||
x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c) | ||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) | ||
return windows, (hp, wp) | ||
|
||
|
||
def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]: | ||
""" | ||
Partition into non-overlapping windows with padding if needed. 3d implementation. | ||
Args: | ||
x (tensor): input tokens with [B, H, W, D, C]. | ||
window_size (int): window size. | ||
Returns: | ||
windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C]. | ||
(Hp, Wp, Dp): padded height, width and depth before partition | ||
""" | ||
batch, h, w, d, c = x.shape | ||
|
||
pad_h = (window_size - h % window_size) % window_size | ||
pad_w = (window_size - w % window_size) % window_size | ||
pad_d = (window_size - d % window_size) % window_size | ||
if pad_h > 0 or pad_w > 0 or pad_d > 0: | ||
x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h)) | ||
hp, wp, dp = h + pad_h, w + pad_w, d + pad_d | ||
|
||
x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c) | ||
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c) | ||
return windows, (hp, wp, dp) | ||
|
||
|
||
def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor: | ||
""" | ||
Window unpartition into original sequences and removing padding. | ||
Args: | ||
windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C]. | ||
with n = 1...len(spatial_dims) and window_size == window_size_i | ||
window_size (int): window size. | ||
pad (Tuple): padded spatial dims (H, W) or (H, W, D) | ||
spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding. | ||
Returns: | ||
x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C]. | ||
""" | ||
x: torch.Tensor | ||
if len(spatial_dims) == 2: | ||
x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size) | ||
x = window_unpartition_2d(x, window_size, pad, spatial_dims) | ||
x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1]) | ||
return x | ||
elif len(spatial_dims) == 3: | ||
x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size) | ||
x = window_unpartition_3d(x, window_size, pad, spatial_dims) | ||
x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2]) | ||
return x | ||
else: | ||
raise ValueError() | ||
|
||
|
||
def window_unpartition_2d( | ||
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] | ||
) -> torch.Tensor: | ||
""" | ||
Window unpartition into original sequences and removing padding. | ||
Args: | ||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. | ||
window_size (int): window size. | ||
pad_hw (Tuple): padded height and width (hp, wp). | ||
hw (Tuple): original height and width (H, W) before padding. | ||
Returns: | ||
x: unpartitioned sequences with [B, H, W, C]. | ||
""" | ||
hp, wp = pad_hw | ||
h, w = hw | ||
batch = windows.shape[0] // (hp * wp // window_size // window_size) | ||
x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1) | ||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1) | ||
|
||
if hp > h or wp > w: | ||
x = x[:, :h, :w, :].contiguous() | ||
return x | ||
|
||
|
||
def window_unpartition_3d( | ||
windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int] | ||
) -> torch.Tensor: | ||
""" | ||
Window unpartition into original sequences and removing padding. 3d implementation. | ||
Args: | ||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C]. | ||
window_size (int): window size. | ||
pad_hwd (Tuple): padded height, width and depth (hp, wp, dp). | ||
hwd (Tuple): original height, width and depth (H, W, D) before padding. | ||
Returns: | ||
x: unpartitioned sequences with [B, H, W, D, C]. | ||
""" | ||
hp, wp, dp = pad_hwd | ||
h, w, d = hwd | ||
batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size) | ||
x = windows.view( | ||
batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1 | ||
) | ||
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1) | ||
|
||
if hp > h or wp > w or dp > d: | ||
x = x[:, :h, :w, :d, :].contiguous() | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.