Skip to content

Commit

Permalink
Adapt UNet2D for supre-resolution (huggingface#1385)
Browse files Browse the repository at this point in the history
* allow disabling self attention

* add class_embedding

* fix copies

* fix condition

* fix copies

* do_self_attention -> only_cross_attention

* fix copies

* num_classes -> num_class_embeds

* fix default value
  • Loading branch information
patil-suraj authored Nov 24, 2022
1 parent 30f6f44 commit cecdd8b
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
Expand Down Expand Up @@ -157,6 +158,7 @@ def __init__(
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
)
for d in range(num_layers)
]
Expand Down Expand Up @@ -387,14 +389,17 @@ def __init__(
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
Expand Down Expand Up @@ -461,7 +466,11 @@ def forward(self, hidden_states, context=None, timestep=None):
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states) + hidden_states

if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states

# 2. Cross-Attention
norm_hidden_states = (
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_down_block(
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
Expand Down Expand Up @@ -78,6 +79,7 @@ def get_down_block(
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
Expand Down Expand Up @@ -143,6 +145,7 @@ def get_up_block(
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
Expand Down Expand Up @@ -174,6 +177,7 @@ def get_up_block(
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
Expand Down Expand Up @@ -530,6 +534,7 @@ def __init__(
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -564,6 +569,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
Expand Down Expand Up @@ -1129,6 +1135,7 @@ def __init__(
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -1165,6 +1172,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
Expand Down
19 changes: 19 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
Expand All @@ -109,6 +110,7 @@ def __init__(
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
):
super().__init__()

Expand All @@ -124,10 +126,17 @@ def __init__(

self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)

self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])

if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)

if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)

Expand All @@ -153,6 +162,7 @@ def __init__(
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.down_blocks.append(down_block)

Expand All @@ -177,6 +187,7 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
Expand Down Expand Up @@ -207,6 +218,7 @@ def __init__(
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -258,6 +270,7 @@ def forward(
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Expand Down Expand Up @@ -310,6 +323,12 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)

if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb

# 2. pre-process
sample = self.conv_in(sample)

Expand Down
23 changes: 23 additions & 0 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
Expand All @@ -177,6 +178,7 @@ def __init__(
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
):
super().__init__()

Expand All @@ -192,10 +194,17 @@ def __init__(

self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)

self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])

if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)

if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)

Expand All @@ -221,6 +230,7 @@ def __init__(
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.down_blocks.append(down_block)

Expand All @@ -245,6 +255,7 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
Expand Down Expand Up @@ -275,6 +286,7 @@ def __init__(
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -326,6 +338,7 @@ def forward(
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Expand Down Expand Up @@ -378,6 +391,12 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)

if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb

# 2. pre-process
sample = self.conv_in(sample)

Expand Down Expand Up @@ -648,6 +667,7 @@ def __init__(
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -682,6 +702,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
Expand Down Expand Up @@ -861,6 +882,7 @@ def __init__(
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -897,6 +919,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
Expand Down

0 comments on commit cecdd8b

Please sign in to comment.