Skip to content

Commit

Permalink
improved vram managment for encode_batch_size
Browse files Browse the repository at this point in the history
  • Loading branch information
cubiq committed May 21, 2024
1 parent 20125bf commit d33265a
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 593 deletions.
198 changes: 0 additions & 198 deletions CrossAttentionPatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,201 +188,3 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No
#out = out + out_ip

return out_ip.to(dtype=dtype)

"""
class CrossAttentionPatch:
# forward for patching
def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'):
self.weights = [weight]
self.ipadapters = [ipadapter]
self.conds = [cond]
self.conds_alt = [cond_alt]
self.unconds = [uncond]
self.weight_types = [weight_type]
self.masks = [mask]
self.sigma_starts = [sigma_start]
self.sigma_ends = [sigma_end]
self.unfold_batch = [unfold_batch]
self.embeds_scaling = [embeds_scaling]
self.number = number
self.layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 16 # TODO: check if this is a valid condition to detect all models
self.k_key = str(self.number*2+1) + "_to_k_ip"
self.v_key = str(self.number*2+1) + "_to_v_ip"
def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'):
self.weights.append(weight)
self.ipadapters.append(ipadapter)
self.conds.append(cond)
self.conds_alt.append(cond_alt)
self.unconds.append(uncond)
self.weight_types.append(weight_type)
self.masks.append(mask)
self.sigma_starts.append(sigma_start)
self.sigma_ends.append(sigma_end)
self.unfold_batch.append(unfold_batch)
self.embeds_scaling.append(embeds_scaling)
def __call__(self, q, k, v, extra_options):
dtype = q.dtype
cond_or_uncond = extra_options["cond_or_uncond"]
sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9
block_type = extra_options["block"][0]
#block_id = extra_options["block"][1]
t_idx = extra_options["transformer_index"]
# extra options for AnimateDiff
ad_params = extra_options['ad_params'] if "ad_params" in extra_options else None
b = q.shape[0]
seq_len = q.shape[1]
batch_prompt = b // len(cond_or_uncond)
out = optimized_attention(q, k, v, extra_options["n_heads"])
_, _, oh, ow = extra_options["original_shape"]
for weight, cond, cond_alt, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch, embeds_scaling in zip(self.weights, self.conds, self.conds_alt, self.unconds, self.ipadapters, self.masks, self.weight_types, self.sigma_starts, self.sigma_ends, self.unfold_batch, self.embeds_scaling):
if sigma <= sigma_start and sigma >= sigma_end:
if weight_type == 'ease in':
weight = weight * (0.05 + 0.95 * (1 - t_idx / self.layers))
elif weight_type == 'ease out':
weight = weight * (0.05 + 0.95 * (t_idx / self.layers))
elif weight_type == 'ease in-out':
weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (self.layers/2)) / (self.layers/2)))
elif weight_type == 'reverse in-out':
weight = weight * (0.05 + 0.95 * (abs(t_idx - (self.layers/2)) / (self.layers/2)))
elif weight_type == 'weak input' and block_type == 'input':
weight = weight * 0.2
elif weight_type == 'weak middle' and block_type == 'middle':
weight = weight * 0.2
elif weight_type == 'weak output' and block_type == 'output':
weight = weight * 0.2
elif weight_type == 'strong middle' and (block_type == 'input' or block_type == 'output'):
weight = weight * 0.2
elif isinstance(weight, dict):
if t_idx not in weight:
continue
weight = weight[t_idx]
if cond_alt is not None and t_idx in cond_alt:
cond = cond_alt[t_idx]
del cond_alt
if unfold_batch:
# Check AnimateDiff context window
if ad_params is not None and ad_params["sub_idxs"] is not None:
if isinstance(weight, torch.Tensor):
weight = tensor_to_size(weight, ad_params["full_length"])
weight = torch.Tensor(weight[ad_params["sub_idxs"]])
if torch.all(weight == 0):
continue
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
elif weight == 0:
continue
# if image length matches or exceeds full_length get sub_idx images
if cond.shape[0] >= ad_params["full_length"]:
cond = torch.Tensor(cond[ad_params["sub_idxs"]])
uncond = torch.Tensor(uncond[ad_params["sub_idxs"]])
# otherwise get sub_idxs images
else:
cond = tensor_to_size(cond, ad_params["full_length"])
uncond = tensor_to_size(uncond, ad_params["full_length"])
cond = cond[ad_params["sub_idxs"]]
uncond = uncond[ad_params["sub_idxs"]]
else:
if isinstance(weight, torch.Tensor):
weight = tensor_to_size(weight, batch_prompt)
if torch.all(weight == 0):
continue
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
elif weight == 0:
continue
cond = tensor_to_size(cond, batch_prompt)
uncond = tensor_to_size(uncond, batch_prompt)
k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond)
k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond)
v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond)
v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond)
else:
# TODO: should we always convert the weights to a tensor?
if isinstance(weight, torch.Tensor):
weight = tensor_to_size(weight, batch_prompt)
if torch.all(weight == 0):
continue
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
elif weight == 0:
continue
k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond).repeat(batch_prompt, 1, 1)
k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond).repeat(batch_prompt, 1, 1)
v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond).repeat(batch_prompt, 1, 1)
v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond).repeat(batch_prompt, 1, 1)
ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0)
ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0)
if embeds_scaling == 'K+mean(V) w/ C penalty':
scaling = float(ip_k.shape[2]) / 1280.0
weight = weight * scaling
ip_k = ip_k * weight
ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True)
ip_v = (ip_v - ip_v_mean) + ip_v_mean * weight
out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
del ip_v_mean
elif embeds_scaling == 'K+V w/ C penalty':
scaling = float(ip_k.shape[2]) / 1280.0
weight = weight * scaling
ip_k = ip_k * weight
ip_v = ip_v * weight
out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
elif embeds_scaling == 'K+V':
ip_k = ip_k * weight
ip_v = ip_v * weight
out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
else:
#ip_v = ip_v * weight
out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
out_ip = out_ip * weight # I'm doing this to get the same results as before
if mask is not None:
mask_h = oh / math.sqrt(oh * ow / seq_len)
mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0)
mask_w = seq_len // mask_h
# check if using AnimateDiff and sliding context window
if (mask.shape[0] > 1 and ad_params is not None and ad_params["sub_idxs"] is not None):
# if mask length matches or exceeds full_length, get sub_idx masks
if mask.shape[0] >= ad_params["full_length"]:
mask = torch.Tensor(mask[ad_params["sub_idxs"]])
mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1)
else:
mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1)
mask = tensor_to_size(mask, ad_params["full_length"])
mask = mask[ad_params["sub_idxs"]]
else:
mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1)
mask = tensor_to_size(mask, batch_prompt)
mask = mask.repeat(len(cond_or_uncond), 1, 1)
mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2])
# covers cases where extreme aspect ratios can cause the mask to have a wrong size
mask_len = mask_h * mask_w
if mask_len < seq_len:
pad_len = seq_len - mask_len
pad1 = pad_len // 2
pad2 = pad_len - pad1
mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0)
elif mask_len > seq_len:
crop_start = (mask_len - seq_len) // 2
mask = mask[:, crop_start:crop_start+seq_len, :]
out_ip = out_ip * mask
out = out + out_ip
return out.to(dtype=dtype)
"""
65 changes: 56 additions & 9 deletions IPAdapterPlus.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,60 @@ def init_proj_faceid(self):
return image_proj_model

