Skip to content

Commit

Permalink
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Sep 28, 2024
2 parents 904c85e + 822fe57 commit b81377c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
5 changes: 3 additions & 2 deletions library/flux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,8 +999,9 @@ def get_unit_index(self, is_double: bool, index: int):

def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu
if self.blocks_to_swap is None:
raise ValueError("Block swap is not enabled.")
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
for i in range(self.num_block_units - self.blocks_to_swap):
for b in self.get_block_unit(i):
b.to(self.device)
Expand Down
4 changes: 3 additions & 1 deletion library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
pred = model(
img=img,
img_ids=img_ids,
Expand All @@ -325,7 +326,8 @@ def denoise(
)

img = img + (t_prev - t_curr) * pred


model.prepare_block_swap_before_forward()
return img


Expand Down
4 changes: 4 additions & 0 deletions networks/convert_flux_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ def main(args):
state_dict = convert_ai_toolkit_to_sd_scripts(state_dict)
elif args.src == "sd-scripts" and args.dst == "ai-toolkit":
state_dict = convert_sd_scripts_to_ai_toolkit(state_dict)

# eliminate 'shared tensors'
for k in list(state_dict.keys()):
state_dict[k] = state_dict[k].detach().clone()
else:
raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported")

Expand Down

0 comments on commit b81377c

Please sign in to comment.