mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant] [PT2] Enable QAT Quantization flow in X86InductorQuantizer (#111280)
**Summary** This PR enables PT2 QAT Quantization flow in `X86InductorQuantizer`. **Test Plan** ``` python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_with_quantizer_api python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_unary_with_quantizer_api python -m pytest test_mkldnn_pattern_matcher.py -k test_qat_qconv2d python -m pytest test_mkldnn_pattern_matcher.py -k test_qat_qconv2d_relu ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/111280 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
parent
8191fb3e06
commit
56ca0043f6
|
|
@ -11,7 +11,11 @@ from torch._dynamo.utils import counters
|
|||
from torch._export import capture_pre_autograd_graph
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantize_pt2e import (
|
||||
convert_pt2e,
|
||||
prepare_pt2e,
|
||||
prepare_qat_pt2e,
|
||||
)
|
||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_quantization import (
|
||||
|
|
@ -79,18 +83,26 @@ class TestPatternMatcherBase(TestCase):
|
|||
|
||||
return tuple(clone(x) for x in inputs)
|
||||
|
||||
def _generate_reference_quantized_model(self, mod, inputs):
|
||||
export_model = capture_pre_autograd_graph(
|
||||
mod,
|
||||
inputs,
|
||||
)
|
||||
quantizer = X86InductorQuantizer()
|
||||
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
|
||||
prepare_model = prepare_pt2e(export_model, quantizer)
|
||||
prepare_model(*inputs)
|
||||
convert_model = convert_pt2e(prepare_model)
|
||||
torch.ao.quantization.move_exported_model_to_eval(convert_model)
|
||||
return convert_model
|
||||
def _generate_qdq_quantized_model(self, mod, inputs, is_qat=False):
|
||||
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
|
||||
with maybe_no_grad:
|
||||
export_model = capture_pre_autograd_graph(
|
||||
mod,
|
||||
inputs,
|
||||
)
|
||||
quantizer = X86InductorQuantizer()
|
||||
quantizer.set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(is_qat=is_qat)
|
||||
)
|
||||
prepare_model = (
|
||||
prepare_qat_pt2e(export_model, quantizer)
|
||||
if is_qat
|
||||
else prepare_pt2e(export_model, quantizer)
|
||||
)
|
||||
prepare_model(*inputs)
|
||||
convert_model = convert_pt2e(prepare_model, fold_quantize=True)
|
||||
torch.ao.quantization.move_exported_model_to_eval(convert_model)
|
||||
return convert_model
|
||||
|
||||
def _test_common(
|
||||
self,
|
||||
|
|
@ -102,6 +114,7 @@ class TestPatternMatcherBase(TestCase):
|
|||
rtol=1.3e-6,
|
||||
check_autocast=False,
|
||||
check_quantization=False,
|
||||
is_qat=False,
|
||||
):
|
||||
counters.clear()
|
||||
torch._dynamo.reset()
|
||||
|
|
@ -110,8 +123,8 @@ class TestPatternMatcherBase(TestCase):
|
|||
maybe_autocast = torch.cpu.amp.autocast()
|
||||
atol, rtol = 1e-2, 1e-2
|
||||
if check_quantization:
|
||||
convert_model = self._generate_qdq_quantized_model(mod, inputs, is_qat)
|
||||
with torch.no_grad():
|
||||
convert_model = self._generate_reference_quantized_model(mod, inputs)
|
||||
_ = torch.compile(convert_model)(*inputs)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["pattern_matcher_count"], matcher_count
|
||||
|
|
@ -148,7 +161,7 @@ class TestPatternMatcherBase(TestCase):
|
|||
with torch.no_grad():
|
||||
clone_inputs = self._clone_inputs(inputs)
|
||||
if check_quantization:
|
||||
mod = self._generate_reference_quantized_model(mod, inputs)
|
||||
mod = self._generate_qdq_quantized_model(mod, inputs)
|
||||
expected = mod(*inputs)
|
||||
actual, (source_code,) = run_and_get_code(
|
||||
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
|
||||
|
|
@ -626,6 +639,93 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
check_quantization=True,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_qat_qconv2d(self):
|
||||
r"""
|
||||
This testcase will quantize a single Conv2d module with qat flow.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
|
||||
self.bn = torch.nn.BatchNorm2d(128)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(self.conv(x))
|
||||
|
||||
mod = M().train()
|
||||
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
|
||||
|
||||
# Totally pattern_matcher_count 4,
|
||||
# pattern_matcher_nodes 17
|
||||
# 1. pair of to_int8 and to_fp32 at conv input matched in pointless_convert pass
|
||||
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
|
||||
# 2. dequant-conv pattern matched in quantization weight prepack
|
||||
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# 3. pair of to_int8 and to_fp32 at conv output matched in pointless_convert pass
|
||||
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
|
||||
# 4. Quantization fusion in post-grad fusion pass
|
||||
# [qconv2d_pointwise_default, div_1, round_2, add_1,
|
||||
# clamp_min_1, clamp_max_1, convert_element_type_2]
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
4,
|
||||
17,
|
||||
check_quantization=True,
|
||||
is_qat=True,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
def test_qat_qconv2d_relu(self):
|
||||
r"""
|
||||
This testcase will quantize Conv2d->ReLU pattern with qat flow.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
|
||||
self.unary_fn = torch.nn.ReLU()
|
||||
self.bn = torch.nn.BatchNorm2d(128)
|
||||
|
||||
def forward(self, x):
|
||||
return self.unary_fn(self.bn(self.conv(x)))
|
||||
|
||||
mod = M()
|
||||
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
|
||||
|
||||
# Totally pattern_matcher_count 4,
|
||||
# pattern_matcher_nodes 18
|
||||
# 1. pair of to_int8 and to_fp32 at conv input matched in pointless_convert pass
|
||||
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
|
||||
# 2. dequant-conv pattern matched in quantization weight prepack
|
||||
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# 3. pair of to_int8 and to_fp32 at conv output matched in pointless_convert pass
|
||||
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
|
||||
# 4. Quantization fusion in post-grad fusion pass
|
||||
# [qconv2d_pointwise_default, relu, div_1, round_2, add_1,
|
||||
# clamp_min_1, clamp_max_1, convert_element_type_2]
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
4,
|
||||
18,
|
||||
check_quantization=True,
|
||||
is_qat=True,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
|
|||
from torch.ao.quantization.quantize_pt2e import (
|
||||
convert_pt2e,
|
||||
prepare_pt2e,
|
||||
prepare_qat_pt2e,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
|
|
@ -29,21 +30,32 @@ class Conv2DType(Enum):
|
|||
|
||||
class TestHelperModules:
|
||||
class SingleConv2dModule(torch.nn.Module):
|
||||
def __init__(self, ) -> None:
|
||||
def __init__(self, with_bn=False) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1))
|
||||
self.bn = torch.nn.BatchNorm2d(6)
|
||||
self.with_bn = with_bn
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
x = self.conv(x)
|
||||
if self.with_bn:
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
class Conv2dReLUModule(torch.nn.Module):
|
||||
def __init__(self, inplace_relu: bool = False, use_bias: bool = False) -> None:
|
||||
def __init__(self, inplace_relu: bool = False, use_bias: bool = False, with_bn=False) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias)
|
||||
self.relu = nn.ReLU(inplace=inplace_relu)
|
||||
self.bn = torch.nn.BatchNorm2d(6)
|
||||
self.with_bn = with_bn
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.conv(x))
|
||||
x = self.conv(x)
|
||||
if self.with_bn:
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class Conv2dAddModule(torch.nn.Module):
|
||||
def __init__(self,
|
||||
|
|
@ -238,8 +250,9 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
|||
quantizer,
|
||||
expected_node_occurrence,
|
||||
expected_node_list=None,
|
||||
is_qat=False,
|
||||
):
|
||||
m_eager = model.eval()
|
||||
m_eager = model.train() if is_qat else model.eval()
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
|
|
@ -248,8 +261,9 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
|||
example_inputs,
|
||||
)
|
||||
|
||||
export_model = copy.deepcopy(m)
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# QAT Model failed to deepcopy
|
||||
export_model = m if is_qat else copy.deepcopy(m)
|
||||
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
prepare_model = copy.deepcopy(m)
|
||||
|
|
@ -873,3 +887,81 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_qat_conv2d_with_quantizer_api(self):
|
||||
"""
|
||||
Test QAT pattern of conv2d_bn with X86InductorQuantizer.
|
||||
"""
|
||||
with override_quantized_engine("x86"):
|
||||
m = TestHelperModules.SingleConv2dModule(with_bn=True)
|
||||
example_inputs = (torch.randn(2, 3, 16, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
|
||||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
# BN should be folded into Conv
|
||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
is_qat=True,
|
||||
)
|
||||
|
||||
@skipIfNoX86
|
||||
def test_qat_conv2d_unary_with_quantizer_api(self):
|
||||
"""
|
||||
Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
|
||||
Currently, only relu as unary post op is supported.
|
||||
"""
|
||||
inplace_relu_list = [True, False]
|
||||
with override_quantized_engine("x86"):
|
||||
for inplace_relu in itertools.product(inplace_relu_list):
|
||||
m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, with_bn=True)
|
||||
example_inputs = (torch.randn(2, 3, 16, 16),)
|
||||
quantizer = X86InductorQuantizer().set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
|
||||
)
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the relu
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
|
||||
# note: quantize op for weights are const propagated
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
# BN should be folded into Conv
|
||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
]
|
||||
self._test_quantizer(
|
||||
m,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
node_occurrence,
|
||||
node_list,
|
||||
is_qat=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
|
||||
from torch.ao.quantization.observer import (
|
||||
HistogramObserver,
|
||||
MovingAveragePerChannelMinMaxObserver,
|
||||
PerChannelMinMaxObserver,
|
||||
PlaceholderObserver,
|
||||
)
|
||||
|
|
@ -72,6 +74,14 @@ int8_in_int8_out_ops_pt2e: Set = {
|
|||
QUANT_ANNOTATION_KEY = "quantization_annotation"
|
||||
|
||||
|
||||
def _mark_nodes_as_annotated(nodes: List[Node]):
|
||||
for node in nodes:
|
||||
if node is not None:
|
||||
if QUANT_ANNOTATION_KEY not in node.meta:
|
||||
node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation()
|
||||
node.meta[QUANT_ANNOTATION_KEY]._annotated = True
|
||||
|
||||
|
||||
def _is_node_annotated(_node):
|
||||
"""
|
||||
return True if the node is annotated, otherwise return False
|
||||
|
|
@ -156,9 +166,9 @@ def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
|
|||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_default_x86_inductor_quantization_config():
|
||||
def get_default_x86_inductor_quantization_config(is_qat: bool = False):
|
||||
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
HistogramObserver
|
||||
FusedMovingAvgObsFakeQuantize if is_qat else HistogramObserver
|
||||
)
|
||||
|
||||
# Copy from x86 default qconfig from torch/ao/quantization/qconfig.py
|
||||
|
|
@ -174,9 +184,13 @@ def get_default_x86_inductor_quantization_config():
|
|||
)
|
||||
|
||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
PerChannelMinMaxObserver
|
||||
FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver
|
||||
)
|
||||
|
||||
extra_args: Dict[str, Any] = {"eps": 2**-12}
|
||||
if is_qat:
|
||||
# Only support per channel quant for now
|
||||
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
|
||||
weight_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
quant_min=-128,
|
||||
|
|
@ -199,6 +213,7 @@ def get_default_x86_inductor_quantization_config():
|
|||
act_quantization_spec,
|
||||
weight_quantization_spec,
|
||||
bias_quantization_spec,
|
||||
is_qat,
|
||||
)
|
||||
return quantization_config
|
||||
|
||||
|
|
@ -379,6 +394,9 @@ class X86InductorQuantizer(Quantizer):
|
|||
|
||||
config = self.global_config
|
||||
|
||||
if config.is_qat:
|
||||
self._annotate_qat_conv2d_fusion_pattern(model, config)
|
||||
|
||||
# Step1: Recipe of fusion patterns like conv/linear.
|
||||
self._annotate_conv2d_fusion_pattern(model, config)
|
||||
|
||||
|
|
@ -398,6 +416,82 @@ class X86InductorQuantizer(Quantizer):
|
|||
|
||||
return model
|
||||
|
||||
def _annotate_qat_conv2d_fusion_pattern(
|
||||
self, model: torch.fx.GraphModule, config: QuantizationConfig
|
||||
):
|
||||
# Annotate QAT Specific patterns
|
||||
self._annotate_qat_conv2d_bn_unary(model, config)
|
||||
self._annotate_qat_conv2d_bn(model, config)
|
||||
|
||||
def _annotate_qat_conv2d_bn_unary(
|
||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
fused_partitions = find_sequential_partitions(
|
||||
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU]
|
||||
)
|
||||
for fused_partition in fused_partitions:
|
||||
conv_partition, bn_partition, unary_partition = fused_partition
|
||||
(
|
||||
conv_node,
|
||||
bn_output_node,
|
||||
unary_node,
|
||||
) = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, bn_partition, unary_partition]
|
||||
)
|
||||
|
||||
if (
|
||||
conv_node.op != "call_function"
|
||||
or conv_node.target != torch.ops.aten.conv2d.default
|
||||
):
|
||||
continue
|
||||
|
||||
if _is_annotated([unary_node, bn_output_node, conv_node]):
|
||||
continue
|
||||
|
||||
self._annotate_conv_node_helper(conv_node, False, quantization_config)
|
||||
unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
nodes_to_mark_annotated = list(conv_partition.nodes)
|
||||
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
|
||||
nodes_to_mark_annotated.extend(list(unary_partition.nodes))
|
||||
_mark_nodes_as_annotated(nodes_to_mark_annotated)
|
||||
|
||||
def _annotate_qat_conv2d_bn(
|
||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||
) -> None:
|
||||
fused_partitions = find_sequential_partitions(
|
||||
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
|
||||
)
|
||||
for fused_partition in fused_partitions:
|
||||
conv_partition, bn_partition = fused_partition
|
||||
conv_node, bn_output_node = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, bn_partition]
|
||||
)
|
||||
|
||||
if (
|
||||
conv_node.op != "call_function"
|
||||
or conv_node.target != torch.ops.aten.conv2d.default
|
||||
):
|
||||
continue
|
||||
|
||||
if _is_annotated([bn_output_node, conv_node]):
|
||||
continue
|
||||
|
||||
self._annotate_conv_node_helper(conv_node, False, quantization_config)
|
||||
bn_output_node.meta[
|
||||
QUANT_ANNOTATION_KEY
|
||||
] = _X86InductorQuantizationAnnotation(
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
)
|
||||
nodes_to_mark_annotated = list(conv_partition.nodes)
|
||||
nodes_to_mark_annotated.extend(list(bn_partition.nodes))
|
||||
_mark_nodes_as_annotated(nodes_to_mark_annotated)
|
||||
|
||||
def _annotate_conv2d_fusion_pattern(
|
||||
self, model: torch.fx.GraphModule, config: QuantizationConfig
|
||||
):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user