diff --git a/networks/lora.py b/networks/lora.py index f7e935952..cd70cf3f8 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -247,14 +247,13 @@ def get_mask_for_x(self, x): area = x.size()[1] mask = self.network.mask_dic.get(area, None) - if mask is None: - # raise ValueError(f"mask is None for resolution {area}") + if mask is None or len(x.size()) == 2: # emb_layers in SDXL doesn't have mask # if "emb" not in self.lora_name: # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts - if len(x.size()) != 4: + if len(x.size()) == 3: mask = torch.reshape(mask, (1, -1, 1)) return mask