From dd9763be31805f24255ca722f30bc5f6d99c73f5 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Wed, 27 Mar 2024 21:53:40 +0300 Subject: [PATCH 1/4] Rating support for WD Tagger --- finetune/tag_images_by_wd14_tagger.py | 33 ++++++++++++++++++--------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index c19ae9749..1d49afc7f 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -174,8 +174,9 @@ def main(args): rows = l[1:] assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" - general_tags = [row[1] for row in rows[1:] if row[2] == "0"] - character_tags = [row[1] for row in rows[1:] if row[2] == "4"] + rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] + general_tags = [row[1] for row in rows[0:] if row[2] == "0"] + character_tags = [row[1] for row in rows[0:] if row[2] == "4"] # 画像を読み込む @@ -202,17 +203,13 @@ def run_batch(path_imgs): probs = probs.numpy() for (image_path, _), prob in zip(path_imgs, probs): - # 最初の4つはratingなので無視する - # # First 4 labels are actually ratings: pick one with argmax - # ratings_names = label_names[:4] - # rating_index = ratings_names["probs"].argmax() - # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] + combined_tags = [] + rating_tag_text = "" + character_tag_text = "" + general_tag_text = "" # それ以降はタグなのでconfidenceがthresholdより高いものを追加する # Everything else is tags: pick any where prediction confidence > threshold - combined_tags = [] - general_tag_text = "" - character_tag_text = "" for i, p in enumerate(prob[4:]): if i < len(general_tags) and p >= args.general_threshold: tag_name = general_tags[i] @@ -231,7 +228,20 @@ def run_batch(path_imgs): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 character_tag_text += caption_separator + tag_name - combined_tags.append(tag_name) + combined_tags.insert(0,tag_name) # insert to the beggining + + #最初の4つはratingなので無視する + # First 4 labels are actually ratings: pick one with argmax + ratings_names = prob[:4] + rating_index = ratings_names.argmax() + found_rating = rating_tags[rating_index] + if args.remove_underscore and len(found_rating) > 3: + found_rating = found_rating.replace("_", " ") + + if found_rating not in undesired_tags: + tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 + rating_tag_text = found_rating + combined_tags.insert(0,found_rating) # insert to the beggining # 先頭のカンマを取る if len(general_tag_text) > 0: @@ -264,6 +274,7 @@ def run_batch(path_imgs): if args.debug: logger.info("") logger.info(f"{image_path}:") + logger.info(f"\tRating tags: {rating_tag_text}") logger.info(f"\tCharacter tags: {character_tag_text}") logger.info(f"\tGeneral tags: {general_tag_text}") From 954731d56402a463a71c0626cb22699bc4e43c3b Mon Sep 17 00:00:00 2001 From: Disty0 Date: Wed, 27 Mar 2024 22:00:59 +0300 Subject: [PATCH 2/4] fix typo --- finetune/tag_images_by_wd14_tagger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 1d49afc7f..4003210ed 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -228,7 +228,7 @@ def run_batch(path_imgs): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 character_tag_text += caption_separator + tag_name - combined_tags.insert(0,tag_name) # insert to the beggining + combined_tags.insert(0,tag_name) # insert to the beginning #最初の4つはratingなので無視する # First 4 labels are actually ratings: pick one with argmax @@ -241,7 +241,7 @@ def run_batch(path_imgs): if found_rating not in undesired_tags: tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 rating_tag_text = found_rating - combined_tags.insert(0,found_rating) # insert to the beggining + combined_tags.insert(0,found_rating) # insert to the beginning # 先頭のカンマを取る if len(general_tag_text) > 0: From 4012fd24f684d4d371a8736d7e65bee307077e33 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 28 Mar 2024 21:08:16 +0300 Subject: [PATCH 3/4] IPEX fix pin_memory --- library/ipex/__init__.py | 7 ++++--- library/ipex/hijacks.py | 17 +++++++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 972a3bf63..e5aba693c 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -32,6 +32,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.nn.Module.cuda = torch.nn.Module.xpu torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialized = torch.xpu.lazy_init._initialized @@ -147,9 +148,9 @@ def ipex_init(): # pylint: disable=too-many-statements # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count + ipex._C._DeviceProperties.major = 2024 + ipex._C._DeviceProperties.minor = 0 # Fix functions with ipex: torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 65089f39e..d3cef8276 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -190,6 +190,16 @@ def Tensor_cuda(self, device=None, *args, **kwargs): else: return original_Tensor_cuda(self, device, *args, **kwargs) +original_Tensor_pin_memory = torch.Tensor.pin_memory +@wraps(torch.Tensor.pin_memory) +def Tensor_pin_memory(self, device=None, *args, **kwargs): + if device is None: + device = "xpu" + if check_device(device): + return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_pin_memory(self, device, *args, **kwargs) + original_UntypedStorage_init = torch.UntypedStorage.__init__ @wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): @@ -259,10 +269,12 @@ def torch_Generator(device=None): original_torch_load = torch.load @wraps(torch.load) def torch_load(f, map_location=None, *args, **kwargs): + if map_location is None: + map_location = "xpu" if check_device(map_location): - return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs) + return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs) else: - return original_torch_load(f, map_location=map_location, *args, **kwargs) + return original_torch_load(f, *args, map_location=map_location, **kwargs) # Hijack Functions: @@ -270,6 +282,7 @@ def ipex_hijacks(): torch.tensor = torch_tensor torch.Tensor.to = Tensor_to torch.Tensor.cuda = Tensor_cuda + torch.Tensor.pin_memory = Tensor_pin_memory torch.UntypedStorage.__init__ = UntypedStorage_init torch.UntypedStorage.cuda = UntypedStorage_cuda torch.empty = torch_empty From bc586ce190e1e85adcd7d9734636fac068bc929e Mon Sep 17 00:00:00 2001 From: Disty0 Date: Fri, 29 Mar 2024 13:56:42 +0300 Subject: [PATCH 4/4] Add --use_rating_tags and --character_tags_first for WD Tagger --- finetune/tag_images_by_wd14_tagger.py | 58 ++++++++++++++++++--------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 4003210ed..16a26179d 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -130,10 +130,10 @@ def main(args): input_name = model.graph.input[0].name try: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value - except: + except Exception: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param - if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0: + if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0: # some rebatch model may use 'N' as dynamic axes logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" @@ -169,9 +169,9 @@ def main(args): with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f) - l = [row for row in reader] - header = l[0] # tag_id,name,category,count - rows = l[1:] + line = [row for row in reader] + header = line[0] # tag_id,name,category,count + rows = line[1:] assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] @@ -228,20 +228,24 @@ def run_batch(path_imgs): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 character_tag_text += caption_separator + tag_name - combined_tags.insert(0,tag_name) # insert to the beginning + if args.character_tags_first: # insert to the beginning + combined_tags.insert(0,tag_name) + else: + combined_tags.append(tag_name) #最初の4つはratingなので無視する # First 4 labels are actually ratings: pick one with argmax - ratings_names = prob[:4] - rating_index = ratings_names.argmax() - found_rating = rating_tags[rating_index] - if args.remove_underscore and len(found_rating) > 3: - found_rating = found_rating.replace("_", " ") - - if found_rating not in undesired_tags: - tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 - rating_tag_text = found_rating - combined_tags.insert(0,found_rating) # insert to the beginning + if args.use_rating_tags: + ratings_names = prob[:4] + rating_index = ratings_names.argmax() + found_rating = rating_tags[rating_index] + if args.remove_underscore and len(found_rating) > 3: + found_rating = found_rating.replace("_", " ") + + if found_rating not in undesired_tags: + tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 + rating_tag_text = found_rating + combined_tags.insert(0,found_rating) # insert to the beginning # 先頭のカンマを取る if len(general_tag_text) > 0: @@ -332,7 +336,9 @@ def run_batch(path_imgs): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument( + "train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ" + ) parser.add_argument( "--repo_id", type=str, @@ -350,7 +356,9 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ" + ) parser.add_argument( "--max_data_loader_n_workers", type=int, @@ -389,7 +397,9 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", ) - parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument( + "--debug", action="store_true", help="debug mode" + ) parser.add_argument( "--undesired_tags", type=str, @@ -399,10 +409,18 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する" ) - parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") + parser.add_argument( + "--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する" + ) parser.add_argument( "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" ) + parser.add_argument( + "--use_rating_tags", action="store_true", help="Adds rating tags as the first tag", + ) + parser.add_argument( + "--character_tags_first", action="store_true", help="Always inserts character tags before the general tags", + ) parser.add_argument( "--caption_separator", type=str,