-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·596 lines (468 loc) · 23.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
"""
训练代码
代码输入:
- 数据文件夹路径, 其中包含近近脸照文件夹和全身照文件夹,
- 指定的输出路径, 用于输出模型
- 其他的参数需要选手自行设定
代码输出:
- 微调后的模型以及其他附加的子模块
"""
from accelerate import Accelerator
import hashlib
import warnings
import torch
import utils
from absl import logging
import os
#import wandb
import libs.autoencoder
import clip
import itertools
from libs.clip import CLIPEmbedder
from libs.caption_decoder import CaptionDecoder
from torch.utils.data import DataLoader
from libs.schedule import stable_diffusion_beta_schedule, Schedule, LSimple_T2I
import argparse
import yaml
import datetime
from transformers import AutoTokenizer,PretrainedConfig
from pathlib import Path
from libs.data import PersonalizedBase, PromptDataset, collate_fn
from libs.uvit_multi_post_ln_v1 import UViT
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from pathlib import Path
from transformers import CLIPTextModel
import tqdm
from accelerate.logging import get_logger
import itertools
import json
from peft import inject_adapter_in_model, LoraConfig,get_peft_model
# 保存text encoder中新增token的embedding
def save_new_embed(clip_text_model, modifier_token_id, accelerator, args, outdir):
"""Saves the new token embeddings from the text encoder."""
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(clip_text_model).get_input_embeddings().weight
for x, y in zip(modifier_token_id, args.modifier_token):
learned_embeds_dict = {}
learned_embeds_dict[y] = learned_embeds[x]
torch.save(learned_embeds_dict, f"{outdir}/{y}.bin")
logger = get_logger(__name__)
def freeze_params(params):
for param in params:
param.requires_grad = False
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
"""
根据预训练模型的名称或路径导入相应的模型类。
参数:
pretrained_model_name_or_path: 预训练模型的名称或路径。
revision: 模型的版本号。
返回:
模型类。
根据模型配置获取模型类,支持的模型包括 CLIPTextModel 和 RobertaSeriesModelWithTransformation。
如果模型类不在支持列表中,将引发 ValueError 异常。
"""
# 从预训练配置中获取文本编码器配置
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
# 获取模型类名
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
return RobertaSeriesModelWithTransformation
else:
# 模型类不在支持列表中,引发 ValueError 异常
raise ValueError(f"{model_class} is not supported.")
def train(config):
"""
prepare models
准备各类需要的模型
"""
accelerator, device = utils.setup(config)
args = get_args()
concepts_list = args.concepts_list
if config.with_prior_preservation:
for i, concept in enumerate(concepts_list):
# 目录文件处理
class_images_dir = Path(concept["class_data_dir"])
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True, exist_ok=True)
if config.real_prior:
assert (
class_images_dir / "images"
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {config.num_class_images}"
assert (
len(list((class_images_dir / "images").iterdir())) == config.num_class_images
), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {config.num_class_images}"
assert (
class_images_dir / "caption.txt"
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {config.num_class_images}"
assert (
class_images_dir / "images.txt"
).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {config.num_class_images}"
concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
concepts_list[i] = concept
accelerator.wait_for_everyone()
# pretrained_model_name_or_path = "/data/hdd3/schengwei/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce"
# pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
pretrained_model_name_or_path = "other_models/stablediffusion/b95be7d6f134c3a9e62ee616f310733567f069ce"
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
revision = None,
use_fast=False,
)
text_encoder_cls = import_model_class_from_model_name_or_path(pretrained_model_name_or_path , config.revision)
text_encoder = text_encoder_cls.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision
)
text_encoder.to(device)
train_state = utils.initialize_train_state(config, device, uvit_class=UViT,text_encoder = text_encoder)
caption_decoder = CaptionDecoder(device=device, **config.caption_decoder)
nnet, optimizer = accelerator.prepare(train_state.nnet, train_state.optimizer)
nnet.to(device)
lr_scheduler = train_state.lr_scheduler
autoencoder = libs.autoencoder.get_model(**config.autoencoder).to(device)
autoencoder.requires_grad = False
# Modify the code of custom diffusion to directly import the clip text encoder
# instead of freezing all parameters.
# clip_text_model = CLIPEmbedder(version=config.clip_text_model, device=device)
clip_img_model, clip_img_model_preprocess = clip.load(config.clip_img_model, jit=False)
# clip_img_model.to(device).eval().requires_grad_(False)
clip_img_model.to(device).requires_grad_(True)
# Adding a modifier token which is optimized #### 来自Textual inversion代码
# Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
# add modifier token
modifier_token_id = []
initializer_token_id = []
if args.modifier_token is not None:
args.modifier_token = args.modifier_token.split("+")#['<new1>']
args.initializer_token = config.initializer_token.split("+")#['ktn', 'pll', 'ucd']
if len(args.modifier_token) > len(args.initializer_token):
raise ValueError("You must specify + separated initializer token for each modifier token.")
for modifier_token, initializer_token in zip(
args.modifier_token, args.initializer_token[: len(args.modifier_token)]
):
# Add the placeholder token in tokenizer
#在添加占位符标记时,通常会将占位符添加到词汇表(vocabulary)中,
#以便在处理文本时能够正确地处理这个占位符。占位符可以在模型训练、文本生成、填充序列等任务中起到重要的作用。
num_added_tokens = tokenizer.add_tokens(modifier_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {modifier_token}. Please pass a different"
" `modifier_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer.encode([initializer_token], add_special_tokens=False)
#[42170]
#ktn
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
raise ValueError("The initializer token must be a single token.")
initializer_token_id.append(token_ids[0])
modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token))
print("modifier_token_id",modifier_token_id)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))#从40408变为40409
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
for x, y in zip(modifier_token_id, initializer_token_id):
token_embeds[x] = token_embeds[y]
# Freeze all parameters except for the token embeddings in text encoder
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
"""
处理数据部分
"""
# process data
train_dataset = PersonalizedBase(
concepts_list=concepts_list,
num_class_images=config.num_class_images,
size=config.resolution, # 设置的默认为 512
center_crop=config.center_crop,
tokenizer_max_length=77,
tokenizer=tokenizer,
config = config,
hflip=config.hflip,
# mask_size= autoencoder.encode(torch.randn(1, 3, config.resolution, config.resolution).to(dtype=torch.float16).to(accelerator.device)
# )
# .latent_dist.sample()
# .size()[-1],
mask_size= 64 #custom_diffusion里mask_size的值为64
)
train_dataset_loader = DataLoader(train_dataset,
batch_size=config.batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=config.dataloader_num_workers,
)
train_data_generator = utils.get_data_generator(train_dataset_loader, enable_tqdm=accelerator.is_main_process, desc='train')
logging.info("saving meta data")
os.makedirs(config.meta_dir, exist_ok=True)
with open(os.path.join(config.meta_dir, "config.yaml"), "w") as f:
f.write(yaml.dump(config))
f.close()
_betas = stable_diffusion_beta_schedule()
schedule = Schedule(_betas)
logging.info(f'use {schedule}')
# 验证哪些参数被冻结
for name, param in nnet.named_parameters():
if param.requires_grad:
print(f"未冻结的参数: {name}")
# total_frozen_params = sum(p.numel() for p in text_encoder.parameters() if p.requires_grad)
# 77560320 lora_adapter+text_embedding 37946112 token_embedding
# INFO - nnet has 1029970000 parameters
# INFO - text_encoder has 123060480 parameters
# text_encoder = accelerator.prepare(text_encoder)
def train_step():
metrics = dict()
text, img, img4clip, mask = next(train_data_generator)
img = img.to(device)
text = text.to(device)
img4clip = img4clip.to(device)
data_type = torch.float32
mask = mask.to(device)
# with torch.no_grad():
z = autoencoder.encode(img)
clip_img = clip_img_model.encode_image(img4clip).unsqueeze(1).contiguous()
text = text_encoder(text)[0]
text = caption_decoder.encode_prefix(text)
#z= false text = true
bloss = LSimple_T2I(img=z,clip_img=clip_img, text=text, data_type=data_type, nnet=nnet, schedule=schedule, device=device, config=config,mask=mask)
accelerator.backward(bloss)
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if True:
# 谁给删了,而且改回来了,下面这个 if 语句没什么大用,都是一样的效果
# if accelerator.num_processes > 1:
# grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
# else:
# grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
for i in range(len(modifier_token_id[1:])):
index_grads_to_zero = index_grads_to_zero & (
torch.arange(len(tokenizer)) != modifier_token_id[i]
)
grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[
index_grads_to_zero, :
].fill_(0)
params_to_clip = (
itertools.chain(text_encoder.parameters(), nnet.parameters())
if args.modifier_token is not None
else nnet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
# 更新参数
optimizer.step()
lr_scheduler.step()
# train_state.ema_update(config.get('ema_rate', 0.9999))这个参数影响添加peft训练
train_state.step += 1
optimizer.zero_grad()
metrics['bloss'] = accelerator.gather(bloss.detach().mean()).mean().item()
# metrics['loss_img'] = accelerator.gather(loss_img.detach().mean()).mean().item()
# metrics['loss_clip_img'] = accelerator.gather(loss_clip_img.detach().mean()).mean().item()
# metrics['scale'] = accelerator.scaler.get_scale()
metrics['lr'] = train_state.optimizer.param_groups[0]['lr']
return metrics
# @torch.no_grad()
# @torch.autocast(device_type='cuda')
# def eval(total_step):
# """
# write evaluation code here
# """
# return
def loop():
log_step = config.log_interval
# eval_step = 1000000
save_step = config.train_step
while True:
nnet.train()
with accelerator.accumulate(nnet),accelerator.accumulate(text_encoder):
metrics = train_step()
print("metrics",metrics)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
# nnet.eval()
total_step = train_state.step * config.batch_size
# if total_step >= log_step:
# i = total_step // config.log_interval
# logging.info(utils.dct2str(dict(step=total_step, **metrics)))
# # wandb.log(utils.add_prefix(metrics, 'train'), step=total_step)
# logging.info(f"saving {i}th logging ckpts to {config.root_ckpt}_{i*1000}...")
# os.makedirs(config.root_ckpt, exist_ok=True)
# if not os.path.exists(config.root_ckpt + f"_{i*1000}"):
# os.makedirs(config.root_ckpt + f"_{i*1000}", exist_ok=True)
# logging.info("Mkdir {}".format(config.root_ckpt + f"_{i*1000}"))
# save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir + f"_{i*1000}")
# train_state.save_lora(os.path.join(config.root_ckpt + f"_{i*1000}", 'lora.pt.tmp'))
# log_step += config.log_interval
# if total_step == 1000:
# logging.info(f"saving final ckpts to {config.outdir}_11000...")
# save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir + "_11000")
# # train_state.save(os.path.join(config.outdir, 'final.ckpt'))
# train_state.save_lora(os.path.join(config.outdir + "_11000", 'lora.pt.tmp'))
if total_step >= save_step:
if not os.path.exists(config.outdir):
os.makedirs(config.outdir, exist_ok=True)
logging.info("Mkdir {}".format(config.outdir))
logging.info(f"saving final ckpts to {config.outdir}...")
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir)
train_state.save_lora(os.path.join(config.outdir, 'lora.pt.tmp'))
break
loop()
def get_args():
parser = argparse.ArgumentParser()
# key args
parser.add_argument('-o', "--outdir", type=str, default="model_ouput/girl2", help="output of model")
parser.add_argument("--train_step", type=int, default=2000, help="total training steps")
parser.add_argument("--log_interval", type=int, default=1000, help="log interval")
# args of logging
parser.add_argument("--log_dir", type=str, default="logs", help="the dir to put logs")
parser.add_argument("--nnet_path", type=str, default="models/uvit_v1.pth", help="nnet path to resume")
parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--concepts_list",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--num_class_images",
type=int,
default=200,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" concepts_list, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--real_prior",
default=True,
action="store_true",
help="real images as prior.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help="Number of subprocesses to use for data loading.",
)
parser.add_argument("--modifier_token", type=str, default="<new1>", help="modifier token")
parser.add_argument(
"--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word."
)
args = parser.parse_args()
if args.with_prior_preservation:
if args.concepts_list is None:
args.concepts_list = [
{
"instance_prompt": args.instance_prompt, #photo of a <new1> girl
"class_prompt": args.class_prompt,#girl
"instance_data_dir": args.instance_data_dir,#./path-to-images/
"class_data_dir": args.class_data_dir,#./real_reg/samples_person/
}
]
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
else:
# logger is not available yet
if args.concepts_list is not None:
warnings.warn("You need not use --concepts_list without --with_prior_preservation.")
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
return args
def main():
print("main start!")
# 赛手需要根据自己的需求修改config file
from configs.unidiffuserv1 import get_config
config = get_config()
config_name = "unidiffuserv1"
args = get_args()
config.log_dir = args.log_dir
config.outdir = args.outdir
config.data = args.instance_data_dir
config.modifier_token = args.modifier_token
config.initializer_token = args.initializer_token
config.prior_loss_weight = args.prior_loss_weight
config.instance_prompt = args.instance_prompt
config.class_prompt = args.class_prompt
config.dataloader_num_workers = args.dataloader_num_workers
config.gradient_accumulation_steps = args.gradient_accumulation_steps
config.with_prior_preservation = args.with_prior_preservation
config.real_prior = args.real_prior
config.num_class_images = args.num_class_images
config.hflip = args.hflip
config.train_step = args.train_step
config.log_interval = args.log_interval
data_name = Path(config.data).stem
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
config.workdir = os.path.join(config.log_dir, f"{config_name}-{data_name}-{now}")
config.ckpt_root = os.path.join(config.workdir, 'ckpts')
config.meta_dir = os.path.join(config.workdir, "meta")
config.nnet_path = args.nnet_path
os.makedirs(config.workdir, exist_ok=True)
train(config)
if __name__ == "__main__":
main()
"""
accelerate launch train.py \
--instance_data_dir="train_data/newboy1" \
--outdir="model_output/boy1"\
--class_data_dir="real_reg/samples_boyface" \
--with_prior_preservation --prior_loss_weight=1.0 \
--class_prompt="boy" --num_class_images=200 \
--instance_prompt=" a <new1> boy" \
--modifier_token "<new1>"
export LD_LIBRARY_PATH=/home/shiyiming/anaconda3/envs/competition/lib/python3.10/site-packages/torch/lib/
"""