diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0aacddf34d00..4c970d062d64 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -100,6 +100,7 @@ def __init__( activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, + only_cross_attention: bool = False, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -157,6 +158,7 @@ def __init__( activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, + only_cross_attention=only_cross_attention, ) for d in range(num_layers) ] @@ -387,14 +389,17 @@ def __init__( activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, + only_cross_attention: bool = False, ): super().__init__() + self.only_cross_attention = only_cross_attention self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.attn2 = CrossAttention( @@ -461,7 +466,11 @@ def forward(self, hidden_states, context=None, timestep=None): norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) - hidden_states = self.attn1(norm_hidden_states) + hidden_states + + if self.only_cross_attention: + hidden_states = self.attn1(norm_hidden_states, context) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states) + hidden_states # 2. Cross-Attention norm_hidden_states = ( diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 5a8a97187f11..e919d21f4a03 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -34,6 +34,7 @@ def get_down_block( downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, + only_cross_attention=False, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -78,6 +79,7 @@ def get_down_block( attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -143,6 +145,7 @@ def get_up_block( cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, + only_cross_attention=False, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -174,6 +177,7 @@ def get_up_block( attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -530,6 +534,7 @@ def __init__( add_downsample=True, dual_cross_attention=False, use_linear_projection=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -564,6 +569,7 @@ def __init__( cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) ) else: @@ -1129,6 +1135,7 @@ def __init__( add_upsample=True, dual_cross_attention=False, use_linear_projection=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -1165,6 +1172,7 @@ def __init__( cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) ) else: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 206097149331..97a26ced5400 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -98,6 +98,7 @@ def __init__( "DownBlock2D", ), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, @@ -109,6 +110,7 @@ def __init__( attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, + num_class_embeds: Optional[int] = None, ): super().__init__() @@ -124,10 +126,17 @@ def __init__( self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + # class embedding + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) @@ -153,6 +162,7 @@ def __init__( downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], ) self.down_blocks.append(down_block) @@ -177,6 +187,7 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -207,6 +218,7 @@ def __init__( attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -258,6 +270,7 @@ def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -310,6 +323,12 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 6d521228e31b..24e79729a599 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -166,6 +166,7 @@ def __init__( "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", ), + only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, @@ -177,6 +178,7 @@ def __init__( attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, + num_class_embeds: Optional[int] = None, ): super().__init__() @@ -192,10 +194,17 @@ def __init__( self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + # class embedding + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) @@ -221,6 +230,7 @@ def __init__( downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], ) self.down_blocks.append(down_block) @@ -245,6 +255,7 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -275,6 +286,7 @@ def __init__( attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -326,6 +338,7 @@ def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -378,6 +391,12 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + # 2. pre-process sample = self.conv_in(sample) @@ -648,6 +667,7 @@ def __init__( add_downsample=True, dual_cross_attention=False, use_linear_projection=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -682,6 +702,7 @@ def __init__( cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) ) else: @@ -861,6 +882,7 @@ def __init__( add_upsample=True, dual_cross_attention=False, use_linear_projection=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -897,6 +919,7 @@ def __init__( cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) ) else: