diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index b56d921a3..e63ec3eb4 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -86,23 +86,26 @@ def main(args): logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: + files = ["selected_tags.csv"] files += FILES_ONNX + else: + for file in SUB_DIR_FILES: + hf_hub_download( + args.repo_id, + file, + subfolder=SUB_DIR, + cache_dir=os.path.join(args.model_dir, SUB_DIR), + force_download=True, + force_filename=file, + ) for file in files: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) else: logger.info("using existing wd14 tagger model") # 画像を読み込む if args.onnx: + import torch import onnx import onnxruntime as ort