Skip to content

Commit

Permalink
Init quantization backend config for inductor (pytorch#96476)
Browse files Browse the repository at this point in the history
**Summary**
Init the backend config file with quantization recipes for quantization 2.0 inductor path. In this PR, we only init the recipe for `convolution` and `convolution_relu`.

**Test Plan**
```
clear && python -m pytest test_quantization.py -k test_inductor_backend_config_conv
```

Pull Request resolved: pytorch#96476
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jerryzh168
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Mar 22, 2023
1 parent 517a432 commit a6d8c70
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 2 deletions.
72 changes: 72 additions & 0 deletions test/quantization/fx/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
QuantizationTestCase,
skip_if_no_torchvision,
skipIfNoQNNPACK,
skipIfNoX86,
)
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.testing._internal.common_quantized import (
Expand All @@ -21,6 +22,7 @@
get_qnnpack_backend_config,
)
from torch.ao.quantization.backend_config._qnnpack_pt2e import get_qnnpack_pt2e_backend_config
from torch.ao.quantization.backend_config._x86_inductor_pt2e import get_x86_inductor_pt2e_backend_config
from torch.ao.quantization.quantize_fx import prepare_fx, convert_to_reference_fx
from torch.ao.quantization._quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.ns.fx.utils import (
Expand Down Expand Up @@ -190,6 +192,76 @@ def forward(self, x):
code_after_recompile = m.code
self.assertTrue(code_before_recompile == code_after_recompile, error_msg)

@skipIfNoQNNPACK
class TestQuantizePT2EX86Inductor(QuantizationTestCase):
@skipIfNoX86
@xfailIfPython311
def test_inductor_backend_config_conv(self):
class M(torch.nn.Module):
def __init__(self, use_relu: bool = False, inplace_relu: bool = False):
super().__init__()
self.use_relu = use_relu
self.conv1 = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1))
self.relu = nn.ReLU(inplace=inplace_relu)

def forward(self, x):
x = self.conv1(x)
return self.relu(x) if self.use_relu else x

use_relu_list = [True, False]
inplace_relu_list = [True, False]
with override_quantized_engine("x86"):
with torch.no_grad():
for use_relu, inplace_relu in itertools.product(use_relu_list, inplace_relu_list):
m = M(use_relu=use_relu, inplace_relu=inplace_relu).eval()
example_inputs = (torch.randn(2, 3, 4, 4),)
# program capture
# **TODO** Add testcase for tracing_mode="symbolic" after fix issue:
# https://github.com/pytorch/pytorch/issues/96274
export_module, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)

qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_x86_inductor_pt2e_backend_config()
prepare_module = prepare_pt2e(export_module, qconfig_mapping, example_inputs, backend_config)
prepare_module(*example_inputs)
convert_module = convert_pt2e(prepare_module)
convert_module(*example_inputs)

# Fake quant should only be inserted at start and end
node_occurrence = {
# one for input and weight of the conv, one for output for the conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
}
if use_relu:
node_list = [
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
]
else:
node_list = [
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
]
self.checkGraphModuleNodes(convert_module,
expected_node_occurrence=node_occurrence,
expected_node_list=node_list)

class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
Expand Down
1 change: 1 addition & 0 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
# Quantization for PyTorch 2.0 Export path
from quantization.fx.test_quantize_pt2e import TestQuantizePT2E # noqa: F401
from quantization.fx.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401
from quantization.fx.test_quantize_pt2e import TestQuantizePT2EX86Inductor # noqa: F401
except ImportError:
# In FBCode we separate FX out into a separate target for the sake of dev
# velocity. These are covered by a separate test target `quantization_fx`
Expand Down
45 changes: 45 additions & 0 deletions torch/ao/quantization/backend_config/_x86_inductor_pt2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from torch.ao.quantization.backend_config import (
BackendConfig,
DTypeConfig,
ObservationType,
BackendPatternConfig,
)

weighted_op_quint8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)

def get_conv_configs():
conv_configs = []
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
dtype_configs = [weighted_op_quint8_dtype_config]
conv_configs.append(
BackendPatternConfig(torch.ops.aten.convolution.default)
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_input_type_to_index({"weight": 1, "bias": 2})
)
conv_configs.append(
BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu.default))
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_input_type_to_index({"weight": 1, "bias": 2})
)
# TODO: remove when functionalization is supported in PT2 mode
conv_configs.append(
BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu_.default))
.set_observation_type(observation_type) # noqa: E131
.set_dtype_configs(dtype_configs)
._set_input_type_to_index({"weight": 1, "bias": 2})
)
return conv_configs

def get_x86_inductor_pt2e_backend_config():
return (
BackendConfig("inductor_pytorch_2.0_export")
.set_backend_pattern_configs(get_conv_configs())
)
4 changes: 2 additions & 2 deletions torch/ao/quantization/quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
Scope,
ScopeContextManager
)
from .fx import fuse # noqa: F401
from .fx import prepare # noqa: F401
from .fx.fuse import fuse # noqa: F401
from .fx.prepare import prepare # noqa: F401
from .fx.convert import convert
from .backend_config import ( # noqa: F401
BackendConfig,
Expand Down
16 changes: 16 additions & 0 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,22 @@ def wrapper(*args, **kwargs):
fn(*args, **kwargs)
return wrapper

def skipIfNoX86(fn):
reason = 'Quantized operations require X86.'
if isinstance(fn, type):
if 'x86' not in torch.backends.quantized.supported_engines:
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = reason
return fn

@functools.wraps(fn)
def wrapper(*args, **kwargs):
if 'x86' not in torch.backends.quantized.supported_engines:
raise unittest.SkipTest(reason)
else:
fn(*args, **kwargs)
return wrapper

try:
import torchvision # noqa: F401
HAS_TORCHVISION = True
Expand Down

0 comments on commit a6d8c70

Please sign in to comment.