From 30c0c81351a14e6820c98ee22c24f3edc9062e55 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Oct 2024 00:48:32 -0400 Subject: [PATCH 1/3] Add a way to patch blocks in SD3. --- comfy/ldm/modules/diffusionmodules/mmdit.py | 31 +++++++++++++++------ 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 43a269fa04b..6f8f506ce02 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -949,7 +949,9 @@ def forward_core_with_concat( c_mod: torch.Tensor, context: Optional[torch.Tensor] = None, control = None, + transformer_options = {}, ) -> torch.Tensor: + patches_replace = transformer_options.get("patches_replace", {}) if self.register_length > 0: context = torch.cat( ( @@ -961,14 +963,25 @@ def forward_core_with_concat( # context is B, L', D # x is B, L, D + blocks_replace = patches_replace.get("dit", {}) blocks = len(self.joint_blocks) for i in range(blocks): - context, x = self.joint_blocks[i]( - context, - x, - c=c_mod, - use_checkpoint=self.use_checkpoint, - ) + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) + return out + + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) + context = out["txt"] + x = out["img"] + else: + context, x = self.joint_blocks[i]( + context, + x, + c=c_mod, + use_checkpoint=self.use_checkpoint, + ) if control is not None: control_o = control.get("output") if i < len(control_o): @@ -986,6 +999,7 @@ def forward( y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, control = None, + transformer_options = {}, ) -> torch.Tensor: """ Forward pass of DiT. @@ -1007,7 +1021,7 @@ def forward( if context is not None: context = self.context_embedder(context) - x = self.forward_core_with_concat(x, c, context, control) + x = self.forward_core_with_concat(x, c, context, control, transformer_options) x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) return x[:,:,:hw[-2],:hw[-1]] @@ -1021,7 +1035,8 @@ def forward( context: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, control = None, + transformer_options = {}, **kwargs, ) -> torch.Tensor: - return super().forward(x, timesteps, context=context, y=y, control=control) + return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options) From 954683d0dbd8f098c5485422a1e27f33fe951c32 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Tue, 29 Oct 2024 21:59:21 +0800 Subject: [PATCH 2/3] SLG first implementation for SD3.5 (#5404) * SLG first implementation for SD3.5 * * Simplify and align with comfy style --- comfy_extras/nodes_sd3.py | 61 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index ddf538deb9d..6bd06f4a3f6 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -3,7 +3,7 @@ import comfy.model_management import nodes import torch - +import re class TripleCLIPLoader: @classmethod def INPUT_TYPES(s): @@ -95,11 +95,70 @@ def INPUT_TYPES(s): CATEGORY = "conditioning/controlnet" DEPRECATED = True +class SkipLayerGuidanceSD3: + ''' + Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. + Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) + Experimental implementation by Dango233@StabilityAI. + ''' + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "layers": ("STRING", {"default": "7,8,9", "multiline": False}), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), + "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "skip_guidance" + + CATEGORY = "advanced/guidance" + + + def skip_guidance(self, model, layers, scale, start_percent, end_percent): + if layers == "" or layers == None: + return (model, ) + # check if layer is comma separated integers + assert layers.replace(",", "").isdigit(), "Layers must be comma separated integers" + def skip(args, extra_args): + return args + + model_sampling = model.get_model_object("model_sampling") + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + x = args["input"] + model_options = args["model_options"].copy() + + for layer in layers: + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) + model_sampling.percent_to_sigma(start_percent) + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + sigma_ = sigma[0].item() + if scale > 0 and sigma_ > sigma_end and sigma_ < sigma_start: + (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + cfg_result = cfg_result + (cond_pred - slg) * scale + return cfg_result + + layers = re.findall(r'\d+', layers) + layers = [int(i) for i in layers] + m = model.clone() + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m, ) + + NODE_CLASS_MAPPINGS = { "TripleCLIPLoader": TripleCLIPLoader, "EmptySD3LatentImage": EmptySD3LatentImage, "CLIPTextEncodeSD3": CLIPTextEncodeSD3, "ControlNetApplySD3": ControlNetApplySD3, + "SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, } NODE_DISPLAY_NAME_MAPPINGS = { From 770ab200f296d8d0269d37fdca84bb742cee38b1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Oct 2024 10:11:46 -0400 Subject: [PATCH 3/3] Cleanup SkipLayerGuidanceSD3 node. --- comfy_extras/nodes_sd3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 6bd06f4a3f6..4d664093cd4 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -104,7 +104,7 @@ class SkipLayerGuidanceSD3: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL", ), - "layers": ("STRING", {"default": "7,8,9", "multiline": False}), + "layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) @@ -119,11 +119,12 @@ def skip_guidance(self, model, layers, scale, start_percent, end_percent): if layers == "" or layers == None: return (model, ) # check if layer is comma separated integers - assert layers.replace(",", "").isdigit(), "Layers must be comma separated integers" def skip(args, extra_args): return args model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) def post_cfg_function(args): model = args["model"] @@ -137,10 +138,9 @@ def post_cfg_function(args): for layer in layers: model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) model_sampling.percent_to_sigma(start_percent) - sigma_start = model_sampling.percent_to_sigma(start_percent) - sigma_end = model_sampling.percent_to_sigma(end_percent) + sigma_ = sigma[0].item() - if scale > 0 and sigma_ > sigma_end and sigma_ < sigma_start: + if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start: (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) cfg_result = cfg_result + (cond_pred - slg) * scale return cfg_result