From 5d5f39b6e6bccfc4d4265f8a7ce651acec132f39 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 25 Feb 2024 00:50:36 +0900 Subject: [PATCH] Replaced print with logger --- gen_img.py | 194 +++++++++++++++++++++------------------- gen_img_diffusers.py | 22 ++--- library/device_utils.py | 11 ++- networks/dylora.py | 4 +- sdxl_gen_img.py | 32 +++---- 5 files changed, 138 insertions(+), 125 deletions(-) diff --git a/gen_img.py b/gen_img.py index daf88d2a1..4fe898716 100644 --- a/gen_img.py +++ b/gen_img.py @@ -61,6 +61,12 @@ from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -82,12 +88,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -95,7 +101,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -112,7 +118,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -168,7 +174,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -224,7 +230,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -386,10 +392,10 @@ def set_control_net_lllites(self, ctrl_net_lllites): def set_gradual_latent(self, gradual_latent): if gradual_latent is None: - print("gradual_latent is disabled") + logger.info("gradual_latent is disabled") self.gradual_latent = None else: - print(f"gradual_latent is enabled: {gradual_latent}") + logger.info(f"gradual_latent is enabled: {gradual_latent}") self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) @torch.no_grad() @@ -467,7 +473,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -576,7 +582,7 @@ def __call__( text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt if init_image is not None and self.clip_vision_model is not None: - print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) @@ -742,8 +748,8 @@ def __call__( enable_gradual_latent = False if self.gradual_latent: if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) + logger.warning("gradual_latent is not supported for this scheduler. Ignoring.") + logger.warning(f"{self.scheduler.__class__.__name__}") else: enable_gradual_latent = True step_elapsed = 1000 @@ -792,7 +798,7 @@ def __call__( if not enabled or ratio >= 1.0: continue if ratio < i / len(timesteps): - print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False @@ -1013,7 +1019,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -1043,7 +1049,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -1344,7 +1350,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -1488,9 +1494,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") # モデルを読み込む if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う @@ -1510,7 +1516,7 @@ def main(args): else: # if `text_encoder_2` subdirectory exists, sdxl is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2")) - print(f"SDXL: {is_sdxl}") + logger.info(f"SDXL: {is_sdxl}") if is_sdxl: if args.clip_skip is None: @@ -1526,10 +1532,10 @@ def main(args): args.clip_skip = 2 if args.v2 else 1 if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") + logger.info("load StableDiffusion checkpoint") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: - print("load Diffusers pretrained models") + logger.info("load Diffusers pretrained models") loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) text_encoder = loading_pipe.text_encoder vae = loading_pipe.vae @@ -1553,7 +1559,7 @@ def main(args): # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") # xformers、Hypernetwork対応 if not args.diffusers_xformers: @@ -1562,7 +1568,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") if is_sdxl: tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) tokenizers = [tokenizer1, tokenizer2] @@ -1654,7 +1660,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -1715,7 +1721,7 @@ def __getattr__(self, item): vae_dtype = dtype if args.no_half_vae: - print("set vae_dtype to float32") + logger.info("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) vae.eval() @@ -1739,10 +1745,10 @@ def __getattr__(self, item): network_merge = args.network_merge_n_models else: network_merge = 0 - print(f"network_merge: {network_merge}") + logger.info(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info("import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -1760,7 +1766,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -1768,7 +1774,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs @@ -1778,20 +1784,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to(text_encoders, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -1805,7 +1811,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info("import upscaler module: {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -1814,7 +1820,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -1833,7 +1839,7 @@ def __getattr__(self, item): control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] if args.control_net_lllite_models: for i, model_file in enumerate(args.control_net_lllite_models): - print(f"loading ControlNet-LLLite: {model_file}") + logger.info(f"loading ControlNet-LLLite: {model_file}") from safetensors.torch import load_file @@ -1867,7 +1873,7 @@ def __getattr__(self, item): ), "ControlNet and ControlNet-LLLite cannot be used at the same time" if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") for text_encoder in text_encoders: text_encoder.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) @@ -1894,7 +1900,7 @@ def __getattr__(self, item): ) pipe.set_control_nets(control_nets) pipe.set_control_net_lllites(control_net_lllites) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -1965,7 +1971,7 @@ def __getattr__(self, item): token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings) token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") assert ( min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 ), f"token ids1 is not ordered" @@ -2002,7 +2008,7 @@ def __getattr__(self, item): # promptを取得する prompt_list = None if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] @@ -2019,7 +2025,7 @@ def load_module_from_path(module_name, file_path): spec.loader.exec_module(module) return module - print(f"reading prompts from module: {args.from_module}") + logger.info(f"reading prompts from module: {args.from_module}") prompt_module = load_module_from_path("prompt_module", args.from_module) prompter = prompt_module.get_prompter(args, pipe, networks) @@ -2050,7 +2056,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -2066,14 +2072,14 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") # CLIP Vision if args.clip_vision_strength is not None: - print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) vision_model.to(device, dtype) processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) @@ -2081,22 +2087,22 @@ def resize_images(imgs, size): pipe.clip_vision_model = vision_model pipe.clip_vision_processor = processor pipe.clip_vision_strength = args.clip_vision_strength - print(f"CLIP Vision model loaded.") + logger.info(f"CLIP Vision model loaded.") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and prompter is None and not args.interactive: - print("get prompts from images' metadata") + logger.info("get prompts from images' metadata") prompt_list = [] for img in init_images: if "prompt" in img.text: @@ -2127,17 +2133,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -2162,14 +2168,14 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print( + logger.warning( f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" ) guide_images = None @@ -2200,7 +2206,7 @@ def fixed_seed(*args, **kwargs): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") if args.iter_same_seed: iter_seed = seed_random.randint(0, 2**32 - 1) else: @@ -2219,7 +2225,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: @@ -2264,7 +2270,7 @@ def scale_and_round(x): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2437,7 +2443,7 @@ def scale_and_round(x): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -2520,7 +2526,7 @@ def scale_and_round(x): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print( + logger.warning( "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" ) @@ -2535,7 +2541,7 @@ def scale_and_round(x): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("\nType prompt:") try: raw_prompt = input() except EOFError: @@ -2595,74 +2601,74 @@ def scale_and_round(x): prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] length = len(prompter) if hasattr(prompter, "__len__") else 0 - print(f"prompt {prompt_index+1}/{length}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{length}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"ow (\d+)", parg, re.IGNORECASE) if m: original_width = int(m.group(1)) - print(f"original width: {original_width}") + logger.info(f"original width: {original_width}") continue m = re.match(r"oh (\d+)", parg, re.IGNORECASE) if m: original_height = int(m.group(1)) - print(f"original height: {original_height}") + logger.info(f"original height: {original_height}") continue m = re.match(r"nw (\d+)", parg, re.IGNORECASE) if m: original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") + logger.info(f"original width negative: {original_width_negative}") continue m = re.match(r"nh (\d+)", parg, re.IGNORECASE) if m: original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") + logger.info(f"original height negative: {original_height_negative}") continue m = re.match(r"ct (\d+)", parg, re.IGNORECASE) if m: crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") + logger.info(f"crop top: {crop_top}") continue m = re.match(r"cl (\d+)", parg, re.IGNORECASE) if m: crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") + logger.info(f"crop left: {crop_left}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -2671,25 +2677,25 @@ def scale_and_round(x): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -2697,89 +2703,89 @@ def scale_and_round(x): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue # Deep Shrink m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + logger.info(f"deep shrink ratio: {ds_ratio}") continue # Gradual Latent m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent timesteps gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") + logger.info(f"gradual latent timesteps: {gl_timesteps}") continue m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio gl_ratio = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") + logger.info(f"gradual latent ratio: {ds_ratio}") continue m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent every n steps gl_every_n_steps = int(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") continue m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio step gl_ratio_step = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") + logger.info(f"gradual latent ratio step: {gl_ratio_step}") continue m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent s noise gl_s_noise = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") + logger.info(f"gradual latent s noise: {gl_s_noise}") continue m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) if m: # gradual latent unsharp params gl_unsharp_params = m.group(1) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # override Deep Shrink if ds_depth_1 is not None: @@ -2825,7 +2831,7 @@ def scale_and_round(x): if seed is None: seed = seed_random.randint(0, 2**32 - 1) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -2841,7 +2847,7 @@ def scale_and_round(x): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.warning( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -2903,12 +2909,14 @@ def scale_and_round(x): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) + parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" ) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 2c5f84a93..2c40f1a06 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -489,10 +489,10 @@ def set_control_nets(self, ctrl_nets): def set_gradual_latent(self, gradual_latent): if gradual_latent is None: - print("gradual_latent is disabled") + logger.info("gradual_latent is disabled") self.gradual_latent = None else: - print(f"gradual_latent is enabled: {gradual_latent}") + logger.info(f"gradual_latent is enabled: {gradual_latent}") self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) # region xformersとか使う部分:独自に書き換えるので関係なし @@ -971,8 +971,8 @@ def __call__( enable_gradual_latent = False if self.gradual_latent: if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') else: enable_gradual_latent = True step_elapsed = 1000 @@ -3314,42 +3314,42 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent timesteps gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") + logger.info(f"gradual latent timesteps: {gl_timesteps}") continue m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio gl_ratio = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") + logger.info(f"gradual latent ratio: {ds_ratio}") continue m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent every n steps gl_every_n_steps = int(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") continue m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio step gl_ratio_step = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") + logger.info(f"gradual latent ratio step: {gl_ratio_step}") continue m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent s noise gl_s_noise = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") + logger.info(f"gradual latent s noise: {gl_s_noise}") continue m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) if m: # gradual latent unsharp params gl_unsharp_params = m.group(1) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: @@ -3369,7 +3369,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if gl_unsharp_params is not None: unsharp_params = gl_unsharp_params.split(",") us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - print(unsharp_params) + logger.info(f'{unsharp_params}') us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) us_ksize = int(us_ksize) else: diff --git a/library/device_utils.py b/library/device_utils.py index 8823c5d9a..f6641ab5c 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -3,6 +3,11 @@ import torch +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + try: HAS_CUDA = torch.cuda.is_available() except Exception: @@ -59,7 +64,7 @@ def get_preferred_device() -> torch.device: device = torch.device("mps") else: device = torch.device("cpu") - print(f"get_preferred_device() -> {device}") + logger.info(f"get_preferred_device() -> {device}") return device @@ -77,8 +82,8 @@ def init_ipex(): is_initialized, error_message = ipex_init() if not is_initialized: - print("failed to initialize ipex:", error_message) + logger.error("failed to initialize ipex: {error_message}") else: return except Exception as e: - print("failed to initialize ipex:", e) + logger.error("failed to initialize ipex: {e}") diff --git a/networks/dylora.py b/networks/dylora.py index d71279c55..637f33450 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -327,10 +327,10 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}") + logger.info(f"create LoRA for Text Encoder {index}") else: index = None - print(f"create LoRA for Text Encoder") + logger.info("create LoRA for Text Encoder") text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 641b3209f..d52f85a8f 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -380,10 +380,10 @@ def set_control_nets(self, ctrl_nets): def set_gradual_latent(self, gradual_latent): if gradual_latent is None: - print("gradual_latent is disabled") + logger.info("gradual_latent is disabled") self.gradual_latent = None else: - print(f"gradual_latent is enabled: {gradual_latent}") + logger.info(f"gradual_latent is enabled: {gradual_latent}") self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) @torch.no_grad() @@ -789,8 +789,8 @@ def __call__( enable_gradual_latent = False if self.gradual_latent: if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') else: enable_gradual_latent = True step_elapsed = 1000 @@ -2614,84 +2614,84 @@ def scale_and_round(x): m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent timesteps gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") + logger.info(f"gradual latent timesteps: {gl_timesteps}") continue m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio gl_ratio = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") + logger.info(f"gradual latent ratio: {ds_ratio}") continue m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent every n steps gl_every_n_steps = int(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") continue m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio step gl_ratio_step = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") + logger.info(f"gradual latent ratio step: {gl_ratio_step}") continue m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent s noise gl_s_noise = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") + logger.info(f"gradual latent s noise: {gl_s_noise}") continue m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) if m: # gradual latent unsharp params gl_unsharp_params = m.group(1) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue # Gradual Latent m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent timesteps gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") + logger.info(f"gradual latent timesteps: {gl_timesteps}") continue m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio gl_ratio = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") + logger.info(f"gradual latent ratio: {ds_ratio}") continue m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent every n steps gl_every_n_steps = int(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") continue m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent ratio step gl_ratio_step = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") + logger.info(f"gradual latent ratio step: {gl_ratio_step}") continue m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) if m: # gradual latent s noise gl_s_noise = float(m.group(1)) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") + logger.info(f"gradual latent s noise: {gl_s_noise}") continue m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) if m: # gradual latent unsharp params gl_unsharp_params = m.group(1) gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: