From f5f3bb01fa69e307c283705f3554a3074d26648c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sun, 7 Apr 2024 17:57:32 +0800 Subject: [PATCH] use no schedule --- fine_tune.py | 27 ++++++++++++++++++++------- sdxl_train.py | 16 ++++++++++++---- sdxl_train_control_net_lllite.py | 5 ++++- sdxl_train_control_net_lllite_old.py | 11 ++++++++--- train_controlnet.py | 11 ++++++++--- train_db.py | 27 ++++++++++++++++++++------- train_network.py | 24 +++++++++++++++++------- train_textual_inversion.py | 22 ++++++++++++++++------ train_textual_inversion_XTI.py | 11 ++++++++--- 9 files changed, 113 insertions(+), 41 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index cecb41b19..17d091408 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -255,18 +255,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) + else: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("scheduleFree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: diff --git a/sdxl_train.py b/sdxl_train.py index 6acd8a6ac..2590d36c1 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -415,9 +415,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2=text_encoder2 if train_text_encoder2 else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: @@ -428,7 +433,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("scheduleFree"): + optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader) + else: + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index d788bacf5..6cbb4741b 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -286,7 +286,10 @@ def train(args): unet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("scheduleFree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 3e81f2c94..bb48fcb14 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -254,9 +254,14 @@ def train(args): network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + unet, network, optimizer, train_dataloader = accelerator.prepare( + unet, network, optimizer, train_dataloader + ) + else: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) network: control_net_lllite.ControlNetLLLite if args.gradient_checkpointing: diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d9..6b71799dc 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -276,9 +276,14 @@ def train(args): controlnet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + controlnet, optimizer, train_dataloader = accelerator.prepare( + controlnet, optimizer, train_dataloader + ) + else: + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) unet.requires_grad_(False) text_encoder.requires_grad_(False) diff --git a/train_db.py b/train_db.py index 62f9852f0..ad55d6ce0 100644 --- a/train_db.py +++ b/train_db.py @@ -229,19 +229,32 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) + else: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) training_models = [unet, text_encoder] else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("scheduleFree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) training_models = [unet] if not train_text_encoder: diff --git a/train_network.py b/train_network.py index d47042805..e7db8168c 100644 --- a/train_network.py +++ b/train_network.py @@ -420,9 +420,14 @@ def train(self, args): text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, network=network, ) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_model = ds_model else: if train_unet: @@ -437,10 +442,15 @@ def train(self, args): text_encoders = [text_encoder] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler - ) + + if args.optimizer_type.lower().endswith("scheduleFree"): + network, optimizer, train_dataloader = accelerator.prepare( + network, optimizer, train_dataloader + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) training_model = network if args.gradient_checkpointing: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10fce2677..4adbc642f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -416,14 +416,24 @@ def train(self, args): # acceleratorがなんかよろしくやってくれるらしい if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet( + text_encoder_or_list, optimizer, train_dataloader + ) + else: + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.preparet( + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler + ) elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader + ) + else: + text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler + ) text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 032a36e21..701fd1467 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -335,9 +335,14 @@ def train(args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("scheduleFree"): + text_encoder, optimizer, train_dataloader = accelerator.prepare( + text_encoder, optimizer, train_dataloader + ) + else: + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # logger.info(len(index_no_updates), torch.sum(index_no_updates))