diff --git a/train_network.py b/train_network.py index dfa51a9c8..b24f89b1e 100644 --- a/train_network.py +++ b/train_network.py @@ -472,7 +472,7 @@ def train(self, args): text_encoder_lr = args.text_encoder_lr else: # toml backward compatibility - if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): text_encoder_lr = args.text_encoder_lr else: text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]