From a6d8c70933785067b40f2be54a876357294a9871 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 22 Mar 2023 12:38:47 +0800 Subject: [PATCH] Init quantization backend config for inductor (#96476) **Summary** Init the backend config file with quantization recipes for quantization 2.0 inductor path. In this PR, we only init the recipe for `convolution` and `convolution_relu`. **Test Plan** ``` clear && python -m pytest test_quantization.py -k test_inductor_backend_config_conv ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/96476 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jerryzh168 --- test/quantization/fx/test_quantize_pt2e.py | 72 +++++++++++++++++++ test/test_quantization.py | 1 + .../backend_config/_x86_inductor_pt2e.py | 45 ++++++++++++ torch/ao/quantization/quantize_fx.py | 4 +- .../testing/_internal/common_quantization.py | 16 +++++ 5 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 torch/ao/quantization/backend_config/_x86_inductor_pt2e.py diff --git a/test/quantization/fx/test_quantize_pt2e.py b/test/quantization/fx/test_quantize_pt2e.py index 36c5ffd57b7e8..689554e7c8214 100644 --- a/test/quantization/fx/test_quantize_pt2e.py +++ b/test/quantization/fx/test_quantize_pt2e.py @@ -7,6 +7,7 @@ QuantizationTestCase, skip_if_no_torchvision, skipIfNoQNNPACK, + skipIfNoX86, ) from torch.testing._internal.common_quantization import NodeSpec as ns from torch.testing._internal.common_quantized import ( @@ -21,6 +22,7 @@ get_qnnpack_backend_config, ) from torch.ao.quantization.backend_config._qnnpack_pt2e import get_qnnpack_pt2e_backend_config +from torch.ao.quantization.backend_config._x86_inductor_pt2e import get_x86_inductor_pt2e_backend_config from torch.ao.quantization.quantize_fx import prepare_fx, convert_to_reference_fx from torch.ao.quantization._quantize_pt2e import prepare_pt2e, convert_pt2e from torch.ao.ns.fx.utils import ( @@ -190,6 +192,76 @@ def forward(self, x): code_after_recompile = m.code self.assertTrue(code_before_recompile == code_after_recompile, error_msg) +@skipIfNoQNNPACK +class TestQuantizePT2EX86Inductor(QuantizationTestCase): + @skipIfNoX86 + @xfailIfPython311 + def test_inductor_backend_config_conv(self): + class M(torch.nn.Module): + def __init__(self, use_relu: bool = False, inplace_relu: bool = False): + super().__init__() + self.use_relu = use_relu + self.conv1 = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1)) + self.relu = nn.ReLU(inplace=inplace_relu) + + def forward(self, x): + x = self.conv1(x) + return self.relu(x) if self.use_relu else x + + use_relu_list = [True, False] + inplace_relu_list = [True, False] + with override_quantized_engine("x86"): + with torch.no_grad(): + for use_relu, inplace_relu in itertools.product(use_relu_list, inplace_relu_list): + m = M(use_relu=use_relu, inplace_relu=inplace_relu).eval() + example_inputs = (torch.randn(2, 3, 4, 4),) + # program capture + # **TODO** Add testcase for tracing_mode="symbolic" after fix issue: + # https://github.com/pytorch/pytorch/issues/96274 + export_module, guards = torchdynamo.export( + m, + *copy.deepcopy(example_inputs), + aten_graph=True, + tracing_mode="real", + ) + + qconfig = get_default_qconfig("x86") + qconfig_mapping = QConfigMapping().set_global(qconfig) + backend_config = get_x86_inductor_pt2e_backend_config() + prepare_module = prepare_pt2e(export_module, qconfig_mapping, example_inputs, backend_config) + prepare_module(*example_inputs) + convert_module = convert_pt2e(prepare_module) + convert_module(*example_inputs) + + # Fake quant should only be inserted at start and end + node_occurrence = { + # one for input and weight of the conv, one for output for the conv + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2, + ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2, + } + if use_relu: + node_list = [ + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ns.call_function(torch.ops.aten.convolution.default), + ns.call_function(torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default), + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ] + else: + node_list = [ + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ns.call_function(torch.ops.aten.convolution.default), + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor), + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ] + self.checkGraphModuleNodes(convert_module, + expected_node_occurrence=node_occurrence, + expected_node_list=node_list) + class TestQuantizePT2EModels(QuantizationTestCase): @skip_if_no_torchvision @skipIfNoQNNPACK diff --git a/test/test_quantization.py b/test/test_quantization.py index 48fe750bb3282..da076db2a0a4d 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -77,6 +77,7 @@ # Quantization for PyTorch 2.0 Export path from quantization.fx.test_quantize_pt2e import TestQuantizePT2E # noqa: F401 from quantization.fx.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401 + from quantization.fx.test_quantize_pt2e import TestQuantizePT2EX86Inductor # noqa: F401 except ImportError: # In FBCode we separate FX out into a separate target for the sake of dev # velocity. These are covered by a separate test target `quantization_fx` diff --git a/torch/ao/quantization/backend_config/_x86_inductor_pt2e.py b/torch/ao/quantization/backend_config/_x86_inductor_pt2e.py new file mode 100644 index 0000000000000..cff299f366d65 --- /dev/null +++ b/torch/ao/quantization/backend_config/_x86_inductor_pt2e.py @@ -0,0 +1,45 @@ +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + ObservationType, + BackendPatternConfig, +) + +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +def get_conv_configs(): + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + conv_configs.append( + BackendPatternConfig(torch.ops.aten.convolution.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + conv_configs.append( + BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu.default)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + # TODO: remove when functionalization is supported in PT2 mode + conv_configs.append( + BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu_.default)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return conv_configs + +def get_x86_inductor_pt2e_backend_config(): + return ( + BackendConfig("inductor_pytorch_2.0_export") + .set_backend_pattern_configs(get_conv_configs()) + ) diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 5a2edbeb29215..aec53e31b313f 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -10,8 +10,8 @@ Scope, ScopeContextManager ) -from .fx import fuse # noqa: F401 -from .fx import prepare # noqa: F401 +from .fx.fuse import fuse # noqa: F401 +from .fx.prepare import prepare # noqa: F401 from .fx.convert import convert from .backend_config import ( # noqa: F401 BackendConfig, diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 86d5876801743..e92cdd0e277f2 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -354,6 +354,22 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +def skipIfNoX86(fn): + reason = 'Quantized operations require X86.' + if isinstance(fn, type): + if 'x86' not in torch.backends.quantized.supported_engines: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = reason + return fn + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if 'x86' not in torch.backends.quantized.supported_engines: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + return wrapper + try: import torchvision # noqa: F401 HAS_TORCHVISION = True