@torch.inference_mode()
def get_image_embeds(self, clip_embed, clip_embed_zeroed):
image_prompt_embeds = self.image_proj_model(clip_embed)
uncond_image_prompt_embeds = self.image_proj_model(clip_embed_zeroed)
def get_image_embeds(self, clip_embed, clip_embed_zeroed, batch_size):
torch_device = model_management.get_torch_device()
intermediate_device = model_management.intermediate_device()

if batch_size == 0:
batch_size = clip_embed.shape[0]
intermediate_device = torch_device
elif batch_size > clip_embed.shape[0]:
batch_size = clip_embed.shape[0]

clip_embed = torch.split(clip_embed, batch_size, dim=0)
clip_embed_zeroed = torch.split(clip_embed_zeroed, batch_size, dim=0)

image_prompt_embeds = []
uncond_image_prompt_embeds = []

for ce, cez in zip(clip_embed, clip_embed_zeroed):
image_prompt_embeds.append(self.image_proj_model(ce.to(torch_device)).to(intermediate_device))
uncond_image_prompt_embeds.append(self.image_proj_model(cez.to(torch_device)).to(intermediate_device))

del clip_embed, clip_embed_zeroed

image_prompt_embeds = torch.cat(image_prompt_embeds, dim=0)
uncond_image_prompt_embeds = torch.cat(uncond_image_prompt_embeds, dim=0)

