diff --git a/exllamav2/model_init.py b/exllamav2/model_init.py index fd85cbbb..17730ce2 100644 --- a/exllamav2/model_init.py +++ b/exllamav2/model_init.py @@ -22,6 +22,8 @@ def add_args(parser): parser.add_argument("-lq4", "--load_q4", action = "store_true", help = "Load weights in Q4 mode") parser.add_argument("-fst", "--fast_safetensors", action = "store_true", help = "Use alternative safetensors loader (with direct I/O when available)") parser.add_argument("-ic", "--ignore_compatibility", action = "store_true", help = "Do not override model config options in case of compatibility issues") + parser.add_argument("-chunk", "--chunk_size", type = int, help = "Chunk size ('input length')") + def print_options(args): @@ -41,6 +43,7 @@ def print_options(args): if args.experts_per_token is not None: print_opts += [f"experts_per_token: {args.experts_per_token}"] if args.load_q4: print_opts += ["load_q4"] if args.ignore_compatibility: print_opts += ["ignore_compatibility"] + if args.chunk_size is not None: print_opts += [f"chunk_size: {args.chunk_size}"] print(f" -- Options: {print_opts}") @@ -107,6 +110,10 @@ def init(args, if args.low_mem: config.set_low_mem() if args.load_q4: config.load_in_q4 = True + if args.chunk_size is not None: + config.max_input_len = args.chunk_size + config.max_attention_size = args.chunk_size ** 2 + # Compatibility warnings config.arch_compat_overrides(warn_only = args.ignore_compatibility)