From 2e3c4d01f7efe2e641e28da848950a804a8ba887 Mon Sep 17 00:00:00 2001 From: zengxian Date: Mon, 25 Nov 2024 00:51:55 -0500 Subject: [PATCH] enable autoquant for cpu --- torchbenchmark/util/backends/torchdynamo.py | 49 +++++++++++++-------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 6a7559c9c..94db3c416 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: ) parser.add_argument( "--quantization", - choices=["int8dynamic", "int8weightonly", "int4weightonly"], + choices=["int8dynamic", "int8weightonly", "int4weightonly", "auto_quant"], help="Apply quantization to the model before running it", ) parser.add_argument( @@ -183,25 +183,36 @@ def apply_torchdynamo_args( if args.quantization: import torchao - from torchao.quantization import ( - change_linear_weights_to_int4_woqtensors, - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, - ) + if model.device == "cuda": + from torchao.quantization import ( + change_linear_weights_to_int4_woqtensors, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + ) - torch._dynamo.config.automatic_dynamic_shapes = False - torch._dynamo.config.force_parameter_static_shapes = False - torch._dynamo.config.cache_size_limit = 1000 - assert "cuda" in model.device - module, example_inputs = model.get_module() - if args.quantization == "int8dynamic": - torch._inductor.config.force_fuse_int_mm_with_mul = True - change_linear_weights_to_int8_dqtensors(module) - elif args.quantization == "int8weightonly": - torch._inductor.config.use_mixed_mm = True - change_linear_weights_to_int8_woqtensors(module) - elif args.quantization == "int4weightonly": - change_linear_weights_to_int4_woqtensors(module) + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.config.force_parameter_static_shapes = False + torch._dynamo.config.cache_size_limit = 1000 + assert "cuda" in model.device + module, example_inputs = model.get_module() + if args.quantization == "int8dynamic": + torch._inductor.config.force_fuse_int_mm_with_mul = True + change_linear_weights_to_int8_dqtensors(module) + elif args.quantization == "int8weightonly": + torch._inductor.config.use_mixed_mm = True + change_linear_weights_to_int8_woqtensors(module) + elif args.quantization == "int4weightonly": + change_linear_weights_to_int4_woqtensors(module) + elif model.device == "cpu" and model.test == "eval": + if args.quantization == "auto_quant": + module, example_inputs = model.get_module() + with torch.no_grad(): + module=torchao.autoquant(torch.compile(module, mode='max-autotune')) + if isinstance(example_inputs, dict): + module(**example_inputs) + else: + module(*example_inputs) + model.set_module(module) if args.freeze_prepack_weights: torch._inductor.config.freezing = True