torch.cuda.empty_cache()

#image_prompt_embeds = self.image_proj_model(clip_embed)
#uncond_image_prompt_embeds = self.image_proj_model(clip_embed_zeroed)
return image_prompt_embeds, uncond_image_prompt_embeds

@torch.inference_mode()
def get_image_embeds_faceid_plus(self, face_embed, clip_embed, s_scale, shortcut):
embeds = self.image_proj_model(face_embed, clip_embed, scale=s_scale, shortcut=shortcut)
def get_image_embeds_faceid_plus(self, face_embed, clip_embed, s_scale, shortcut, batch_size):
torch_device = model_management.get_torch_device()
intermediate_device = model_management.intermediate_device()

if batch_size == 0:
batch_size = clip_embed.shape[0]
intermediate_device = torch_device
elif batch_size > clip_embed.shape[0]:
batch_size = clip_embed.shape[0]

face_embed_batch = torch.split(face_embed, batch_size, dim=0)
clip_embed_batch = torch.split(clip_embed, batch_size, dim=0)

embeds = []
for face_embed, clip_embed in zip(face_embed_batch, clip_embed_batch):
embeds.append(self.image_proj_model(face_embed.to(torch_device), clip_embed.to(torch_device), scale=s_scale, shortcut=shortcut).to(intermediate_device))

del face_embed_batch, clip_embed_batch

embeds = torch.cat(embeds, dim=0)
torch.cuda.empty_cache()
#embeds = self.image_proj_model(face_embed, clip_embed, scale=s_scale, shortcut=shortcut)
return embeds

class To_KV(nn.Module):
Expand Down Expand Up @@ -351,16 +397,17 @@ def ipadapter_execute(model,
).to(device, dtype=dtype)

if is_faceid and is_plus:
cond = ipa.get_image_embeds_faceid_plus(face_cond_embeds, img_cond_embeds, weight_faceidv2, is_faceidv2)
cond = ipa.get_image_embeds_faceid_plus(face_cond_embeds, img_cond_embeds, weight_faceidv2, is_faceidv2, encode_batch_size)
# TODO: check if noise helps with the uncond face embeds
uncond = ipa.get_image_embeds_faceid_plus(torch.zeros_like(face_cond_embeds), img_uncond_embeds, weight_faceidv2, is_faceidv2)
uncond = ipa.get_image_embeds_faceid_plus(torch.zeros_like(face_cond_embeds), img_uncond_embeds, weight_faceidv2, is_faceidv2, encode_batch_size)
else:
cond, uncond = ipa.get_image_embeds(img_cond_embeds, img_uncond_embeds)
cond, uncond = ipa.get_image_embeds(img_cond_embeds, img_uncond_embeds, encode_batch_size)
if img_comp_cond_embeds is not None:
cond_comp = ipa.get_image_embeds(img_comp_cond_embeds, img_uncond_embeds)[0]
cond_comp = ipa.get_image_embeds(img_comp_cond_embeds, img_uncond_embeds, encode_batch_size)[0]

cond = cond.to(device, dtype=dtype)
uncond = uncond.to(device, dtype=dtype)

cond_alt = None
if img_comp_cond_embeds is not None:
cond_alt = { 3: cond_comp.to(device, dtype=dtype) }
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Please consider a [Github Sponsorship](https://github.com/sponsors/cubiq) or [Pa

## Important updates

**2024/05/21**: Improved memory allocation when `encode_batch_size`. Useful mostly for very long animations.

**2024/05/02**: Add `encode_batch_size` to the Advanced batch node. This can be useful for animations with a lot of frames to reduce the VRAM usage during the image encoding. Please note that results will be slightly different based on the batch size.

**2024/04/27**: Refactored the IPAdapterWeights mostly useful for AnimateDiff animations.
Expand Down
Loading

0 comments on commit d33265a

Please sign in to comment.