[Quant] [PT2] Enable QAT Quantization flow in X86InductorQuantizer (#111280)

**Summary**
This PR enables PT2 QAT Quantization flow in `X86InductorQuantizer`.

**Test Plan**
```
python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_with_quantizer_api
python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_unary_with_quantizer_api
python -m pytest test_mkldnn_pattern_matcher.py -k test_qat_qconv2d
python -m pytest test_mkldnn_pattern_matcher.py -k test_qat_qconv2d_relu
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111280
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
leslie-fang-intel 2023-11-01 17:00:38 +08:00 committed by PyTorch MergeBot
parent 8191fb3e06
commit 56ca0043f6
3 changed files with 311 additions and 25 deletions

View File

@ -11,7 +11,11 @@ from torch._dynamo.utils import counters
from torch._export import capture_pre_autograd_graph
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.nn import functional as F
from torch.testing._internal.common_quantization import (
@ -79,18 +83,26 @@ class TestPatternMatcherBase(TestCase):
return tuple(clone(x) for x in inputs)
def _generate_reference_quantized_model(self, mod, inputs):
export_model = capture_pre_autograd_graph(
mod,
inputs,
)
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepare_model = prepare_pt2e(export_model, quantizer)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model)
torch.ao.quantization.move_exported_model_to_eval(convert_model)
return convert_model
def _generate_qdq_quantized_model(self, mod, inputs, is_qat=False):
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
with maybe_no_grad:
export_model = capture_pre_autograd_graph(
mod,
inputs,
)
quantizer = X86InductorQuantizer()
quantizer.set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=is_qat)
)
prepare_model = (
prepare_qat_pt2e(export_model, quantizer)
if is_qat
else prepare_pt2e(export_model, quantizer)
)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model, fold_quantize=True)
torch.ao.quantization.move_exported_model_to_eval(convert_model)
return convert_model
def _test_common(
self,
@ -102,6 +114,7 @@ class TestPatternMatcherBase(TestCase):
rtol=1.3e-6,
check_autocast=False,
check_quantization=False,
is_qat=False,
):
counters.clear()
torch._dynamo.reset()
@ -110,8 +123,8 @@ class TestPatternMatcherBase(TestCase):
maybe_autocast = torch.cpu.amp.autocast()
atol, rtol = 1e-2, 1e-2
if check_quantization:
convert_model = self._generate_qdq_quantized_model(mod, inputs, is_qat)
with torch.no_grad():
convert_model = self._generate_reference_quantized_model(mod, inputs)
_ = torch.compile(convert_model)(*inputs)
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], matcher_count
@ -148,7 +161,7 @@ class TestPatternMatcherBase(TestCase):
with torch.no_grad():
clone_inputs = self._clone_inputs(inputs)
if check_quantization:
mod = self._generate_reference_quantized_model(mod, inputs)
mod = self._generate_qdq_quantized_model(mod, inputs)
expected = mod(*inputs)
actual, (source_code,) = run_and_get_code(
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
@ -626,6 +639,93 @@ class TestPatternMatcher(TestPatternMatcherBase):
check_quantization=True,
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qat_qconv2d(self):
r"""
This testcase will quantize a single Conv2d module with qat flow.
"""
class M(torch.nn.Module):
def __init__(
self,
**kwargs,
):
super().__init__()
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
self.bn = torch.nn.BatchNorm2d(128)
def forward(self, x):
return self.bn(self.conv(x))
mod = M().train()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
# Totally pattern_matcher_count 4,
# pattern_matcher_nodes 17
# 1. pair of to_int8 and to_fp32 at conv input matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. dequant-conv pattern matched in quantization weight prepack
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. pair of to_int8 and to_fp32 at conv output matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
# 4. Quantization fusion in post-grad fusion pass
# [qconv2d_pointwise_default, div_1, round_2, add_1,
# clamp_min_1, clamp_max_1, convert_element_type_2]
self._test_common(
mod,
(v,),
4,
17,
check_quantization=True,
is_qat=True,
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qat_qconv2d_relu(self):
r"""
This testcase will quantize Conv2d->ReLU pattern with qat flow.
"""
class M(torch.nn.Module):
def __init__(
self,
**kwargs,
):
super().__init__()
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
self.unary_fn = torch.nn.ReLU()
self.bn = torch.nn.BatchNorm2d(128)
def forward(self, x):
return self.unary_fn(self.bn(self.conv(x)))
mod = M()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
# Totally pattern_matcher_count 4,
# pattern_matcher_nodes 18
# 1. pair of to_int8 and to_fp32 at conv input matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. dequant-conv pattern matched in quantization weight prepack
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. pair of to_int8 and to_fp32 at conv output matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
# 4. Quantization fusion in post-grad fusion pass
# [qconv2d_pointwise_default, relu, div_1, round_2, add_1,
# clamp_min_1, clamp_max_1, convert_element_type_2]
self._test_common(
mod,
(v,),
4,
18,
check_quantization=True,
is_qat=True,
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm

View File

@ -8,6 +8,7 @@ from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
@ -29,21 +30,32 @@ class Conv2DType(Enum):
class TestHelperModules:
class SingleConv2dModule(torch.nn.Module):
def __init__(self, ) -> None:
def __init__(self, with_bn=False) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1))
self.bn = torch.nn.BatchNorm2d(6)
self.with_bn = with_bn
def forward(self, x):
return self.conv(x)
x = self.conv(x)
if self.with_bn:
x = self.bn(x)
return x
class Conv2dReLUModule(torch.nn.Module):
def __init__(self, inplace_relu: bool = False, use_bias: bool = False) -> None:
def __init__(self, inplace_relu: bool = False, use_bias: bool = False, with_bn=False) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias)
self.relu = nn.ReLU(inplace=inplace_relu)
self.bn = torch.nn.BatchNorm2d(6)
self.with_bn = with_bn
def forward(self, x):
return self.relu(self.conv(x))
x = self.conv(x)
if self.with_bn:
x = self.bn(x)
x = self.relu(x)
return x
class Conv2dAddModule(torch.nn.Module):
def __init__(self,
@ -238,8 +250,9 @@ class X86InductorQuantTestCase(QuantizationTestCase):
quantizer,
expected_node_occurrence,
expected_node_list=None,
is_qat=False,
):
m_eager = model.eval()
m_eager = model.train() if is_qat else model.eval()
# program capture
m = copy.deepcopy(m_eager)
@ -248,8 +261,9 @@ class X86InductorQuantTestCase(QuantizationTestCase):
example_inputs,
)
export_model = copy.deepcopy(m)
m = prepare_pt2e(m, quantizer)
# QAT Model failed to deepcopy
export_model = m if is_qat else copy.deepcopy(m)
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
prepare_model = copy.deepcopy(m)
@ -873,3 +887,81 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
node_occurrence,
node_list,
)
@skipIfNoX86
def test_qat_conv2d_with_quantizer_api(self):
"""
Test QAT pattern of conv2d_bn with X86InductorQuantizer.
"""
with override_quantized_engine("x86"):
m = TestHelperModules.SingleConv2dModule(with_bn=True)
example_inputs = (torch.randn(2, 3, 16, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
)
node_occurrence = {
# one for input and weight of the conv, one for output for the conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
is_qat=True,
)
@skipIfNoX86
def test_qat_conv2d_unary_with_quantizer_api(self):
"""
Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
Currently, only relu as unary post op is supported.
"""
inplace_relu_list = [True, False]
with override_quantized_engine("x86"):
for inplace_relu in itertools.product(inplace_relu_list):
m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, with_bn=True)
example_inputs = (torch.randn(2, 3, 16, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
)
node_occurrence = {
# one for input and weight of the conv, one for output for the relu
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
is_qat=True,
)

View File

@ -7,8 +7,10 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
import torch
import torch.nn.functional as F
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
@ -72,6 +74,14 @@ int8_in_int8_out_ops_pt2e: Set = {
QUANT_ANNOTATION_KEY = "quantization_annotation"
def _mark_nodes_as_annotated(nodes: List[Node]):
for node in nodes:
if node is not None:
if QUANT_ANNOTATION_KEY not in node.meta:
node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation()
node.meta[QUANT_ANNOTATION_KEY]._annotated = True
def _is_node_annotated(_node):
"""
return True if the node is annotated, otherwise return False
@ -156,9 +166,9 @@ def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
@functools.lru_cache
def get_default_x86_inductor_quantization_config():
def get_default_x86_inductor_quantization_config(is_qat: bool = False):
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
HistogramObserver
FusedMovingAvgObsFakeQuantize if is_qat else HistogramObserver
)
# Copy from x86 default qconfig from torch/ao/quantization/qconfig.py
@ -174,9 +184,13 @@ def get_default_x86_inductor_quantization_config():
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PerChannelMinMaxObserver
FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver
)
extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
# Only support per channel quant for now
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
@ -199,6 +213,7 @@ def get_default_x86_inductor_quantization_config():
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
return quantization_config
@ -379,6 +394,9 @@ class X86InductorQuantizer(Quantizer):
config = self.global_config
if config.is_qat:
self._annotate_qat_conv2d_fusion_pattern(model, config)
# Step1: Recipe of fusion patterns like conv/linear.
self._annotate_conv2d_fusion_pattern(model, config)
@ -398,6 +416,82 @@ class X86InductorQuantizer(Quantizer):
return model
def _annotate_qat_conv2d_fusion_pattern(
self, model: torch.fx.GraphModule, config: QuantizationConfig
):
# Annotate QAT Specific patterns
self._annotate_qat_conv2d_bn_unary(model, config)
self._annotate_qat_conv2d_bn(model, config)
def _annotate_qat_conv2d_bn_unary(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU]
)
for fused_partition in fused_partitions:
conv_partition, bn_partition, unary_partition = fused_partition
(
conv_node,
bn_output_node,
unary_node,
) = self._get_output_nodes_of_partitions(
[conv_partition, bn_partition, unary_partition]
)
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.conv2d.default
):
continue
if _is_annotated([unary_node, bn_output_node, conv_node]):
continue
self._annotate_conv_node_helper(conv_node, False, quantization_config)
unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
_is_output_of_quantized_pattern=True,
)
nodes_to_mark_annotated = list(conv_partition.nodes)
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
nodes_to_mark_annotated.extend(list(unary_partition.nodes))
_mark_nodes_as_annotated(nodes_to_mark_annotated)
def _annotate_qat_conv2d_bn(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
)
for fused_partition in fused_partitions:
conv_partition, bn_partition = fused_partition
conv_node, bn_output_node = self._get_output_nodes_of_partitions(
[conv_partition, bn_partition]
)
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.conv2d.default
):
continue
if _is_annotated([bn_output_node, conv_node]):
continue
self._annotate_conv_node_helper(conv_node, False, quantization_config)
bn_output_node.meta[
QUANT_ANNOTATION_KEY
] = _X86InductorQuantizationAnnotation(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
_is_output_of_quantized_pattern=True,
)
nodes_to_mark_annotated = list(conv_partition.nodes)
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
_mark_nodes_as_annotated(nodes_to_mark_annotated)
def _annotate_conv2d_fusion_pattern(
self, model: torch.fx.GraphModule, config: QuantizationConfig
):