Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of SPADE Network + tests and modification of SPADE normalisation #7775

Merged
merged 11 commits into from
Jun 3, 2024
7 changes: 3 additions & 4 deletions monai/networks/blocks/spade_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks import ADN, Convolution
from monai.networks.blocks import Convolution
from monai.networks.layers.utils import get_norm_layer


class SPADE(nn.Module):
Expand Down Expand Up @@ -50,9 +51,7 @@ def __init__(
norm_params = {}
if len(norm_params) != 0:
norm = (norm, norm_params)
self.param_free_norm = ADN(
act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc
)
self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc)
self.mlp_shared = Convolution(
spatial_dims=spatial_dims,
in_channels=label_nc,
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
)
from .spade_autoencoderkl import SPADEAutoencoderKL
from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
from .spade_network import SPADENet
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from .torchvision_fc import TorchVisionFCModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
Expand Down
8 changes: 4 additions & 4 deletions monai/networks/nets/spade_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
__all__ = ["SPADEDiffusionModelUNet"]


class SPADEResnetBlock(nn.Module):
class SPADEDiffResBlock(nn.Module):
"""
Residual block with timestep conditioning and SPADE norm.
Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(
resnet_in_channels = prev_output_channel if i == 0 else out_channels

resnets.append(
SPADEResnetBlock(
SPADEDiffResBlock(
spatial_dims=spatial_dims,
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -356,7 +356,7 @@ def __init__(
resnet_in_channels = prev_output_channel if i == 0 else out_channels

resnets.append(
SPADEResnetBlock(
SPADEDiffResBlock(
spatial_dims=spatial_dims,
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -491,7 +491,7 @@ def __init__(
resnet_in_channels = prev_output_channel if i == 0 else out_channels

resnets.append(
SPADEResnetBlock(
SPADEDiffResBlock(
spatial_dims=spatial_dims,
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
Expand Down
Loading
Loading