From a9aa52658a0d9ba7910a1d1983b650bc9de7153e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 17:12:56 +0900 Subject: [PATCH 1/2] fix sample generation is not working in FLUX1 fine tuning #1647 --- library/flux_models.py | 5 +++-- library/flux_train_utils.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index a35dbc106..0bc1c02b9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -999,8 +999,9 @@ def get_unit_index(self, is_double: bool, index: int): def prepare_block_swap_before_forward(self): # make: first n blocks are on cuda, and last n blocks are on cpu - if self.blocks_to_swap is None: - raise ValueError("Block swap is not enabled.") + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return for i in range(self.num_block_units - self.blocks_to_swap): for b in self.get_block_unit(i): b.to(self.device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f77d4b585..1d1eb9d24 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -313,6 +313,7 @@ def denoise( guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() pred = model( img=img, img_ids=img_ids, @@ -325,7 +326,8 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - + + model.prepare_block_swap_before_forward() return img From 822fe578591e44ac949830e03a8841e222483052 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 20:57:27 +0900 Subject: [PATCH 2/2] add workaround for 'Some tensors share memory' error #1614 --- networks/convert_flux_lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index bd4c1cf78..fe6466ebc 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -412,6 +412,10 @@ def main(args): state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) elif args.src == "sd-scripts" and args.dst == "ai-toolkit": state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + + # eliminate 'shared tensors' + for k in list(state_dict.keys()): + state_dict[k] = state_dict[k].detach().clone() else: raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported")