diff --git a/flux_train_control_net.py b/flux_train_control_net.py index c9d38afbb..85082c257 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -271,7 +271,7 @@ def unwrap_model(model): # load controlnet controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype controlnet = flux_utils.load_controlnet( - args.controlnet_model_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors ) controlnet.train() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0b68abb7d..e1ec88dc3 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -570,7 +570,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 01387409a..ffbf03cab 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -184,12 +184,12 @@ def unwrap_model(model): # make control net logger.info("make ControlNet") - if args.controlnet_model_path: + if args.controlnet_model_name_or_path: with init_empty_weights(): control_net = SdxlControlNet() - logger.info(f"load ControlNet from {args.controlnet_model_path}") - filename = args.controlnet_model_path + logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}") + filename = args.controlnet_model_name_or_path if os.path.splitext(filename)[1] == ".safetensors": state_dict = load_file(filename) else: @@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser: sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", diff --git a/train_control_net.py b/train_control_net.py index 91d0f3800..4bd895836 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -223,8 +223,8 @@ def __contains__(self, name): controlnet = ControlNetModel.from_unet(unet) - if args.controlnet_model_path: - filename = args.controlnet_model_path + if args.controlnet_model_name_or_path: + filename = args.controlnet_model_name_or_path if os.path.isfile(filename): if os.path.splitext(filename)[1] == ".safetensors": state_dict = load_file(filename) @@ -691,7 +691,7 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス",