Skip to content

Commit

Permalink
Merge branch 'kohya-ss:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds authored Mar 30, 2024
2 parents 0ab18db + 6ba8428 commit 9d3792f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 25 deletions.
69 changes: 49 additions & 20 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def main(args):
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
except Exception:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param

if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0:
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
# some rebatch model may use 'N' as dynamic axes
logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
Expand Down Expand Up @@ -169,13 +169,14 @@ def main(args):

with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"

general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]

# 画像を読み込む

Expand All @@ -202,17 +203,13 @@ def run_batch(path_imgs):
probs = probs.numpy()

for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
combined_tags = []
rating_tag_text = ""
character_tag_text = ""
general_tag_text = ""

# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
combined_tags = []
general_tag_text = ""
character_tag_text = ""
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
Expand All @@ -231,7 +228,24 @@ def run_batch(path_imgs):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0,tag_name)
else:
combined_tags.append(tag_name)

#最初の4つはratingなので無視する
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags:
ratings_names = prob[:4]
rating_index = ratings_names.argmax()
found_rating = rating_tags[rating_index]
if args.remove_underscore and len(found_rating) > 3:
found_rating = found_rating.replace("_", " ")

if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
combined_tags.insert(0,found_rating) # insert to the beginning

# 先頭のカンマを取る
if len(general_tag_text) > 0:
Expand Down Expand Up @@ -264,6 +278,7 @@ def run_batch(path_imgs):
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")

Expand Down Expand Up @@ -321,7 +336,9 @@ def run_batch(path_imgs):

def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--repo_id",
type=str,
Expand All @@ -339,7 +356,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
Expand Down Expand Up @@ -378,7 +397,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument(
"--undesired_tags",
type=str,
Expand All @@ -388,10 +409,18 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags",
)
parser.add_argument(
"--caption_separator",
type=str,
Expand Down
7 changes: 4 additions & 3 deletions library/ipex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.nn.Module.cuda = torch.nn.Module.xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
Expand Down Expand Up @@ -147,9 +148,9 @@ def ipex_init(): # pylint: disable=too-many-statements

# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 2024
ipex._C._DeviceProperties.minor = 0

# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
Expand Down
17 changes: 15 additions & 2 deletions library/ipex/hijacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
else:
return original_Tensor_cuda(self, device, *args, **kwargs)

original_Tensor_pin_memory = torch.Tensor.pin_memory
@wraps(torch.Tensor.pin_memory)
def Tensor_pin_memory(self, device=None, *args, **kwargs):
if device is None:
device = "xpu"
if check_device(device):
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_pin_memory(self, device, *args, **kwargs)

original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
Expand Down Expand Up @@ -259,17 +269,20 @@ def torch_Generator(device=None):
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs):
if map_location is None:
map_location = "xpu"
if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs)
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
else:
return original_torch_load(f, map_location=map_location, *args, **kwargs)
return original_torch_load(f, *args, map_location=map_location, **kwargs)


# Hijack Functions:
def ipex_hijacks():
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.Tensor.pin_memory = Tensor_pin_memory
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty
Expand Down

0 comments on commit 9d3792f

Please sign in to comment.