Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/WillReynolds5/ComfyUI
Browse files Browse the repository at this point in the history
  • Loading branch information
WillReynolds5 committed Oct 29, 2024
2 parents 5115e6e + b262104 commit 1d85a2a
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 9 deletions.
31 changes: 23 additions & 8 deletions comfy/ldm/modules/diffusionmodules/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,9 @@ def forward_core_with_concat(
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if self.register_length > 0:
context = torch.cat(
(
Expand All @@ -961,14 +963,25 @@ def forward_core_with_concat(

# context is B, L', D
# x is B, L, D
blocks_replace = patches_replace.get("dit", {})
blocks = len(self.joint_blocks)
for i in range(blocks):
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
return out

out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if control is not None:
control_o = control.get("output")
if i < len(control_o):
Expand All @@ -986,6 +999,7 @@ def forward(
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
"""
Forward pass of DiT.
Expand All @@ -1007,7 +1021,7 @@ def forward(
if context is not None:
context = self.context_embedder(context)

x = self.forward_core_with_concat(x, c, context, control)
x = self.forward_core_with_concat(x, c, context, control, transformer_options)

x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]]
Expand All @@ -1021,7 +1035,8 @@ def forward(
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control)
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)

61 changes: 60 additions & 1 deletion comfy_extras/nodes_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import comfy.model_management
import nodes
import torch

import re
class TripleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -95,11 +95,70 @@ def INPUT_TYPES(s):
CATEGORY = "conditioning/controlnet"
DEPRECATED = True

class SkipLayerGuidanceSD3:
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Experimental implementation by Dango233@StabilityAI.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"

CATEGORY = "advanced/guidance"


def skip_guidance(self, model, layers, scale, start_percent, end_percent):
if layers == "" or layers == None:
return (model, )
# check if layer is comma separated integers
def skip(args, extra_args):
return args

model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)

def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()

for layer in layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
model_sampling.percent_to_sigma(start_percent)

sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result

layers = re.findall(r'\d+', layers)
layers = [int(i) for i in layers]
m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)

return (m, )


NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader,
"EmptySD3LatentImage": EmptySD3LatentImage,
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3,
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand Down

0 comments on commit 1d85a2a

Please sign in to comment.