From 81c0c965a24ce4f0f86dfa980f803d7616ca46d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 5 Nov 2024 21:22:42 +0900 Subject: [PATCH 1/9] faster block swap --- flux_train.py | 107 ++++++++++---------- library/flux_models.py | 138 ++++++++++++++----------- library/utils.py | 222 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 352 insertions(+), 115 deletions(-) diff --git a/flux_train.py b/flux_train.py index 79c44d7b4..afddc897f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -17,12 +17,14 @@ import os from multiprocessing import Value import time -from typing import List +from typing import List, Optional, Tuple, Union import toml from tqdm import tqdm import torch +import torch.nn as nn +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -466,45 +468,28 @@ def train(args): # memory efficient block swapping - def get_block_unit(dbl_blocks, sgl_blocks, index: int): - if index < len(dbl_blocks): - return (dbl_blocks[index],) - else: - index -= len(dbl_blocks) - index *= 2 - return (sgl_blocks[index], sgl_blocks[index + 1]) - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): - def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): - # print(f"Backward: Move block {bidx_to_cpu} to CPU") - for block in blocks_to_cpu: - block = block.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Backward: Move block {bidx_to_cuda} to CUDA") - for block in blocks_to_cuda: - block = block.to(dvc, non_blocking=True) - - torch.cuda.synchronize() - # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") - return bidx_to_cpu, bidx_to_cuda - - blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu) - blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda) - - futures[block_idx_to_cuda] = thread_pool.submit( - move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device - ) + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + # start_time = time.perf_counter() + # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA") + utils.swap_weight_devices(block_to_cpu, block_to_cuda) + # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - def wait_blocks_move(block_idx, futures): - if block_idx not in futures: + def wait_blocks_move(block_id, futures): + if block_id not in futures: return - # print(f"Backward: Wait for block {block_idx}") + # print(f"Backward: Wait for block {block_id}") # start_time = time.perf_counter() - future = futures.pop(block_idx) - future.result() - # print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") - # torch.cuda.synchronize() + future = futures.pop(block_id) + _, bidx_to_cuda = future.result() + assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}" + # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s") # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") if args.fused_backward_pass: @@ -513,11 +498,11 @@ def wait_blocks_move(block_idx, futures): library.adafactor_fused.patch_adafactor_fused(optimizer) - blocks_to_swap = args.blocks_to_swap + double_blocks_to_swap = args.blocks_to_swap // 2 + single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - num_block_units = num_double_blocks + num_single_blocks // 2 - handled_unit_indices = set() + handled_block_ids = set() n = 1 # only asynchronous purpose, no need to increase this number # n = 2 @@ -530,28 +515,37 @@ def wait_blocks_move(block_idx, futures): if parameter.requires_grad: grad_hook = None - if blocks_to_swap: + if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: is_double = param_name.startswith("double_blocks") is_single = param_name.startswith("single_blocks") - if is_double or is_single: + if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: block_idx = int(param_name.split(".")[1]) - unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 - if unit_idx not in handled_unit_indices: + block_id = (is_double, block_idx) # double or single, block index + if block_id not in handled_block_ids: # swap following (already backpropagated) block - handled_unit_indices.add(unit_idx) + handled_block_ids.add(block_id) # if n blocks were already backpropagated - num_blocks_propagated = num_block_units - unit_idx - 1 + if is_double: + num_blocks = num_double_blocks + blocks_to_swap = double_blocks_to_swap + else: + num_blocks = num_single_blocks + blocks_to_swap = single_blocks_to_swap + + # -1 for 0-based index, -1 for current block is not fully backpropagated yet + num_blocks_propagated = num_blocks - block_idx - 2 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + waiting = block_idx > 0 and block_idx <= blocks_to_swap + if swapping or waiting: - block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cpu = num_blocks - num_blocks_propagated block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = unit_idx - 1 + block_idx_to_wait = block_idx - 1 # create swap hook def create_swap_grad_hook( - bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool + is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool ): def __grad_hook(tensor: torch.Tensor): if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -559,24 +553,25 @@ def __grad_hook(tensor: torch.Tensor): optimizer.step_param(tensor, param_group) tensor.grad = None - # print(f"Backward: {uidx}, {swpng}, {wtng}") + # print( + # f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}" + # ) if swpng: submit_move_blocks( futures, thread_pool, bidx_to_cpu, bidx_to_cuda, - flux.double_blocks, - flux.single_blocks, - accelerator.device, + flux.double_blocks if is_dbl else flux.single_blocks, + (is_dbl, bidx_to_cuda), # wait for this block ) if wtng: - wait_blocks_move(bidx_to_wait, futures) + wait_blocks_move((is_dbl, bidx_to_wait), futures) return __grad_hook grad_hook = create_swap_grad_hook( - block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting + is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting ) if grad_hook is None: diff --git a/library/flux_models.py b/library/flux_models.py index 0bc1c02b9..48dea4fc9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -7,8 +7,9 @@ import math import os import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -923,7 +924,8 @@ def __init__(self, params: FluxParams): self.blocks_to_swap = None self.thread_pool: Optional[ThreadPoolExecutor] = None - self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) @property def device(self): @@ -963,14 +965,17 @@ def disable_gradient_checkpointing(self): def enable_block_swap(self, num_blocks: int): self.blocks_to_swap = num_blocks + self.double_blocks_to_swap = num_blocks // 2 + self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}." + ) n = 1 # async block swap. 1 is enough - # n = 2 - # n = max(1, os.cpu_count() // 2) self.thread_pool = ThreadPoolExecutor(max_workers=n) def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_double_blocks = self.double_blocks save_single_blocks = self.single_blocks @@ -983,31 +988,55 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - def get_block_unit(self, index: int): - if index < len(self.double_blocks): - return (self.double_blocks[index],) - else: - index -= len(self.double_blocks) - index *= 2 - return self.single_blocks[index], self.single_blocks[index + 1] + # def get_block_unit(self, index: int): + # if index < len(self.double_blocks): + # return (self.double_blocks[index],) + # else: + # index -= len(self.double_blocks) + # index *= 2 + # return self.single_blocks[index], self.single_blocks[index + 1] - def get_unit_index(self, is_double: bool, index: int): - if is_double: - return index - else: - return len(self.double_blocks) + index // 2 + # def get_unit_index(self, is_double: bool, index: int): + # if is_double: + # return index + # else: + # return len(self.double_blocks) + index // 2 def prepare_block_swap_before_forward(self): - # make: first n blocks are on cuda, and last n blocks are on cpu + # # make: first n blocks are on cuda, and last n blocks are on cpu + # if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # # raise ValueError("Block swap is not enabled.") + # return + # for i in range(self.num_block_units - self.blocks_to_swap): + # for b in self.get_block_unit(i): + # b.to(self.device) + # for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): + # for b in self.get_block_unit(i): + # b.to("cpu") + # clean_memory_on_device(self.device) + + # all blocks are on device, but some weights are on cpu + # make first n blocks weights on device, and last n blocks weights on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: # raise ValueError("Block swap is not enabled.") return - for i in range(self.num_block_units - self.blocks_to_swap): - for b in self.get_block_unit(i): - b.to(self.device) - for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): - for b in self.get_block_unit(i): - b.to("cpu") + + for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: + b.to(self.device) + utils.weighs_to_device(b, self.device) # make sure weights are on device + for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]: + b.to(self.device) # move block to device first + utils.weighs_to_device(b, "cpu") # make sure weights are on cpu + torch.cuda.synchronize() + clean_memory_on_device(self.device) + + for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]: + b.to(self.device) + utils.weighs_to_device(b, self.device) # make sure weights are on device + for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]: + b.to(self.device) # move block to device first + utils.weighs_to_device(b, "cpu") # make sure weights are on cpu + torch.cuda.synchronize() clean_memory_on_device(self.device) def forward( @@ -1044,27 +1073,22 @@ def forward( for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - futures = {} - - def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): - # print(f"Moving {bidx_to_cpu} to cpu.") - for block in blocks_to_cpu: - block.to("cpu", non_blocking=True) - torch.cuda.empty_cache() + # device = self.device - # print(f"Moving {bidx_to_cuda} to cuda.") - for block in blocks_to_cuda: - block.to(self.device, non_blocking=True) - - torch.cuda.synchronize() + def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + start_time = time.perf_counter() + # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") + utils.swap_weight_devices(block_to_cpu, block_to_cuda) # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - return block_idx_to_cpu, block_idx_to_cuda - blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) - blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + return block_idx_to_cpu, block_idx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) def wait_for_blocks_move(block_idx, ftrs): if block_idx not in ftrs: @@ -1073,37 +1097,35 @@ def wait_for_blocks_move(block_idx, ftrs): # start_time = time.perf_counter() ftr = ftrs.pop(block_idx) ftr.result() - # torch.cuda.synchronize() - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds") + double_futures = {} for block_idx, block in enumerate(self.double_blocks): # print(f"Double block {block_idx}") - unit_idx = self.get_unit_index(is_double=True, index=block_idx) - wait_for_blocks_move(unit_idx, futures) + wait_for_blocks_move(block_idx, double_futures) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if unit_idx < self.blocks_to_swap: - block_idx_to_cpu = unit_idx - block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + if block_idx < self.double_blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx + future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda) + double_futures[block_idx_to_cuda] = future img = torch.cat((txt, img), 1) + single_futures = {} for block_idx, block in enumerate(self.single_blocks): # print(f"Single block {block_idx}") - unit_idx = self.get_unit_index(is_double=False, index=block_idx) - if block_idx % 2 == 0: - wait_for_blocks_move(unit_idx, futures) + wait_for_blocks_move(block_idx, single_futures) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: - block_idx_to_cpu = unit_idx - block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + if block_idx < self.single_blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx + future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) + single_futures[block_idx_to_cuda] = future img = img[:, txt.shape[1] :, ...] diff --git a/library/utils.py b/library/utils.py index ca0f904d2..aed510074 100644 --- a/library/utils.py +++ b/library/utils.py @@ -6,6 +6,7 @@ import struct import torch +import torch.nn as nn from torchvision import transforms from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete @@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False): # region PyTorch utils +# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ +# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): +# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: +# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.") +# # cpu_tensor = module_to_cuda.weight.data +# # cuda_tensor = module_to_cpu.weight.data +# # assert cuda_tensor.device.type == "cuda" +# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True) +# # torch.cuda.current_stream().synchronize() +# # cuda_tensor.copy_(cpu_tensor, non_blocking=True) +# # torch.cuda.current_stream().synchronize() +# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True) +# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor +# cuda_tensor_view = module_to_cpu.weight.data +# cpu_tensor_view = module_to_cuda.weight.data +# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone() +# module_to_cuda.weight.data = cuda_tensor_view +# module_to_cuda.weight.data.copy_(cpu_tensor_view) + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + stream_to_cpu = torch.cuda.Stream() + stream_to_cuda = torch.cuda.Stream() + + events = [] + with torch.cuda.stream(stream_to_cpu): + # cuda to offload + offloaded_weights = [] + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) + event = torch.cuda.Event() + event.record(stream=stream_to_cpu) + events.append(event) + + with torch.cuda.stream(stream_to_cuda): + # cpu to cuda + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events): + event.synchronize() + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + # offload to cpu + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip( + weight_swap_jobs, offloaded_weights + ): + module_to_cpu.weight.data = offloaded_weight + + stream_to_cuda.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + stream_to_cpu = torch.cuda.Stream() + stream_to_cuda = torch.cuda.Stream() + + # cuda to offload + events = [] + with torch.cuda.stream(stream_to_cpu): + offloaded_weights = [] + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream_to_cpu) + offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) + + event = torch.cuda.Event() + event.record(stream=stream_to_cpu) + events.append(event) + + # cpu to cuda + with torch.cuda.stream(stream_to_cuda): + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip( + weight_swap_jobs, events, offloaded_weights + ): + event.synchronize() + cuda_data_view.record_stream(stream_to_cuda) + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + module_to_cpu.weight.data = offloaded_weight + + stream_to_cuda.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + # torch.cuda.current_stream().wait_stream(stream_to_cuda) + # for job in weight_swap_jobs: + # job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor + + +def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")): + # one of the modules must have the tensor to offload + module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") + module_to_cpu.offloaded_weight.pin_memory() + offloaded_weight = ( + module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight + ) + assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu" + weight_swap_jobs.append( + (module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight) + ) + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to offload + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + cuda_data_view.record_stream(stream) + offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + # offload to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + module_to_cpu.weight.data = offloaded_weight + offloaded_weight = cpu_data_view + module_to_cpu.offloaded_weight = offloaded_weight + module_to_cuda.offloaded_weight = offloaded_weight + + stream.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")): + # one of the modules must have the tensor to cache + module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") + module_to_cpu.__cached_cpu_weight.pin_memory() + + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True) + module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True) + + torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish + torch.cuda.empty_cache() + + +# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ +# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): +# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: +# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda" +# weight_on_cuda = module_to_cpu.weight +# weight_on_cpu = module_to_cuda.weight +# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True) +# event = torch.cuda.current_stream().record_event() +# event.synchronize() +# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True) +# weight_on_cpu.data = cuda_to_cpu_data +# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad + +# module_to_cpu.weight = weight_on_cpu +# module_to_cuda.weight = weight_on_cuda + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ @@ -313,6 +533,7 @@ def _convert_float8(byte_tensor, dtype_str, shape): # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") + def load_safetensors( path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 ) -> dict[str, torch.Tensor]: @@ -336,7 +557,6 @@ def load_safetensors( return state_dict - # endregion # region Image utils From aab943cea3eb8a91041c857771f1642581133608 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 5 Nov 2024 23:27:41 +0900 Subject: [PATCH 2/9] remove unused weight swapping functions from utils.py --- library/utils.py | 185 ----------------------------------------------- 1 file changed, 185 deletions(-) diff --git a/library/utils.py b/library/utils.py index aed510074..07079c6d9 100644 --- a/library/utils.py +++ b/library/utils.py @@ -94,26 +94,6 @@ def setup_logging(args=None, log_level=None, reset=False): # region PyTorch utils -# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): -# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ -# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): -# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: -# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.") -# # cpu_tensor = module_to_cuda.weight.data -# # cuda_tensor = module_to_cpu.weight.data -# # assert cuda_tensor.device.type == "cuda" -# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True) -# # torch.cuda.current_stream().synchronize() -# # cuda_tensor.copy_(cpu_tensor, non_blocking=True) -# # torch.cuda.current_stream().synchronize() -# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True) -# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor -# cuda_tensor_view = module_to_cpu.weight.data -# cpu_tensor_view = module_to_cuda.weight.data -# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone() -# module_to_cuda.weight.data = cuda_tensor_view -# module_to_cuda.weight.data.copy_(cpu_tensor_view) - def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ @@ -143,171 +123,6 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): torch.cuda.current_stream().synchronize() # this prevents the illegal loss value -def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - stream_to_cpu = torch.cuda.Stream() - stream_to_cuda = torch.cuda.Stream() - - events = [] - with torch.cuda.stream(stream_to_cpu): - # cuda to offload - offloaded_weights = [] - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) - event = torch.cuda.Event() - event.record(stream=stream_to_cpu) - events.append(event) - - with torch.cuda.stream(stream_to_cuda): - # cpu to cuda - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events): - event.synchronize() - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - # offload to cpu - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip( - weight_swap_jobs, offloaded_weights - ): - module_to_cpu.weight.data = offloaded_weight - - stream_to_cuda.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - - -def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - stream_to_cpu = torch.cuda.Stream() - stream_to_cuda = torch.cuda.Stream() - - # cuda to offload - events = [] - with torch.cuda.stream(stream_to_cpu): - offloaded_weights = [] - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - cuda_data_view.record_stream(stream_to_cpu) - offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) - - event = torch.cuda.Event() - event.record(stream=stream_to_cpu) - events.append(event) - - # cpu to cuda - with torch.cuda.stream(stream_to_cuda): - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip( - weight_swap_jobs, events, offloaded_weights - ): - event.synchronize() - cuda_data_view.record_stream(stream_to_cuda) - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - module_to_cpu.weight.data = offloaded_weight - - stream_to_cuda.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - # torch.cuda.current_stream().wait_stream(stream_to_cuda) - # for job in weight_swap_jobs: - # job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor - - -def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")): - # one of the modules must have the tensor to offload - module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") - module_to_cpu.offloaded_weight.pin_memory() - offloaded_weight = ( - module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight - ) - assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu" - weight_swap_jobs.append( - (module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight) - ) - - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - # cuda to offload - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - cuda_data_view.record_stream(stream) - offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True) - - stream.synchronize() - - # cpu to cuda - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - # offload to cpu - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - module_to_cpu.weight.data = offloaded_weight - offloaded_weight = cpu_data_view - module_to_cpu.offloaded_weight = offloaded_weight - module_to_cuda.offloaded_weight = offloaded_weight - - stream.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - - -def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")): - # one of the modules must have the tensor to cache - module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") - module_to_cpu.__cached_cpu_weight.pin_memory() - - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs: - module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True) - module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True) - - torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish - torch.cuda.empty_cache() - - -# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): -# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ -# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): -# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: -# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda" -# weight_on_cuda = module_to_cpu.weight -# weight_on_cpu = module_to_cuda.weight -# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True) -# event = torch.cuda.current_stream().record_event() -# event.synchronize() -# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True) -# weight_on_cpu.data = cuda_to_cpu_data -# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad - -# module_to_cpu.weight = weight_on_cpu -# module_to_cuda.weight = weight_on_cuda - - def weighs_to_device(layer: nn.Module, device: torch.device): for module in layer.modules(): if hasattr(module, "weight") and module.weight is not None: From 186aa5b97d43700706bd8e986e2d5ac3f5d4c9b7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 7 Nov 2024 22:16:05 +0900 Subject: [PATCH 3/9] fix illeagal block is swapped #1764 --- library/flux_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 48dea4fc9..4721fa02e 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1077,7 +1077,7 @@ def forward( def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - start_time = time.perf_counter() + # start_time = time.perf_counter() # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") utils.swap_weight_devices(block_to_cpu, block_to_cuda) # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") @@ -1123,7 +1123,7 @@ def wait_for_blocks_move(block_idx, ftrs): if block_idx < self.single_blocks_to_swap: block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx + block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) single_futures[block_idx_to_cuda] = future From 02bd76e6c719ad85c108a177405846c5c958bd78 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 11 Nov 2024 21:15:36 +0900 Subject: [PATCH 4/9] Refactor block swapping to utilize custom offloading utilities --- flux_train.py | 228 ++++++++--------------------- library/custom_offloading_utils.py | 216 +++++++++++++++++++++++++++ library/flux_models.py | 113 ++------------ 3 files changed, 295 insertions(+), 262 deletions(-) create mode 100644 library/custom_offloading_utils.py diff --git a/flux_train.py b/flux_train.py index afddc897f..02dede45e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -295,7 +295,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - flux.enable_block_swap(args.blocks_to_swap) + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) if not cache_latents: # load VAE here if not cached @@ -338,15 +338,15 @@ def train(args): # determine target layer and block index for each parameter block_type = "other" # double, single or other if np[0].startswith("double_blocks"): - block_idx = int(np[0].split(".")[1]) + block_index = int(np[0].split(".")[1]) block_type = "double" elif np[0].startswith("single_blocks"): - block_idx = int(np[0].split(".")[1]) + block_index = int(np[0].split(".")[1]) block_type = "single" else: - block_idx = -1 + block_index = -1 - param_group_key = (block_type, block_idx) + param_group_key = (block_type, block_index) if param_group_key not in param_group: param_group[param_group_key] = [] param_group[param_group_key].append(p) @@ -466,123 +466,21 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # memory efficient block swapping - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # start_time = time.perf_counter() - # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA") - utils.swap_weight_devices(block_to_cpu, block_to_cuda) - # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") - return bidx_to_cpu, bidx_to_cuda # , event - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - - futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_blocks_move(block_id, futures): - if block_id not in futures: - return - # print(f"Backward: Wait for block {block_id}") - # start_time = time.perf_counter() - future = futures.pop(block_id) - _, bidx_to_cuda = future.result() - assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}" - # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s") - # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") - if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - double_blocks_to_swap = args.blocks_to_swap // 2 - single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - handled_block_ids = set() - - n = 1 # only asynchronous purpose, no need to increase this number - # n = 2 - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - grad_hook = None - - if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: - is_double = param_name.startswith("double_blocks") - is_single = param_name.startswith("single_blocks") - if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: - block_idx = int(param_name.split(".")[1]) - block_id = (is_double, block_idx) # double or single, block index - if block_id not in handled_block_ids: - # swap following (already backpropagated) block - handled_block_ids.add(block_id) - - # if n blocks were already backpropagated - if is_double: - num_blocks = num_double_blocks - blocks_to_swap = double_blocks_to_swap - else: - num_blocks = num_single_blocks - blocks_to_swap = single_blocks_to_swap - - # -1 for 0-based index, -1 for current block is not fully backpropagated yet - num_blocks_propagated = num_blocks - block_idx - 2 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = block_idx > 0 and block_idx <= blocks_to_swap - - if swapping or waiting: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_idx - 1 - - # create swap hook - def create_swap_grad_hook( - is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool - ): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # print( - # f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}" - # ) - if swpng: - submit_move_blocks( - futures, - thread_pool, - bidx_to_cpu, - bidx_to_cuda, - flux.double_blocks if is_dbl else flux.single_blocks, - (is_dbl, bidx_to_cuda), # wait for this block - ) - if wtng: - wait_blocks_move((is_dbl, bidx_to_wait), futures) - - return __grad_hook - - grad_hook = create_swap_grad_hook( - is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting - ) - - if grad_hook is None: - - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - grad_hook = __grad_hook + def grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None parameter.register_post_accumulate_grad_hook(grad_hook) @@ -601,66 +499,66 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - blocks_to_swap = args.blocks_to_swap - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - num_block_units = num_double_blocks + num_single_blocks // 2 - - n = 1 # only asynchronous purpose, no need to increase this number - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - block_type, block_idx = block_types_and_indices[opt_idx] - - def create_optimizer_hook(btype, bidx): - def optimizer_hook(parameter: torch.Tensor): - # print(f"optimizer_hook: {btype}, {bidx}") - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - # swap blocks if necessary - if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)): - unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2 - num_blocks_propagated = num_block_units - unit_idx - - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = unit_idx > 0 and unit_idx <= blocks_to_swap - - if swapping: - block_idx_to_cpu = num_block_units - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") - submit_move_blocks( - futures, - thread_pool, - block_idx_to_cpu, - block_idx_to_cuda, - flux.double_blocks, - flux.single_blocks, - accelerator.device, - ) - - if waiting: - block_idx_to_wait = unit_idx - 1 - wait_blocks_move(block_idx_to_wait, futures) - - return optimizer_hook - - parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 + # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook + if is_swapping_blocks: + import library.custom_offloading_utils as custom_offloading_utils + + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) + double_blocks_to_swap = args.blocks_to_swap // 2 + single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 + + offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device) + offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device) + + param_name_pairs = [] + if not args.blockwise_fused_optimizers: + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + param_name_pairs.extend(zip(param_group["params"], param_name_group)) + else: + # named_parameters is a list of (name, parameter) pairs + param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()]) + + for parameter, param_name in param_name_pairs: + if not parameter.requires_grad: + continue + + is_double = param_name.startswith("double_blocks") + is_single = param_name.startswith("single_blocks") + if not is_double and not is_single: + continue + + block_index = int(param_name.split(".")[1]) + if is_double: + blocks = flux.double_blocks + offloader = offloader_double + else: + blocks = flux.single_blocks + offloader = offloader_single + + grad_hook = offloader.create_grad_hook(blocks, block_index) + if grad_hook is not None: + parameter.register_post_accumulate_grad_hook(grad_hook) + # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py new file mode 100644 index 000000000..33a413004 --- /dev/null +++ b/library/custom_offloading_utils.py @@ -0,0 +1,216 @@ +from concurrent.futures import ThreadPoolExecutor +import time +from typing import Optional +import torch +import torch.nn as nn + +from library.device_utils import clean_memory_on_device + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + synchronize_device() + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + synchronize_device() + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices(block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + + +class TrainOffloader(Offloader): + """ + supports backward offloading + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + self.hook_added = set() + + def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + if block_index in self.hook_added: + return None + self.hook_added.add(block_index) + + # -1 for 0-based index, -1 for current block is not fully backpropagated yet + num_blocks_propagated = self.num_blocks - block_index - 2 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + if self.debug: + print( + f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}" + ) + if swapping: + + def grad_hook(tensor: torch.Tensor): + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + + return grad_hook + + else: + + def grad_hook(tensor: torch.Tensor): + self._wait_blocks_move(block_idx_to_wait) + + return grad_hook + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first + weighs_to_device(b, "cpu") # make sure weights are on cpu + + synchronize_device(self.device) + clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + if block_idx >= self.blocks_to_swap: + return + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/library/flux_models.py b/library/flux_models.py index 4721fa02e..e0bee160f 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -18,6 +18,7 @@ from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint +from library import custom_offloading_utils # USE_REENTRANT = True @@ -923,7 +924,8 @@ def __init__(self, params: FluxParams): self.cpu_offload_checkpointing = False self.blocks_to_swap = None - self.thread_pool: Optional[ThreadPoolExecutor] = None + self.offloader_double = None + self.offloader_single = None self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) @@ -963,17 +965,17 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, num_blocks: int): + def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - self.double_blocks_to_swap = num_blocks // 2 - self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device) + self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device) print( - f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}." + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." ) - n = 1 # async block swap. 1 is enough - self.thread_pool = ThreadPoolExecutor(max_workers=n) - def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: @@ -988,56 +990,11 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - # def get_block_unit(self, index: int): - # if index < len(self.double_blocks): - # return (self.double_blocks[index],) - # else: - # index -= len(self.double_blocks) - # index *= 2 - # return self.single_blocks[index], self.single_blocks[index + 1] - - # def get_unit_index(self, is_double: bool, index: int): - # if is_double: - # return index - # else: - # return len(self.double_blocks) + index // 2 - def prepare_block_swap_before_forward(self): - # # make: first n blocks are on cuda, and last n blocks are on cpu - # if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # # raise ValueError("Block swap is not enabled.") - # return - # for i in range(self.num_block_units - self.blocks_to_swap): - # for b in self.get_block_unit(i): - # b.to(self.device) - # for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): - # for b in self.get_block_unit(i): - # b.to("cpu") - # clean_memory_on_device(self.device) - - # all blocks are on device, but some weights are on cpu - # make first n blocks weights on device, and last n blocks weights on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # raise ValueError("Block swap is not enabled.") return - - for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: - b.to(self.device) - utils.weighs_to_device(b, self.device) # make sure weights are on device - for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]: - b.to(self.device) # move block to device first - utils.weighs_to_device(b, "cpu") # make sure weights are on cpu - torch.cuda.synchronize() - clean_memory_on_device(self.device) - - for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]: - b.to(self.device) - utils.weighs_to_device(b, self.device) # make sure weights are on device - for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]: - b.to(self.device) # move block to device first - utils.weighs_to_device(b, "cpu") # make sure weights are on cpu - torch.cuda.synchronize() - clean_memory_on_device(self.device) + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) def forward( self, @@ -1073,59 +1030,21 @@ def forward( for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - # device = self.device - - def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # start_time = time.perf_counter() - # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") - utils.swap_weight_devices(block_to_cpu, block_to_cuda) - # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") - return block_idx_to_cpu, block_idx_to_cuda # , event - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_for_blocks_move(block_idx, ftrs): - if block_idx not in ftrs: - return - # print(f"Waiting for move blocks: {block_idx}") - # start_time = time.perf_counter() - ftr = ftrs.pop(block_idx) - ftr.result() - # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds") - - double_futures = {} for block_idx, block in enumerate(self.double_blocks): - # print(f"Double block {block_idx}") - wait_for_blocks_move(block_idx, double_futures) + self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx < self.double_blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx - future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda) - double_futures[block_idx_to_cuda] = future + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) img = torch.cat((txt, img), 1) - single_futures = {} for block_idx, block in enumerate(self.single_blocks): - # print(f"Single block {block_idx}") - wait_for_blocks_move(block_idx, single_futures) + self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx < self.single_blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx - future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) - single_futures[block_idx_to_cuda] = future + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) img = img[:, txt.shape[1] :, ...] From cde90b8903870b6b28dae274d07ed27978055e3c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 08:49:05 +0900 Subject: [PATCH 5/9] feat: implement block swapping for FLUX.1 LoRA (WIP) --- flux_train.py | 2 +- flux_train_network.py | 33 ++++++++++++++++++++++++ library/custom_offloading_utils.py | 40 +++++++++++++++++++++++++++++- library/flux_models.py | 8 ++++-- train_network.py | 9 ++++++- 5 files changed, 87 insertions(+), 5 deletions(-) diff --git a/flux_train.py b/flux_train.py index 02dede45e..346fe8fbd 100644 --- a/flux_train.py +++ b/flux_train.py @@ -519,7 +519,7 @@ def grad_hook(parameter: torch.Tensor): num_parameters_per_group[opt_idx] += 1 # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook - if is_swapping_blocks: + if False: # is_swapping_blocks: import library.custom_offloading_utils as custom_offloading_utils num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) diff --git a/flux_train_network.py b/flux_train_network.py index 2b71a8979..376cc1597 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -25,6 +25,7 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -78,6 +79,12 @@ def load_target_model(self, args, weight_dtype, accelerator): if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() @@ -285,6 +292,8 @@ def sample_images(self, accelerator, args, epoch, global_step, device, ae, token text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) if not args.split_mode: + if self.is_swapping_blocks: + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() flux_train_utils.sample_images( accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) @@ -539,6 +548,19 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + + return flux + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() @@ -550,6 +572,17 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) + + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) return parser diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 33a413004..70da93902 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -183,9 +183,47 @@ class ModelOffloader(Offloader): supports forward offloading """ - def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): super().__init__(num_blocks, blocks_to_swap, device, debug) + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def __del__(self): + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module, grad_input, grad_output): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return diff --git a/library/flux_models.py b/library/flux_models.py index e0bee160f..4fa272522 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -970,8 +970,12 @@ def enable_block_swap(self, num_blocks: int, device: torch.device): double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 - self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device) - self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device) + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True + ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." ) diff --git a/train_network.py b/train_network.py index b90aa420e..d70f14ad3 100644 --- a/train_network.py +++ b/train_network.py @@ -18,6 +18,7 @@ init_ipex() from accelerate.utils import set_seed +from accelerate import Accelerator from diffusers import DDPMScheduler from library import deepspeed_utils, model_util, strategy_base, strategy_sd @@ -272,6 +273,11 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + return accelerator.prepare(unet) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass @@ -627,7 +633,8 @@ def train(self, args): training_model = ds_model else: if train_unet: - unet = accelerator.prepare(unet) + # default implementation is: unet = accelerator.prepare(unet) + unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: From 2cb7a6db02ae001355f4830581b9fc2ffffe01c6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 21:39:13 +0900 Subject: [PATCH 6/9] feat: add block swap for FLUX.1/SD3 LoRA training --- README.md | 212 ++++++---------------------- flux_train.py | 56 +------- flux_train_network.py | 95 +++++++------ library/custom_offloading_utils.py | 75 ++++------ library/flux_models.py | 19 ++- library/flux_train_utils.py | 48 +------ library/sd3_models.py | 71 +++------- library/sd3_train_utils.py | 49 +------ library/train_util.py | 74 +++++++++- sd3_train.py | 186 +++--------------------- sd3_train_network.py | 30 ++++ tools/cache_latents.py | 1 + tools/cache_text_encoder_outputs.py | 1 + train_network.py | 6 +- 14 files changed, 291 insertions(+), 632 deletions(-) diff --git a/README.md b/README.md index 14328607e..1e63b5830 100644 --- a/README.md +++ b/README.md @@ -14,150 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates -Nov 9, 2024: +Nov 12, 2024: -- Fixed an issue where the image size could not be obtained when caching latents was enabled and a specific file name existed, causing the latent size to be incorrect. See PR [#1770](https://github.com/kohya-ss/sd-scripts/pull/1770) for details. Thanks to feffy380! - -Nov 7, 2024: - -- The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! - - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled (training more on image details), and if more than 1.0, the side closer to noise is more sampled (training more on overall structure). - - The effect of a shift in uniform distribution is shown in the figure below. - - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) - -Oct 31, 2024: - -- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. - -Oct 19, 2024: - -- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. - - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. - - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. - - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. - - Specify "number of training images x number of repeats >= number of regularization images x number of repeats". - - The weights of DOP is specified by `--prior_loss_weight` option (not dataset config). - - The appropriate value is still unknown. For FLUX, according to the comments in the [PR](https://github.com/kohya-ss/sd-scripts/pull/1710), the value may be 1 (thanks to dxqbYD!). For SDXL, a larger value may be needed (10-100 may be good starting points). - - It may be good to adjust the value so that the loss is about half to three-quarters of the loss when DOP is not applied. -``` -[[datasets.subsets]] -image_dir = "path/to/image/dir" -num_repeats = 1 -is_reg = true -custom_attributes.diff_output_preservation = true # Add this -``` - - -Oct 13, 2024: - -- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. - -- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. - - Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching. - - `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. - - Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment). - - `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching. -- `--skip_cache_check` option is added to each training script. - - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. - - Specify this option if you have a large number of cache files and the consistency check takes time. - - Even if this option is specified, the cache will be created if the file does not exist. - - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. - -Oct 12, 2024 (update 1): - -- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models. - - A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38. - - The model is automatically determined based on the keys in *.safetensors. - - Specifications for compact model safetensors: - - Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks. - - LoRA training is unverified. - - The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified. - -Oct 12, 2024: - -- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - - In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified. - - Set up multi-GPU training with `accelerate config`. - - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. - ``` - accelerate launch --rdzv_backend=c10d sdxl_train_network.py ... - ``` - - In multi-GPU training, the memory of multiple GPUs is not integrated. In other words, even if you have two 12GB VRAM GPUs, you cannot train the model that requires 24GB VRAM. Training that can be done with 12GB VRAM is executed at (up to) twice the speed. - -Oct 11, 2024: -- ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`. - - For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file). - - The learning rate for the copy part of the U-Net is specified by `--learning_rate`. The learning rate for the added modules in ControlNet is specified by `--control_net_lr`. The optimal value is still unknown, but try around U-Net `1e-5` and ControlNet `1e-4`. - - If you want to generate sample images, specify the control image as `--cn path/to/control/image`. - - The trained weights are automatically converted and saved in Diffusers format. It should be available in ComfyUI. -- Weighting of prompts (captions) during training in SDXL is now supported (e.g., `(some text)`, `[some text]`, `(some text:1.4)`, etc.). The function is enabled by specifying `--weighted_captions`. - - The default is `False`. It is same as before, and the parentheses are used as normal text. - - If `--weighted_captions` is specified, please use `\` to escape the parentheses in the prompt. For example, `\(some text:1.4\)`. - -Oct 6, 2024: -- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. -- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. - -Sep 26, 2024: -The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - - -Sep 18, 2024 (update 1): -Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. - -Sep 18, 2024: - -- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. - - Details of the schedule-free optimizer can be found in [facebookresearch/schedule_free](https://github.com/facebookresearch/schedule_free). - - `schedulefree` is added to the dependencies. Please update the library if necessary. - - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. - - Wrapper classes are not available for now. - - These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch. - -Sep 16, 2024: - - Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. - -Sep 15, 2024: - -Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. - -The implementation is based on 2kpr's code. Thanks to 2kpr! - -Sep 14, 2024: -- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. -- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. - -Sep 11, 2024: -Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! - -Sep 10, 2024: -In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. - -Sep 9, 2024: -Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. - -Sep 5, 2024 (update 1): - -Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. - -Sep 5, 2024: - -The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. - -Sep 4, 2024: -- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. -- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. -- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. - -Sep 1, 2024: -- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! - - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. - -Aug 29, 2024: -Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. +- Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. +- During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved. +- There may be bugs due to the significant changes. Feedback is welcome. ## FLUX.1 training @@ -190,7 +51,8 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_module networks.lora_flux --network_dim 4 --network_train_unet_only +--optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name @@ -198,23 +60,39 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t ``` (The command is multi-line for readability. Please combine it into one line.) -The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + +When training LoRA for Text Encoder (without `--network_train_unet_only`), more VRAM is required. Please refer to the settings below to reduce VRAM usage. + +__Options for GPUs with less VRAM:__ + +By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + +Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. + +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`. + +Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below: ``` --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: +The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration. + +The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 +--blocks_to_swap 16 ``` -`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. +For GPUs with less than 10GB of VRAM, it is recommended to use an fp8 checkpoint for T5XXL. You can download `t5xxl_fp8_e4m3fn.safetensors` from [comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) (please use without `scaled`). -We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. +10GB VRAM GPUs will work with 22 blocks swapped, and 8GB VRAM GPUs will work with 28 blocks swapped. -The trained LoRA model can be used with ComfyUI. +__`--split_mode` is deprecated. This option is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. If this option is specified and `--blocks_to_swap` is not specified, `--blocks_to_swap 18` is automatically enabled.__ #### Key Options for FLUX.1 LoRA training @@ -239,6 +117,7 @@ There are many unknown points in FLUX.1 training, so some settings can be specif - `additive`: add to noisy input - `sigma_scaled`: apply sigma scaling, same as SD3 - `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). +- `--blocks_to_swap`. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. @@ -426,9 +305,9 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, ` `--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36. +`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). The maximum value is 35. -`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. This option cannot be used with `--blocks_to_swap`. All these options are experimental and may change in the future. @@ -448,13 +327,13 @@ There are two possible ways to use block swap. It is unknown which is better. 2. Swap many blocks to increase the batch size and shorten the training speed per data. - For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1. + For example, swapping 35 blocks seems to increase the batch size to about 5. In this case, the training speed per data will be relatively faster than 1. #### Training with <24GB VRAM GPUs Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU. -T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. +T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. #### Key Features for FLUX.1 fine-tuning @@ -465,17 +344,19 @@ T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement f - Since the transfer between CPU and GPU takes time, the training will be slower. - `--blocks_to_swap` specify the number of blocks to swap. - About 640MB of memory can be saved per block. - - Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`. - - Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU. - - In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU. - - The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU. - - After the backward pass, the blocks are back to their original locations. + - (Update 1: Nov 12, 2024) + - The maximum number of blocks that can be swapped is 35. + - We are exchanging only the data of the weights (weight.data) in reference to the implementation of OneTrainer (thanks to OneTrainer). However, the mechanism of the exchange is a custom implementation. + - Since it takes time to free CUDA memory (torch.cuda.empty_cache()), we reuse the CUDA memory allocated to weight.data as it is and exchange the weights between modules. + - This shortens the time it takes to exchange weights between modules. + - Since the weights must be almost identical to be exchanged, FLUX.1 exchanges the weights between double blocks and single blocks. + - In SD3, all blocks are similar, but some weights are different, so there are weights that always remain on the GPU. 2. Sample Image Generation: - Sample image generation during training is now supported. - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. - - Note: It will be very slow when `--split_mode` is specified. + - Note: It will be very slow when `--blocks_to_swap` is specified. 3. Experimental Memory-Efficient Saving: - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). @@ -621,20 +502,19 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_tr --pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---network_module networks.lora_sd3 --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_module networks.lora_sd3 --network_dim 4 --network_train_unet_only +--optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name sd3-lora-name ``` (The command is multi-line for readability. Please combine it into one line.) -The training can be done with 12GB VRAM GPUs with Adafactor optimizer. Please use settings like below: +Like FLUX.1 training, the `--blocks_to_swap` option for memory reduction is available. The maximum number of blocks that can be swapped is 36 for SD3.5L and 22 for SD3.5M. -``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 -``` +Adafactor optimizer is also available. -`--cpu_offload_checkpointing` and `--split_mode` are not available for SD3 LoRA training. +`--cpu_offload_checkpointing` option is not available. We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. diff --git a/flux_train.py b/flux_train.py index 346fe8fbd..ad2c7722b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -78,6 +78,10 @@ def train(args): ) args.gradient_checkpointing = True + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -518,47 +522,6 @@ def grad_hook(parameter: torch.Tensor): parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 - # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook - if False: # is_swapping_blocks: - import library.custom_offloading_utils as custom_offloading_utils - - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - double_blocks_to_swap = args.blocks_to_swap // 2 - single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 - - offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device) - offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device) - - param_name_pairs = [] - if not args.blockwise_fused_optimizers: - for param_group, param_name_group in zip(optimizer.param_groups, param_names): - param_name_pairs.extend(zip(param_group["params"], param_name_group)) - else: - # named_parameters is a list of (name, parameter) pairs - param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()]) - - for parameter, param_name in param_name_pairs: - if not parameter.requires_grad: - continue - - is_double = param_name.startswith("double_blocks") - is_single = param_name.startswith("single_blocks") - if not is_double and not is_single: - continue - - block_index = int(param_name.split(".")[1]) - if is_double: - blocks = flux.double_blocks - offloader = offloader_double - else: - blocks = flux.single_blocks - offloader = offloader_single - - grad_hook = offloader.create_grad_hook(blocks, block_index) - if grad_hook is not None: - parameter.register_post_accumulate_grad_hook(grad_hook) - # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -827,6 +790,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( @@ -851,16 +815,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks (~640MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", - ) parser.add_argument( "--double_blocks_to_swap", type=int, diff --git a/flux_train_network.py b/flux_train_network.py index 376cc1597..9bcd59282 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -52,10 +52,23 @@ def assert_extra_args(self, args, train_dataset_group): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") - assert not args.split_mode or not args.cpu_offload_checkpointing, ( - "split_mode and cpu_offload_checkpointing cannot be used together" - " / split_modeとcpu_offload_checkpointingは同時に使用できません" - ) + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases train_dataset_group.verify_bucket_reso_steps(32) # TODO check this @@ -75,9 +88,15 @@ def load_target_model(self, args, weight_dtype, accelerator): raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") elif model.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) - if args.split_mode: - model = self.prepare_split_model(model, weight_dtype, accelerator) + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if self.is_swapping_blocks: @@ -108,6 +127,7 @@ def load_target_model(self, args, weight_dtype, accelerator): return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + """ def prepare_split_model(self, model, weight_dtype, accelerator): from accelerate import init_empty_weights @@ -144,6 +164,7 @@ def prepare_split_model(self, model, weight_dtype, accelerator): logger.info("split model prepared") return flux_lower + """ def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) @@ -291,14 +312,12 @@ def sample_images(self, accelerator, args, epoch, global_step, device, ae, token text_encoders = text_encoder # for compatibility text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) - if not args.split_mode: - if self.is_swapping_blocks: - accelerator.unwrap_model(flux).prepare_block_swap_before_forward() - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs - ) - return + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + # return + """ class FluxUpperLowerWrapper(torch.nn.Module): def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): super().__init__() @@ -325,6 +344,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs ) clean_memory_on_device(accelerator.device) + """ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) @@ -383,20 +403,21 @@ def get_noise_pred_and_target( t5_attn_mask = None def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=img, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) + # if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ else: # split forward to reduce memory usage assert network.train_blocks == "single", "train_blocks must be single for split mode" @@ -430,6 +451,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t vec.requires_grad_(True) pe.requires_grad_(True) model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ return model_pred @@ -558,30 +580,23 @@ def prepare_unet_with_accelerator( flux: flux_models.Flux = unet flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() return flux def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( "--split_mode", action="store_true", - help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" - + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", - ) - - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", ) return parser diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 70da93902..84c2b743e 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -16,13 +16,29 @@ def synchronize_device(device: torch.device): torch.mps.synchronize() -def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules + # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + # print(module_to_cpu.__class__, module_to_cuda.__class__) + # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} + for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): + if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: + module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) + if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + else: + if module_to_cuda.weight.data.device.type != device.type: + # print( + # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" + # ) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) torch.cuda.current_stream().synchronize() # this prevents the illegal loss value @@ -92,7 +108,7 @@ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, d def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): if self.cuda_available: - swap_weight_devices(block_to_cpu, block_to_cuda) + swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) else: swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) @@ -132,52 +148,6 @@ def _wait_blocks_move(self, block_idx): print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") -class TrainOffloader(Offloader): - """ - supports backward offloading - """ - - def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): - super().__init__(num_blocks, blocks_to_swap, device, debug) - self.hook_added = set() - - def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: - if block_index in self.hook_added: - return None - self.hook_added.add(block_index) - - # -1 for 0-based index, -1 for current block is not fully backpropagated yet - num_blocks_propagated = self.num_blocks - block_index - 2 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap - waiting = block_index > 0 and block_index <= self.blocks_to_swap - - if not swapping and not waiting: - return None - - # create hook - block_idx_to_cpu = self.num_blocks - num_blocks_propagated - block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_index - 1 - - if self.debug: - print( - f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}" - ) - if swapping: - - def grad_hook(tensor: torch.Tensor): - self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) - - return grad_hook - - else: - - def grad_hook(tensor: torch.Tensor): - self._wait_blocks_move(block_idx_to_wait) - - return grad_hook - - class ModelOffloader(Offloader): """ supports forward offloading @@ -228,6 +198,9 @@ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return + if self.debug: + print("Prepare block devices before forward") + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: b.to(self.device) weighs_to_device(b, self.device) # make sure weights are on device diff --git a/library/flux_models.py b/library/flux_models.py index 4fa272522..fa3c7ad2b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -970,11 +970,16 @@ def enable_block_swap(self, num_blocks: int, device: torch.device): double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." @@ -1061,10 +1066,11 @@ def forward( return img +""" class FluxUpper(nn.Module): - """ + "" Transformer model for flow matching on sequences. - """ + "" def __init__(self, params: FluxParams): super().__init__() @@ -1168,9 +1174,9 @@ def forward( class FluxLower(nn.Module): - """ + "" Transformer model for flow matching on sequences. - """ + "" def __init__(self, params: FluxParams): super().__init__() @@ -1228,3 +1234,4 @@ def forward( img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img +""" diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index fa673a2f0..d90644a25 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -257,14 +257,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption def time_shift(mu: float, sigma: float, t: torch.Tensor): @@ -324,7 +319,7 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - + model.prepare_block_swap_before_forward() return img @@ -549,44 +544,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): action="store_true", help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", ) - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) - parser.add_argument( - "--text_encoder_batch_size", - type=int, - default=None, - help="text encoder batch size (default: None, use dataset's batch size)" - + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", - ) - parser.add_argument( - "--disable_mmap_load_safetensors", - action="store_true", - help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", - ) - # copy from Diffusers - parser.add_argument( - "--weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - ) - parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) parser.add_argument( "--guidance_scale", type=float, diff --git a/library/sd3_models.py b/library/sd3_models.py index 89225fe4d..8b90205db 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -18,6 +18,7 @@ from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from library import custom_offloading_utils from library.device_utils import clean_memory_on_device from .utils import setup_logging @@ -862,7 +863,8 @@ def __init__( # self.initialize_weights() self.blocks_to_swap = None - self.thread_pool: Optional[ThreadPoolExecutor] = None + self.offloader = None + self.num_blocks = len(self.joint_blocks) def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): self.use_scaled_pos_embed = use_scaled_pos_embed @@ -1055,14 +1057,20 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b # ) return spatial_pos_embed - def enable_block_swap(self, num_blocks: int): + def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - n = 1 # async block swap. 1 is enough - self.thread_pool = ThreadPoolExecutor(max_workers=n) + assert ( + self.blocks_to_swap <= self.num_blocks - 2 + ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." + + self.offloader = custom_offloading_utils.ModelOffloader( + self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True + ) + print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_blocks = self.joint_blocks self.joint_blocks = None @@ -1073,16 +1081,9 @@ def move_to_device_except_swap_blocks(self, device: torch.device): self.joint_blocks = save_blocks def prepare_block_swap_before_forward(self): - # make: first n blocks are on cuda, and last n blocks are on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # raise ValueError("Block swap is not enabled.") return - num_blocks = len(self.joint_blocks) - for i in range(num_blocks - self.blocks_to_swap): - self.joint_blocks[i].to(self.device) - for i in range(num_blocks - self.blocks_to_swap, num_blocks): - self.joint_blocks[i].to("cpu") - clean_memory_on_device(self.device) + self.offloader.prepare_block_devices_before_forward(self.joint_blocks) def forward( self, @@ -1122,57 +1123,19 @@ def forward( if self.register_length > 0: context = torch.cat( - ( - einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), - default(context, torch.Tensor([]).type_as(x)), - ), - 1, + (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1 ) if not self.blocks_to_swap: for block in self.joint_blocks: context, x = block(context, x, c) else: - futures = {} - - def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # print(f"Moving {bidx_to_cpu} to cpu.") - block_to_cpu.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Moving {bidx_to_cuda} to cuda.") - block_to_cuda.to(self.device, non_blocking=True) - - torch.cuda.synchronize() - # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - return block_idx_to_cpu, block_idx_to_cuda - - block_to_cpu = self.joint_blocks[block_idx_to_cpu] - block_to_cuda = self.joint_blocks[block_idx_to_cuda] - # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_for_blocks_move(block_idx, ftrs): - if block_idx not in ftrs: - return - # print(f"Waiting for move blocks: {block_idx}") - # start_time = time.perf_counter() - ftr = ftrs.pop(block_idx) - ftr.result() - # torch.cuda.synchronize() - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") - for block_idx, block in enumerate(self.joint_blocks): - wait_for_blocks_move(block_idx, futures) + self.offloader.wait_for_block(block_idx) context, x = block(context, x, c) - if block_idx < self.blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + self.offloader.submit_move_blocks(self.joint_blocks, block_idx) x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify return x[:, :, :H, :W] diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 38f3c25f4..c40798846 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -142,27 +142,6 @@ def sd_saver(ckpt_file, epoch_no, global_step): def add_sd3_training_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) - parser.add_argument( - "--text_encoder_batch_size", - type=int, - default=None, - help="text encoder batch size (default: None, use dataset's batch size)" - + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", - ) - parser.add_argument( - "--disable_mmap_load_safetensors", - action="store_true", - help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", - ) - parser.add_argument( "--clip_l", type=str, @@ -253,32 +232,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # Dependencies of Diffusers noise sampler has been removed for clarity. - parser.add_argument( - "--weighting_scheme", - type=str, - default="uniform", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], - help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", - ) - parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均", - ) - parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd", - ) - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", - ) + # Dependencies of Diffusers noise sampler has been removed for clarity in training + parser.add_argument( "--training_shift", type=float, diff --git a/library/train_util.py b/library/train_util.py index a5d6fdd21..e1dfeecdb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1887,7 +1887,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # make image path to npz path mapping npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) - npz_paths.sort(key=lambda item: item.rsplit("_", maxsplit=2)[0]) # sort by name excluding resolution and cache_suffix + npz_paths.sort( + key=lambda item: item.rsplit("_", maxsplit=2)[0] + ) # sort by name excluding resolution and cache_suffix npz_path_index = 0 size_set_count = 0 @@ -3537,8 +3539,8 @@ def int_or_float(value): parser.add_argument( "--fused_backward_pass", action="store_true", - help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" - + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL, SD3 and FLUX" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXL、SD3、FLUXでのみ利用可能", ) parser.add_argument( "--lr_scheduler_timescale", @@ -4027,6 +4029,72 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): ) +def add_dit_training_arguments(parser: argparse.ArgumentParser): + # Text encoder related arguments + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + + # Model loading optimization + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + # Training arguments. partial copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"], + help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior" + " / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動", + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", + ) + + # offloading + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + + def get_sanitized_config_or_none(args: argparse.Namespace): # if `--log_config` is enabled, return args for logging. if not, return None. # when `--log_config is enabled, filter out sensitive values from args diff --git a/sd3_train.py b/sd3_train.py index 24ecbfb7d..a4fc2eec8 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -201,21 +201,6 @@ def train(args): # モデルを読み込む # t5xxl_dtype = weight_dtype - # if args.t5xxl_dtype is not None: - # if args.t5xxl_dtype == "fp16": - # t5xxl_dtype = torch.float16 - # elif args.t5xxl_dtype == "bf16": - # t5xxl_dtype = torch.bfloat16 - # elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": - # t5xxl_dtype = torch.float32 - # else: - # raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") - # t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - # clip_dtype = weight_dtype # if not args.train_text_encoder else None - - # if clip_l is not specified, the checkpoint must contain clip_l, so we load state dict here - # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). - # by loading with model_dtype, we can reduce memory usage. model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) if args.clip_l is None: sd3_state_dict = utils.load_safetensors( @@ -384,7 +369,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - mmdit.enable_block_swap(args.blocks_to_swap) + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) if not cache_latents: # move to accelerator device @@ -611,108 +596,21 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # memory efficient block swapping - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, device): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda, dvc): - # print(f"Backward: Move block {bidx_to_cpu} to CPU") - block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Backward: Move block {bidx_to_cuda} to CUDA") - block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) - torch.cuda.synchronize() - # print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}") - return bidx_to_cpu, bidx_to_cuda - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - - futures[block_idx_to_cuda] = thread_pool.submit( - move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda, device - ) - - def wait_blocks_move(block_idx, futures): - if block_idx not in futures: - return - future = futures.pop(block_idx) - future.result() - if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - blocks_to_swap = args.blocks_to_swap - num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) - handled_block_indices = set() - - n = 1 # only asynchronous purpose, no need to increase this number - # n = 2 - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - grad_hook = None - - if blocks_to_swap: - is_block = param_name.startswith("joint_blocks") - if is_block: - block_idx = int(param_name.split(".")[1]) - if block_idx not in handled_block_indices: - # swap following (already backpropagated) block - handled_block_indices.add(block_idx) - - # if n blocks were already backpropagated - num_blocks_propagated = num_blocks - block_idx - 1 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = block_idx > 0 and block_idx <= blocks_to_swap - if swapping or waiting: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_idx - 1 - - # create swap hook - def create_swap_grad_hook( - bidx_to_cpu, bidx_to_cuda, bidx_to_wait, bidx: int, swpng: bool, wtng: bool - ): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - if swpng: - submit_move_blocks( - futures, - thread_pool, - bidx_to_cpu, - bidx_to_cuda, - mmdit.joint_blocks, - accelerator.device, - ) - if wtng: - wait_blocks_move(bidx_to_wait, futures) - - return __grad_hook - - grad_hook = create_swap_grad_hook( - block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, block_idx, swapping, waiting - ) - - if grad_hook is None: - - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - grad_hook = __grad_hook + def grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None parameter.register_post_accumulate_grad_hook(grad_hook) @@ -731,59 +629,22 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - blocks_to_swap = args.blocks_to_swap - num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) - - n = 1 # only asynchronous purpose, no need to increase this number - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - block_type, block_idx = block_types_and_indices[opt_idx] - - def create_optimizer_hook(btype, bidx): - def optimizer_hook(parameter: torch.Tensor): - # print(f"optimizer_hook: {btype}, {bidx}") - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - # swap blocks if necessary - if blocks_to_swap and btype == "joint": - num_blocks_propagated = num_blocks - bidx - - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = bidx > 0 and bidx <= blocks_to_swap - - if swapping: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") - submit_move_blocks( - futures, - thread_pool, - block_idx_to_cpu, - block_idx_to_cuda, - mmdit.joint_blocks, - accelerator.device, - ) - - if waiting: - block_idx_to_wait = bidx - 1 - wait_blocks_move(block_idx_to_wait, futures) - - return optimizer_hook - - parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -1130,6 +991,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_custom_train_arguments(parser) + train_util.add_dit_training_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) parser.add_argument( @@ -1190,16 +1052,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks (~640MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", - ) parser.add_argument( "--num_last_block_to_freeze", type=int, diff --git a/sd3_train_network.py b/sd3_train_network.py index bb02c7ac7..1726e325f 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -51,6 +51,10 @@ def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this # enumerate resolutions from dataset for positional embeddings @@ -83,6 +87,17 @@ def load_target_model(self, args, weight_dtype, accelerator): raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") elif mmdit.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 SD3 model") + else: + logger.info( + "Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + mmdit.to(torch.float8_e4m3fn) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) clip_l = sd3_utils.load_clip_l( args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict @@ -432,9 +447,24 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + mmdit: sd3_models.MMDiT = unet + mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + return mmdit + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) return parser diff --git a/tools/cache_latents.py b/tools/cache_latents.py index e2faa58a7..c034f949a 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -164,6 +164,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 7be9ad781..5888b8e3d 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -191,6 +191,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") diff --git a/train_network.py b/train_network.py index d70f14ad3..bbf381f99 100644 --- a/train_network.py +++ b/train_network.py @@ -601,8 +601,10 @@ def train(self, args): # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory - logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") - unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") + # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") + unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) From 2bb0f547d72cd0256cafebd46d0f61fbe54012ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:33:12 +0900 Subject: [PATCH 7/9] update grad hook creation to fix TE lr in sd3 fine tuning --- flux_train.py | 19 ++++++++++++------- library/train_util.py | 1 + sd3_train.py | 15 +++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/flux_train.py b/flux_train.py index ad2c7722b..a89e2f139 100644 --- a/flux_train.py +++ b/flux_train.py @@ -80,7 +80,9 @@ def train(args): assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -480,13 +482,16 @@ def train(args): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - def grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook - parameter.register_post_accumulate_grad_hook(grad_hook) + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers diff --git a/library/train_util.py b/library/train_util.py index e1dfeecdb..25cf7640d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5913,6 +5913,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) diff --git a/sd3_train.py b/sd3_train.py index a4fc2eec8..96ec951b9 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -606,13 +606,16 @@ def train(args): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - def grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook - parameter.register_post_accumulate_grad_hook(grad_hook) + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers From 5c5b544b91ac434c12a372cbf1dc123a367ec878 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:35:43 +0900 Subject: [PATCH 8/9] refactor: remove unused prepare_split_model method from FluxNetworkTrainer --- flux_train_network.py | 39 --------------------------------------- 1 file changed, 39 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 9bcd59282..704c4d32e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -127,45 +127,6 @@ def load_target_model(self, args, weight_dtype, accelerator): return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model - """ - def prepare_split_model(self, model, weight_dtype, accelerator): - from accelerate import init_empty_weights - - logger.info("prepare split model") - with init_empty_weights(): - flux_upper = flux_models.FluxUpper(model.params) - flux_lower = flux_models.FluxLower(model.params) - sd = model.state_dict() - - # lower (trainable) - logger.info("load state dict for lower") - flux_lower.load_state_dict(sd, strict=False, assign=True) - flux_lower.to(dtype=weight_dtype) - - # upper (frozen) - logger.info("load state dict for upper") - flux_upper.load_state_dict(sd, strict=False, assign=True) - - logger.info("prepare upper model") - target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype - flux_upper.to(accelerator.device, dtype=target_dtype) - flux_upper.eval() - - if args.fp8_base: - # this is required to run on fp8 - flux_upper = accelerator.prepare(flux_upper) - - flux_upper.to("cpu") - - self.flux_upper = flux_upper - del model # we don't need model anymore - clean_memory_on_device(accelerator.device) - - logger.info("split model prepared") - - return flux_lower - """ - def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) From fd2d879ac883b8bdf1e03b6ca545c33200dbdff2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:43:08 +0900 Subject: [PATCH 9/9] docs: update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1e63b5830..81a3199bc 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The command to install PyTorch is as follows: ### Recent Updates -Nov 12, 2024: +Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. - During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved.