diff --git a/sdxl_train_controlnet.py b/sdxl_train_controlnet.py index bab152f88..00026d2cc 100644 --- a/sdxl_train_controlnet.py +++ b/sdxl_train_controlnet.py @@ -528,7 +528,7 @@ def remove_model(old_ckpt_name): # orig_size, crop_size, target_size, accelerator.device # ).to(weight_dtype) - embs = torch.cat([orig_size, crop_size, target_size], dim=-1).to(accelerator.device).to(weight_dtype) + embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6 # concat embeddings #vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) vector_embedding_dict = {