mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944 Approved by: https://github.com/Skylion007
2854 lines
115 KiB
Python
2854 lines
115 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
# ruff: noqa: F841
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.ao.quantization import observer, ObserverOrFakeQuantize, QConfigMapping
|
|
from torch.ao.quantization.qconfig import (
|
|
default_per_channel_symmetric_qnnpack_qconfig,
|
|
float_qparams_weight_only_qconfig,
|
|
per_channel_weight_observer_range_neg_127_to_127,
|
|
QConfig,
|
|
weight_observer_range_neg_127_to_127,
|
|
)
|
|
from torch.ao.quantization.quantize_pt2e import (
|
|
convert_pt2e,
|
|
prepare_pt2e,
|
|
prepare_qat_pt2e,
|
|
)
|
|
from torch.ao.quantization.quantizer import (
|
|
DerivedQuantizationSpec,
|
|
EdgeOrNode,
|
|
FixedQParamsQuantizationSpec,
|
|
QuantizationAnnotation,
|
|
QuantizationSpec,
|
|
Quantizer,
|
|
SharedQuantizationSpec,
|
|
)
|
|
from torch.ao.quantization.quantizer.composable_quantizer import ( # noqa: F811
|
|
ComposableQuantizer,
|
|
)
|
|
from torch.ao.quantization.quantizer.embedding_quantizer import ( # noqa: F811
|
|
EmbeddingQuantizer,
|
|
)
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
get_symmetric_quantization_config,
|
|
XNNPACKQuantizer,
|
|
)
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
|
OP_TO_ANNOTATOR,
|
|
QuantizationConfig,
|
|
)
|
|
from torch.export import export_for_training
|
|
from torch.fx import Node
|
|
from torch.testing._internal.common_quantization import (
|
|
NodeSpec as ns,
|
|
PT2EQuantizationTestCase,
|
|
skipIfNoQNNPACK,
|
|
TestHelperModules,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skipIfHpu,
|
|
TemporaryFileName,
|
|
TEST_CUDA,
|
|
TEST_HPU,
|
|
)
|
|
|
|
|
|
@skipIfNoQNNPACK
|
|
class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|
def test_simple_quantizer(self):
|
|
# TODO: use OP_TO_ANNOTATOR
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.conv2d.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
bias = node.args[2]
|
|
assert isinstance(bias, Node)
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
bias_qspec = QuantizationSpec(
|
|
dtype=torch.float32,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
node_occurrence = {
|
|
# two for input of the first conv, one for output for the first conv
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
|
}
|
|
node_list = [
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.aten.conv2d.default,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
]
|
|
self._test_quantizer(
|
|
TestHelperModules.ConvWithBNRelu(relu=False, bn=False),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_occurrence,
|
|
node_list,
|
|
)
|
|
|
|
def test_wo_annotate_conv_output_quantizer(self):
|
|
# TODO: use OP_TO_ANNOTATOR
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
bias_qspec = QuantizationSpec(
|
|
dtype=torch.float32,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
|
)
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.conv2d.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
bias = node.args[2]
|
|
assert isinstance(bias, Node)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
m = torch.nn.Conv2d(2, 2, 1)
|
|
x = torch.rand(1, 2, 14, 14)
|
|
example_inputs = (x,)
|
|
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
|
# Ensure the conv has no observer inserted at output
|
|
node_occurrence = {
|
|
# two for input of conv
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 1,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 2,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.conv2d.default),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
def test_max_pool2d_quantizer(self):
|
|
# TODO: use OP_TO_ANNOTATOR
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
bias_qspec = QuantizationSpec(
|
|
dtype=torch.float32,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
|
)
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.conv2d.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
bias = node.args[2]
|
|
assert isinstance(bias, Node)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
_annotated=True,
|
|
)
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.max_pool2d.default
|
|
):
|
|
maxpool_node = node
|
|
input_act = maxpool_node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
maxpool_node.meta[
|
|
"quantization_annotation"
|
|
] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
},
|
|
output_qspec=SharedQuantizationSpec(
|
|
(input_act, maxpool_node)
|
|
),
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
m = TestHelperModules.ConvMaxPool2d()
|
|
x = torch.rand(1, 2, 14, 14)
|
|
example_inputs = (x,)
|
|
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
|
node_occurrence = {
|
|
# two for input of conv
|
|
# one for input of maxpool
|
|
# one for output of maxpool
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 3,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 4,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.conv2d.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.max_pool2d.default),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
def test_derived_qspec(self):
|
|
# TODO: use OP_TO_ANNOTATOR
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.conv2d.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
bias = node.args[2]
|
|
assert isinstance(bias, Node)
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
|
|
def derive_qparams_fn(
|
|
obs_or_fqs: list[ObserverOrFakeQuantize],
|
|
) -> tuple[Tensor, Tensor]:
|
|
assert (
|
|
len(obs_or_fqs) == 2
|
|
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
|
|
act_obs_or_fq = obs_or_fqs[0]
|
|
weight_obs_or_fq = obs_or_fqs[1]
|
|
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
|
|
(
|
|
weight_scale,
|
|
weight_zp,
|
|
) = weight_obs_or_fq.calculate_qparams()
|
|
return torch.tensor([act_scale * weight_scale]).to(
|
|
torch.float32
|
|
), torch.tensor([0]).to(torch.int32)
|
|
|
|
bias_qspec = DerivedQuantizationSpec(
|
|
derived_from=[(input_act, node), (weight, node)],
|
|
derive_qparams_fn=derive_qparams_fn,
|
|
dtype=torch.int32,
|
|
quant_min=-(2**31),
|
|
quant_max=2**31 - 1,
|
|
qscheme=torch.per_tensor_symmetric,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval()
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
|
|
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
|
node_occurrence = {
|
|
# input, weight, bias, output for the conv
|
|
# note: quantize op for weight and bias are const propagated
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 4,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.conv2d.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
def test_derived_qspec_per_channel(self):
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.conv2d.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
bias = node.args[2]
|
|
assert isinstance(bias, Node)
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_channel_affine,
|
|
is_dynamic=False,
|
|
ch_axis=0,
|
|
observer_or_fake_quant_ctr=observer.default_per_channel_weight_observer,
|
|
)
|
|
|
|
def derive_qparams_fn(
|
|
obs_or_fqs: list[ObserverOrFakeQuantize],
|
|
) -> tuple[Tensor, Tensor]:
|
|
assert (
|
|
len(obs_or_fqs) == 1
|
|
), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
|
|
weight_obs_or_fq = obs_or_fqs[0]
|
|
(
|
|
weight_scale,
|
|
weight_zp,
|
|
) = weight_obs_or_fq.calculate_qparams()
|
|
return weight_scale, torch.zeros_like(weight_scale)
|
|
|
|
bias_qspec = DerivedQuantizationSpec(
|
|
derived_from=[(weight, node)],
|
|
derive_qparams_fn=derive_qparams_fn,
|
|
dtype=torch.int32,
|
|
quant_min=-(2**31),
|
|
quant_max=2**31 - 1,
|
|
qscheme=torch.per_channel_symmetric,
|
|
ch_axis=0,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval()
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
|
|
m = self._quantize(m, BackendAQuantizer(), example_inputs)
|
|
|
|
node_occurrence = {
|
|
# input, output for the conv
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 2,
|
|
# weight and bias for conv
|
|
# note: quantize op for weight and bias are const propagated
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default
|
|
): 0,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
): 2,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
),
|
|
ns.call_function(torch.ops.aten.conv2d.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
def test_fixed_qparams_qspec_ptq(self):
|
|
self._test_fixed_qparams_qspec(is_qat=False)
|
|
|
|
# TODO: refactor and move this to test_quantize_pt2_qat.py
|
|
def test_fixed_qparams_qspec_qat(self):
|
|
self._test_fixed_qparams_qspec(is_qat=True)
|
|
|
|
def _test_fixed_qparams_qspec(self, is_qat: bool):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.sigmoid(x)
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.sigmoid.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
act_qspec = FixedQParamsQuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
scale=1.0 / 256.0,
|
|
zero_point=0,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
|
|
m = self._quantize(m, BackendAQuantizer(), example_inputs, is_qat)
|
|
fixed_scale = 1.0 / 256.0
|
|
fixed_zero_point = 0
|
|
for n in m.graph.nodes:
|
|
if n.op == "call_function":
|
|
if (
|
|
n.target
|
|
== torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
):
|
|
scale_0 = n.args[1]
|
|
zero_point_0 = n.args[2]
|
|
if (
|
|
n.target
|
|
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
):
|
|
scale_1 = n.args[1]
|
|
zero_point_1 = n.args[2]
|
|
self.assertEqual(scale_0, fixed_scale)
|
|
self.assertEqual(zero_point_0, fixed_zero_point)
|
|
self.assertEqual(scale_1, fixed_scale)
|
|
self.assertEqual(zero_point_1, fixed_zero_point)
|
|
node_occurrence = {
|
|
# two for input of the first conv, one for output for the first conv
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 2,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.sigmoid.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
def test_fixed_qparams_qspec_observer_dedup(self):
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
act_qspec = FixedQParamsQuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
scale=1.0 / 256.0,
|
|
zero_point=0,
|
|
)
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.sigmoid.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
elif (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.add.Tensor
|
|
):
|
|
input_act0 = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
input_act1 = node.args[1]
|
|
assert isinstance(input_act, Node)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act0: act_qspec,
|
|
input_act1: act_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.sigmoid(x) + y
|
|
|
|
def example_inputs(self):
|
|
return (
|
|
torch.randn(1, 3, 5, 5),
|
|
torch.randn(1, 3, 5, 5),
|
|
)
|
|
|
|
m = M().eval()
|
|
example_inputs = m.example_inputs()
|
|
m = self._quantize(m, BackendAQuantizer(), example_inputs, is_qat=False)
|
|
|
|
node_occurrence = {
|
|
# two for input of the first conv, one for output for the first conv
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 4,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 4,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.sigmoid.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.add.Tensor),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
def test_shared_qspec(self):
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.conv2d.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
bias = node.args[2]
|
|
assert isinstance(bias, Node)
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
bias_qspec = QuantizationSpec(
|
|
dtype=torch.float32,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
elif node.target is torch.ops.aten.cat.default:
|
|
cat_node = node
|
|
input_nodes = cat_node.args[0]
|
|
first_input_node = input_nodes[0]
|
|
input_qspec_map = {}
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
input_qspec_map[first_input_node] = act_qspec
|
|
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
|
|
(first_input_node, cat_node)
|
|
)
|
|
for input_node in input_nodes[1:]:
|
|
input_qspec_map[
|
|
input_node
|
|
] = share_qparams_with_input_act0_qspec
|
|
|
|
cat_node.meta[
|
|
"quantization_annotation"
|
|
] = QuantizationAnnotation(
|
|
input_qspec_map=input_qspec_map,
|
|
output_qspec=share_qparams_with_input_act0_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
m = TestHelperModules.Conv2dWithCat().eval()
|
|
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
|
|
|
|
# program capture
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
m = prepare_pt2e(m, BackendAQuantizer())
|
|
# make sure the two observers for input are shared
|
|
conv_output_obs = []
|
|
for n in m.graph.nodes:
|
|
if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default:
|
|
conv_output_obs.append(getattr(m, next(iter(n.users)).target))
|
|
if n.op == "call_function" and n.target == torch.ops.aten.cat.default:
|
|
inputs = n.args[0]
|
|
input0 = inputs[0]
|
|
input1 = inputs[1]
|
|
assert input0.op == "call_module"
|
|
assert input1.op == "call_module"
|
|
obs_ins0 = getattr(m, input0.target)
|
|
obs_ins1 = getattr(m, input1.target)
|
|
assert obs_ins0 == obs_ins1
|
|
assert (
|
|
len(conv_output_obs) == 2
|
|
), "expecting two observer that follows conv2d ops"
|
|
# checking that the output observers for the two convs are shared as well
|
|
assert conv_output_obs[0] == conv_output_obs[1]
|
|
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
|
|
node_occurrence = {
|
|
# two for input of the first conv, one for output for the first conv
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 5,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 7,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.cat.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
def _test_transitive_sharing_with_cat_helper(self, quantizer):
|
|
m = TestHelperModules.Conv2dWithTwoCat().eval()
|
|
example_inputs = (
|
|
torch.randn(1, 3, 5, 5),
|
|
torch.randn(1, 3, 5, 5),
|
|
torch.randn(1, 6, 3, 3),
|
|
torch.randn(1, 6, 3, 3),
|
|
)
|
|
|
|
# program capture
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
m = prepare_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
# make sure the two input observers and output are shared
|
|
conv_output_obs = []
|
|
for n in m.graph.nodes:
|
|
if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default:
|
|
conv_output_obs.append(getattr(m, next(iter(n.users)).target))
|
|
if n.op == "call_function" and n.target == torch.ops.aten.cat.default:
|
|
inputs = n.args[0]
|
|
input0 = inputs[0]
|
|
input1 = inputs[1]
|
|
assert input0.op == "call_module"
|
|
assert input1.op == "call_module"
|
|
obs_ins0 = getattr(m, input0.target)
|
|
obs_ins1 = getattr(m, input1.target)
|
|
assert obs_ins0 == obs_ins1
|
|
|
|
output_obs = next(iter(n.users))
|
|
assert output_obs.op == "call_module"
|
|
obs_ins2 = getattr(m, output_obs.target)
|
|
assert obs_ins0 == obs_ins2, "input observer does not match output"
|
|
|
|
assert (
|
|
len(conv_output_obs) == 2
|
|
), "expecting two observer that follows conv2d ops"
|
|
# checking that the output observers for the two convs are shared as well
|
|
assert conv_output_obs[0] == conv_output_obs[1]
|
|
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
|
|
node_occurrence = {
|
|
# two for input of the first conv, one for output for the first conv
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 7,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 9,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.cat.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.cat.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
|
)
|
|
|
|
def test_shared_qspec_transitivity(self):
|
|
"""This tests the transitivity of SharedQuantizationSpec, that is
|
|
if A is shared with B, B is shared with C, then C should be shared with A as well
|
|
|
|
x1 -> conv1 -> cat1 -----> cat2
|
|
x2 -> conv2 -/ /
|
|
x3 -> add /
|
|
x4 /
|
|
|
|
both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor
|
|
so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same
|
|
sharing group after transitive sharing
|
|
"""
|
|
|
|
# TODO: refactor this to a common util
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.conv2d.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
bias = node.args[2]
|
|
assert isinstance(bias, Node)
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
bias_qspec = QuantizationSpec(
|
|
dtype=torch.float32,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
elif node.target is torch.ops.aten.cat.default:
|
|
cat_node = node
|
|
input_nodes = cat_node.args[0]
|
|
first_input_node = input_nodes[0]
|
|
input_qspec_map = {}
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
input_qspec_map[first_input_node] = act_qspec
|
|
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
|
|
(first_input_node, cat_node)
|
|
)
|
|
for input_node in input_nodes[1:]:
|
|
input_qspec_map[
|
|
input_node
|
|
] = share_qparams_with_input_act0_qspec
|
|
|
|
cat_node.meta[
|
|
"quantization_annotation"
|
|
] = QuantizationAnnotation(
|
|
input_qspec_map=input_qspec_map,
|
|
output_qspec=share_qparams_with_input_act0_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
|
|
|
|
def test_shared_qspec_transitivity_case_2(self):
|
|
"""This tests the transitivity of SharedQuantizationSpec, that is
|
|
if A is shared with B, B is shared with C, then C should be shared with A as well
|
|
|
|
x1 -> conv1 -> cat1 -----> cat2
|
|
x2 -> conv2 -/ /
|
|
x3 -> add /
|
|
x4 /
|
|
|
|
both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor
|
|
so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same
|
|
sharing group after transitive sharing
|
|
|
|
the difference is that for this one, all edges and nodes are shared with the second input edge of cat
|
|
instead of the first input edge of cat as in previous example
|
|
"""
|
|
|
|
# TODO: refactor this to a common util
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.conv2d.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
bias = node.args[2]
|
|
assert isinstance(bias, Node)
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
bias_qspec = QuantizationSpec(
|
|
dtype=torch.float32,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
elif node.target is torch.ops.aten.cat.default:
|
|
cat_node = node
|
|
input_nodes = cat_node.args[0]
|
|
first_input_node = input_nodes[0]
|
|
second_input_node = input_nodes[1]
|
|
input_qspec_map = {}
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
input_qspec_map[second_input_node] = act_qspec
|
|
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
|
|
(second_input_node, cat_node)
|
|
)
|
|
input_qspec_map[
|
|
first_input_node
|
|
] = share_qparams_with_input_act1_qspec
|
|
|
|
cat_node.meta[
|
|
"quantization_annotation"
|
|
] = QuantizationAnnotation(
|
|
input_qspec_map=input_qspec_map,
|
|
output_qspec=share_qparams_with_input_act1_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
|
|
|
|
def test_allow_implicit_sharing(self):
|
|
"""This tests the allow_transitive_sharing flag of QuantizationAnnotation, that is
|
|
if a node is configured with allow_implicit_sharing=False, we will not have implicit sharing
|
|
for node and (node, consumer) even they refer to the same Tensor
|
|
|
|
x1 -> add1 -----> add3
|
|
x2 -/ /
|
|
x3 -> add2 /
|
|
x4 -/
|
|
|
|
all add has shared input and output, and second input is using shared quantization spec pointing
|
|
to first input, but we set allow_implicit_sharing to False for all add nodes so input and output of add1,
|
|
add2 and add3 will each belong to one sharing group, so we'll have:
|
|
|
|
x1 -> obs1 -> add1 -> obs1 -> obs3--> add3 -> obs3
|
|
x2 -> obs1 -/ /
|
|
x3 -> obs2 -> add2 -> obs2 -> obs3
|
|
x4 -> obs2 -/
|
|
"""
|
|
|
|
# TODO: refactor this to a common util
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if node.target is torch.ops.aten.add.Tensor:
|
|
add_node = node
|
|
first_input_node = add_node.args[0]
|
|
second_input_node = add_node.args[1]
|
|
input_qspec_map = {}
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
input_qspec_map[second_input_node] = act_qspec
|
|
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
|
|
(second_input_node, add_node)
|
|
)
|
|
input_qspec_map[
|
|
first_input_node
|
|
] = share_qparams_with_input_act1_qspec
|
|
|
|
add_node.meta[
|
|
"quantization_annotation"
|
|
] = QuantizationAnnotation(
|
|
input_qspec_map=input_qspec_map,
|
|
output_qspec=share_qparams_with_input_act1_qspec,
|
|
allow_implicit_sharing=False,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
m = TestHelperModules.ThreeAdd().eval()
|
|
example_inputs = (
|
|
torch.randn(1, 3, 5, 5),
|
|
torch.randn(1, 3, 5, 5),
|
|
torch.randn(1, 3, 5, 5),
|
|
torch.randn(1, 3, 5, 5),
|
|
)
|
|
|
|
# program capture
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
quantizer = BackendAQuantizer()
|
|
m = prepare_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
observers = []
|
|
for n in m.graph.nodes:
|
|
if n.target == torch.ops.aten.add.Tensor:
|
|
input_obs1 = getattr(m, n.args[0].target)
|
|
input_obs2 = getattr(m, n.args[1].target)
|
|
output_obs = getattr(m, next(iter(n.users)).target)
|
|
self.assertIs(input_obs1, input_obs2)
|
|
self.assertIs(input_obs1, output_obs)
|
|
observers.append(input_obs1)
|
|
assert len(observers) == 3
|
|
self.assertIsNot(observers[0], observers[1])
|
|
self.assertIsNot(observers[0], observers[2])
|
|
self.assertIsNot(observers[1], observers[2])
|
|
|
|
@skipIfHpu
|
|
@parametrize("dtype", (torch.float32, torch.bfloat16))
|
|
@parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
|
|
def test_quantization_dtype(self, dtype, quant_dtype):
|
|
class DtypeActQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
info_fun = torch.iinfo if quant_dtype == torch.int16 else torch.finfo
|
|
activate_qspec = QuantizationSpec(
|
|
dtype=quant_dtype,
|
|
quant_min=int(info_fun(quant_dtype).min),
|
|
quant_max=int(info_fun(quant_dtype).max),
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
int8_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_symmetric,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
quantization_config = QuantizationConfig(
|
|
input_activation=activate_qspec,
|
|
weight=int8_qspec,
|
|
bias=None,
|
|
output_activation=activate_qspec,
|
|
)
|
|
OP_TO_ANNOTATOR["conv"](model, quantization_config)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dtype):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
quantizer = DtypeActQuantizer()
|
|
node_occurrence = {
|
|
# one for input of the first conv, one for output for the first conv
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
|
}
|
|
node_list = [
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.aten.conv2d.default,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
]
|
|
example_inputs = (torch.randn(1, 3, 3, 3, dtype=dtype),)
|
|
m = self._test_quantizer(
|
|
M(dtype).eval(),
|
|
example_inputs,
|
|
quantizer,
|
|
node_occurrence,
|
|
node_list,
|
|
)
|
|
|
|
def verify_quant_dequant_iotypes(m):
|
|
for node in m.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target.__name__ == "dequantize_per_tensor.default"
|
|
):
|
|
# Check dequantize node
|
|
dequant_node = node
|
|
dequant_in_dtype = dequant_node.args[5]
|
|
dequant_out_dtype = torch.float32
|
|
if "out_dtype" in dequant_node.kwargs:
|
|
dequant_out_dtype = dequant_node.kwargs["out_dtype"]
|
|
|
|
# Check preceding quantize node
|
|
# Depending on fold_quantize flag, quantize node may be absent
|
|
quant_node = node.args[0]
|
|
if (
|
|
quant_node.op == "call_function"
|
|
and quant_node.target.__name__ == "quantize_per_tensor.default"
|
|
):
|
|
quant_in_dtype = torch.float32
|
|
if "val" in quant_node.args[0].meta:
|
|
quant_in_dtype = quant_node.args[0].meta["val"].dtype
|
|
quant_out_dtype = quant_node.args[5]
|
|
assert (
|
|
quant_in_dtype == dequant_out_dtype
|
|
and quant_out_dtype == dequant_in_dtype
|
|
), "quant dequant io dtype check failed!"
|
|
|
|
verify_quant_dequant_iotypes(m)
|
|
|
|
def test_input_edge_sanity_check(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 6
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.add.Tensor
|
|
):
|
|
input_act1 = node.args[0]
|
|
# this is a constant, so not valid for annotation
|
|
input_act2 = node.args[1]
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act1: act_qspec,
|
|
# this is supposed to error out
|
|
input_act2: act_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
m = M().eval()
|
|
example_inputs = torch.randn(1, 2, 3, 3)
|
|
m = export_for_training(m, (example_inputs,), strict=True).module()
|
|
with self.assertRaises(Exception):
|
|
m = prepare_pt2e(m, BackendAQuantizer())
|
|
|
|
def test_fold_quantize(self):
|
|
"""Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)"""
|
|
m = self._get_pt2e_quantized_linear()
|
|
node_occurrence = {
|
|
# quantize op for weight node is folded
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 3,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_fold_quantize_per_channel(self):
|
|
"""Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)"""
|
|
m = self._get_pt2e_quantized_linear(is_per_channel=True)
|
|
node_occurrence = {
|
|
# quantize op for weight node is folded
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
): 1,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 2,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_dont_fold_other_constant(self):
|
|
"""Make sure the constant propagation does not apply to things unrelated to
|
|
quantization
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2))
|
|
|
|
def forward(self, x):
|
|
t = self.dont_fold_me.t()
|
|
return self.linear(x) + t
|
|
|
|
quantizer = XNNPACKQuantizer()
|
|
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
|
# only quantize linear, so add is not quantized and the constant Tensor
|
|
# should not be folded
|
|
quantizer.set_module_type(torch.nn.Linear, operator_config)
|
|
example_inputs = (torch.randn(2, 2),)
|
|
m = M().eval()
|
|
m = self._quantize(m, quantizer, example_inputs)
|
|
node_occurrence = {
|
|
# quantize op for weight node is folded
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 3,
|
|
# transpose op not folded
|
|
ns.call_function(torch.ops.aten.t.default): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_fold_all_ops_before_quantize(self):
|
|
"""Test folding all ops that's before quantized operator:
|
|
Before:
|
|
get_attr(weight) -> transpose -> quantize -> dequantize
|
|
After:
|
|
get_attr(folded_weight) -> dequantize
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.randn(2, 2)
|
|
|
|
def forward(self, x):
|
|
t = self.weight.t()
|
|
return torch.nn.functional.linear(x, t)
|
|
|
|
quantizer = XNNPACKQuantizer()
|
|
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
|
quantizer.set_global(operator_config)
|
|
example_inputs = (torch.randn(2, 2),)
|
|
m = M().eval()
|
|
m = self._quantize(m, quantizer, example_inputs)
|
|
node_occurrence = {
|
|
# quantize op for weight node is folded
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 2,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 3,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_constant_prop_preserve_metadata(self):
|
|
"""Test to make sure the get_attr node for const propagated weight Tensor gets the correct
|
|
metadata (from original get_attr node from weight)
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
quantizer = XNNPACKQuantizer()
|
|
operator_config = get_symmetric_quantization_config()
|
|
quantizer.set_global(operator_config)
|
|
example_inputs = (torch.randn(2, 2),)
|
|
m = M().eval()
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
weight_meta = None
|
|
for n in m.graph.nodes:
|
|
if (
|
|
n.op == "get_attr"
|
|
and next(iter(n.users)).target == torch.ops.aten.linear.default
|
|
):
|
|
weight_meta = n.meta
|
|
break
|
|
assert weight_meta is not None, "Expect to find metadata for weight node"
|
|
|
|
m = prepare_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
|
|
for n in m.graph.nodes:
|
|
if n.op == "get_attr" and "frozen_param" in n.target:
|
|
for key in n.meta:
|
|
self.assertEqual(n.meta[key], weight_meta[key])
|
|
|
|
def test_save_load(self):
|
|
"""Test save/load a quantized model"""
|
|
m = self._get_pt2e_quantized_linear()
|
|
example_inputs = (torch.randn(2, 2),)
|
|
ref_res = m(*example_inputs)
|
|
|
|
with TemporaryFileName() as fname:
|
|
# serialization
|
|
quantized_ep = torch.export.export(m, example_inputs, strict=True)
|
|
torch.export.save(quantized_ep, fname)
|
|
# deserialization
|
|
loaded_ep = torch.export.load(fname)
|
|
loaded_quantized_model = loaded_ep.module()
|
|
res = loaded_quantized_model(*example_inputs)
|
|
self.assertEqual(ref_res, res)
|
|
|
|
def test_composable_quantizer_throw(self):
|
|
class BadQuantizer(Quantizer):
|
|
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for n in gm.graph.nodes:
|
|
n.meta["quantization_annotation"] = None
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
quantizer = XNNPACKQuantizer()
|
|
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
|
quantizer.set_global(quantization_config)
|
|
bad_quantizer = BadQuantizer()
|
|
composable_quantizer = ComposableQuantizer([quantizer, bad_quantizer])
|
|
m_eager = TestHelperModules.ConvLinearWPermute().eval()
|
|
example_inputs = (torch.randn(2, 3, 4, 4),)
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: self._test_quantizer(
|
|
m_eager, example_inputs, composable_quantizer, {}
|
|
),
|
|
)
|
|
|
|
def test_transform_for_annotation(self):
|
|
class TestQuantizer(Quantizer):
|
|
def transform_for_annotation(
|
|
self, model: torch.fx.GraphModule
|
|
) -> torch.fx.GraphModule:
|
|
# Make a copy of the graph to ensure that we are using the
|
|
# return value of this function.
|
|
graph = torch.fx.Graph()
|
|
graph.graph_copy(model.graph, {})
|
|
for n in graph.nodes:
|
|
if n.target == torch.ops.aten.add.Tensor:
|
|
n.target = torch.ops.aten.mul.Tensor
|
|
model = torch.fx.GraphModule(model, graph)
|
|
return model
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
return model
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 3
|
|
|
|
m = M().eval()
|
|
quantizer = TestQuantizer()
|
|
example_inputs = (torch.randn(1, 2, 3, 3),)
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
m = prepare_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
node_occurrence = {
|
|
ns.call_function(torch.ops.aten.add.Tensor): 0,
|
|
ns.call_function(torch.ops.aten.mul.Tensor): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_composable_quantizer_transform_for_annotation(self):
|
|
class TestQuantizer1(Quantizer):
|
|
def transform_for_annotation(
|
|
self, model: torch.fx.GraphModule
|
|
) -> torch.fx.GraphModule:
|
|
for n in model.graph.nodes:
|
|
if n.target == torch.ops.aten.add.Tensor:
|
|
n.target = torch.ops.aten.mul.Tensor
|
|
return model
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
return model
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class TestQuantizer2(Quantizer):
|
|
def transform_for_annotation(
|
|
self, model: torch.fx.GraphModule
|
|
) -> torch.fx.GraphModule:
|
|
for n in model.graph.nodes:
|
|
if n.target == torch.ops.aten.sub.Tensor:
|
|
n.target = torch.ops.aten.div.Tensor
|
|
return model
|
|
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
return model
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
return x + y - z
|
|
|
|
m = M().eval()
|
|
quantizer = ComposableQuantizer([TestQuantizer1(), TestQuantizer2()])
|
|
example_inputs = (
|
|
torch.randn(1, 2, 3, 3),
|
|
torch.randn(1, 2, 3, 3),
|
|
torch.randn(1, 2, 3, 3),
|
|
)
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
m = prepare_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
node_occurrence = {
|
|
ns.call_function(torch.ops.aten.add.Tensor): 0,
|
|
ns.call_function(torch.ops.aten.sub.Tensor): 0,
|
|
ns.call_function(torch.ops.aten.mul.Tensor): 1,
|
|
ns.call_function(torch.ops.aten.div.Tensor): 1,
|
|
}
|
|
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
|
|
|
def test_embedding_quantizer(self):
|
|
m_eager = TestHelperModules.EmbeddingModule().eval()
|
|
indices = torch.tensor(
|
|
[
|
|
9,
|
|
6,
|
|
5,
|
|
7,
|
|
8,
|
|
8,
|
|
9,
|
|
2,
|
|
8,
|
|
6,
|
|
6,
|
|
9,
|
|
1,
|
|
6,
|
|
8,
|
|
8,
|
|
3,
|
|
2,
|
|
3,
|
|
6,
|
|
3,
|
|
6,
|
|
5,
|
|
7,
|
|
0,
|
|
8,
|
|
4,
|
|
6,
|
|
5,
|
|
8,
|
|
2,
|
|
3,
|
|
]
|
|
)
|
|
example_inputs = (indices,)
|
|
|
|
quantizer = EmbeddingQuantizer()
|
|
node_occurrence = {
|
|
# note: quantize op for weights are const propagated
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
|
}
|
|
node_list = [
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
torch.ops.aten.embedding.default,
|
|
]
|
|
# Compare against short term workflow
|
|
# cannot compare against fx quant because of the numerical differences coming
|
|
# from quantize and dequantize ops
|
|
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
|
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
|
qconfig_mapping = qconfig_mapping.set_object_type(
|
|
torch.nn.Embedding, float_qparams_weight_only_qconfig
|
|
)
|
|
self._test_quantizer(
|
|
m_eager,
|
|
example_inputs,
|
|
quantizer,
|
|
node_occurrence,
|
|
node_list,
|
|
True,
|
|
qconfig_mapping,
|
|
)
|
|
|
|
def test_composable_quantizer_linear_conv(self):
|
|
dynamic_quantizer = XNNPACKQuantizer()
|
|
quantization_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=False, is_dynamic=True
|
|
)
|
|
dynamic_quantizer.set_global(quantization_config_dynamic)
|
|
static_quantizer = XNNPACKQuantizer()
|
|
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
|
static_quantizer.set_global(quantization_config)
|
|
# Note that dynamic quantization must be applied first here.
|
|
# this is because static quantizer also quantizes linear with static qspec
|
|
# and if we apply static_quantizer first then dynamic_quantizer cannot be applied
|
|
composable_quantizer = ComposableQuantizer(
|
|
[dynamic_quantizer, static_quantizer]
|
|
)
|
|
m_eager = TestHelperModules.ConvLinearWPermute().eval()
|
|
|
|
node_occurrence = {
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
|
# note: quantize op for weights are const propagated
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
|
# note: quantize op for weights are const propagated
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
|
}
|
|
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
|
|
dtype=torch.qint8,
|
|
qscheme=torch.per_tensor_affine,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
eps=2**-12,
|
|
is_dynamic=True,
|
|
)
|
|
dynamic_qconfig = QConfig(
|
|
activation=act_affine_quant_obs,
|
|
weight=weight_observer_range_neg_127_to_127,
|
|
)
|
|
# Test with 2d inputs
|
|
example_inputs = (torch.randn(2, 3, 4, 4),)
|
|
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
|
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
|
qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig)
|
|
# Had to turn off check against fx because fx quant workflow does not seem
|
|
# to propagate observers for permute node for this model.
|
|
# Suprisingly it does propagate it for EmbeddingConvLinearModule
|
|
# TODO: Figure out the right behavior for propagation
|
|
self._test_quantizer(
|
|
m_eager,
|
|
example_inputs,
|
|
composable_quantizer,
|
|
node_occurrence,
|
|
[],
|
|
False,
|
|
qconfig_mapping,
|
|
)
|
|
|
|
def test_embedding_conv_linear_quantization(self):
|
|
m_eager = TestHelperModules.EmbeddingConvLinearModule().eval()
|
|
indices = torch.tensor(
|
|
[
|
|
9,
|
|
6,
|
|
5,
|
|
7,
|
|
8,
|
|
8,
|
|
9,
|
|
2,
|
|
8,
|
|
6,
|
|
6,
|
|
9,
|
|
1,
|
|
6,
|
|
8,
|
|
8,
|
|
3,
|
|
2,
|
|
3,
|
|
6,
|
|
3,
|
|
6,
|
|
5,
|
|
7,
|
|
0,
|
|
8,
|
|
4,
|
|
6,
|
|
5,
|
|
8,
|
|
2,
|
|
3,
|
|
]
|
|
)
|
|
indices = torch.unsqueeze(indices, 0)
|
|
example_inputs = (indices,)
|
|
|
|
embedding_quantizer = EmbeddingQuantizer()
|
|
dynamic_quantizer = XNNPACKQuantizer()
|
|
quantization_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_dynamic=True
|
|
)
|
|
dynamic_quantizer.set_global(quantization_config_dynamic)
|
|
static_quantizer = XNNPACKQuantizer()
|
|
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
|
static_quantizer.set_global(quantization_config)
|
|
composed_quantizer = ComposableQuantizer(
|
|
[embedding_quantizer, dynamic_quantizer, static_quantizer]
|
|
)
|
|
|
|
act_affine_quant_obs = observer.PlaceholderObserver.with_args(
|
|
dtype=torch.qint8,
|
|
qscheme=torch.per_tensor_affine,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
eps=2**-12,
|
|
is_dynamic=True,
|
|
)
|
|
dynamic_qconfig = QConfig(
|
|
activation=act_affine_quant_obs,
|
|
weight=per_channel_weight_observer_range_neg_127_to_127,
|
|
)
|
|
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
|
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
|
qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig)
|
|
qconfig_mapping = qconfig_mapping.set_object_type(
|
|
torch.nn.Embedding, float_qparams_weight_only_qconfig
|
|
)
|
|
|
|
node_occurrence = {
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
|
|
# note: quantize op for weights are const propagated
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
|
}
|
|
self._test_quantizer(
|
|
m_eager,
|
|
example_inputs,
|
|
composed_quantizer,
|
|
node_occurrence,
|
|
[],
|
|
True,
|
|
qconfig_mapping,
|
|
)
|
|
|
|
def _get_node(self, m: torch.fx.GraphModule, target: torch._ops.OpOverload):
|
|
"""
|
|
Return the first node matching the specified target, throwing an exception
|
|
if no such batch norm node is found.
|
|
"""
|
|
for n in m.graph.nodes:
|
|
if n.target == target:
|
|
return n
|
|
raise ValueError("Did not find node with target ", target)
|
|
|
|
def _test_move_exported_model_dropout(self, inplace: bool):
|
|
"""
|
|
Test switching dropout behavior between train and eval modes using
|
|
`move_exported_model_to_eval` and `move_exported_model_to_train` APIs.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.dropout = torch.nn.Dropout(0.5, inplace=inplace)
|
|
|
|
def forward(self, x):
|
|
return self.dropout(x)
|
|
|
|
example_inputs = (torch.randn(1),)
|
|
m = M().train()
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
if inplace:
|
|
target = torch.ops.aten.dropout_.default
|
|
else:
|
|
target = torch.ops.aten.dropout.default
|
|
|
|
# Assert that dropout op exists and is in train mode
|
|
dropout_node = self._get_node(m, target)
|
|
self.assertTrue(dropout_node is not None)
|
|
self.assertTrue(dropout_node.args[2])
|
|
|
|
# Move to eval
|
|
torch.ao.quantization.move_exported_model_to_eval(m)
|
|
|
|
# Assert that dropout op is now in eval mode
|
|
dropout_node = self._get_node(m, target)
|
|
self.assertTrue(dropout_node is not None)
|
|
self.assertTrue(not dropout_node.args[2])
|
|
|
|
# Move back to train
|
|
torch.ao.quantization.move_exported_model_to_train(m)
|
|
|
|
# Assert that dropout op is now in train mode again
|
|
dropout_node = self._get_node(m, target)
|
|
self.assertTrue(dropout_node is not None)
|
|
self.assertTrue(dropout_node.args[2])
|
|
|
|
def test_move_exported_model_dropout(self):
|
|
self._test_move_exported_model_dropout(inplace=False)
|
|
|
|
def test_move_exported_model_dropout_inplace(self):
|
|
self._test_move_exported_model_dropout(inplace=True)
|
|
|
|
def _get_bn_train_eval_ops(self):
|
|
return (
|
|
torch.ops.aten.batch_norm.default,
|
|
torch.ops.aten.batch_norm.default,
|
|
)
|
|
|
|
@parametrize(
|
|
"device",
|
|
["cpu"] + (["cuda"] if TEST_CUDA else []) + (["hpu"] if TEST_HPU else []),
|
|
)
|
|
def test_move_exported_model_bn(self, device):
|
|
"""
|
|
Test switching batch_norm behavior between train and eval modes using
|
|
`move_exported_model_to_eval` and `move_exported_model_to_train` APIs.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
|
|
def forward(self, x):
|
|
return self.bn(x)
|
|
|
|
if TEST_CUDA or TEST_HPU:
|
|
m = M().train().to(device)
|
|
example_inputs = (torch.randn((1, 3, 3, 3), device=device),)
|
|
|
|
else:
|
|
m = M().train()
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
|
|
# Assert that batch norm op exists and is in train mode
|
|
bn_node = self._get_node(m, bn_train_op)
|
|
self.assertTrue(bn_node is not None)
|
|
self.assertTrue(bn_node.args[5])
|
|
|
|
# Move to eval
|
|
torch.ao.quantization.move_exported_model_to_eval(m)
|
|
|
|
# Assert that batch norm op is now in eval mode
|
|
bn_node = self._get_node(m, bn_eval_op)
|
|
self.assertTrue(bn_node is not None)
|
|
|
|
# Move to train
|
|
torch.ao.quantization.move_exported_model_to_train(m)
|
|
|
|
# Assert that batch norm op is now in train mode again
|
|
bn_node = self._get_node(m, bn_train_op)
|
|
self.assertTrue(bn_node is not None)
|
|
self.assertTrue(bn_node.args[5])
|
|
|
|
def test_disallow_eval_train(self):
|
|
m = TestHelperModules.ConvWithBNRelu(relu=True)
|
|
example_inputs = (torch.rand(3, 3, 5, 5),)
|
|
|
|
# Before export: this is OK
|
|
m.eval()
|
|
m.train()
|
|
|
|
# After export: this is not OK
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
with self.assertRaises(NotImplementedError):
|
|
m.eval()
|
|
with self.assertRaises(NotImplementedError):
|
|
m.train()
|
|
|
|
# After prepare: still not OK
|
|
quantizer = XNNPACKQuantizer()
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
with self.assertRaises(NotImplementedError):
|
|
m.eval()
|
|
with self.assertRaises(NotImplementedError):
|
|
m.train()
|
|
|
|
# After convert: still not OK
|
|
m = convert_pt2e(m)
|
|
with self.assertRaises(NotImplementedError):
|
|
m.eval()
|
|
with self.assertRaises(NotImplementedError):
|
|
m.train()
|
|
|
|
@skipIfHpu
|
|
def test_allow_exported_model_train_eval(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
self.dropout = torch.nn.Dropout(0.5)
|
|
|
|
def forward(self, x):
|
|
x = self.bn(x)
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
if TEST_CUDA:
|
|
m = M().train().cuda()
|
|
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
|
|
else:
|
|
m = M().train()
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
|
|
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
|
|
targets = [n.target for n in m.graph.nodes]
|
|
bn_op = bn_train_op if train else bn_eval_op
|
|
bn_node = self._get_node(m, bn_op)
|
|
self.assertTrue(bn_node is not None)
|
|
if TEST_CUDA:
|
|
self.assertEqual(bn_node.args[5], train)
|
|
dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
|
|
self.assertEqual(dropout_node.args[2], train)
|
|
|
|
# Before wrapping: this is not OK
|
|
with self.assertRaises(NotImplementedError):
|
|
m.eval()
|
|
with self.assertRaises(NotImplementedError):
|
|
m.train()
|
|
|
|
# After wrapping: does not error and swaps the ops accordingly
|
|
torch.ao.quantization.allow_exported_model_train_eval(m)
|
|
m.eval()
|
|
_assert_ops_are_correct(m, train=False)
|
|
m.train()
|
|
_assert_ops_are_correct(m, train=True)
|
|
|
|
# After prepare but before wrapping: this is not OK
|
|
quantizer = XNNPACKQuantizer()
|
|
m = prepare_qat_pt2e(m, quantizer)
|
|
with self.assertRaises(NotImplementedError):
|
|
m.eval()
|
|
with self.assertRaises(NotImplementedError):
|
|
m.train()
|
|
|
|
# After prepare and after wrapping: does not error and swaps the ops accordingly
|
|
torch.ao.quantization.allow_exported_model_train_eval(m)
|
|
m.eval()
|
|
_assert_ops_are_correct(m, train=False)
|
|
m.train()
|
|
_assert_ops_are_correct(m, train=True)
|
|
|
|
# After convert but before wrapping: this is not OK
|
|
m = convert_pt2e(m, fold_quantize=True)
|
|
with self.assertRaises(NotImplementedError):
|
|
m.eval()
|
|
with self.assertRaises(NotImplementedError):
|
|
m.train()
|
|
|
|
# After convert and after wrapping: does not error and swaps the ops accordingly
|
|
torch.ao.quantization.allow_exported_model_train_eval(m)
|
|
m.eval()
|
|
_assert_ops_are_correct(m, train=False)
|
|
m.train()
|
|
_assert_ops_are_correct(m, train=True)
|
|
|
|
def test_allow_exported_model_train_eval_idempotent(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
|
|
def forward(self, x):
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
m = M().train()
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
torch.ao.quantization.allow_exported_model_train_eval(m)
|
|
|
|
# Mock m.recompile() to count how many times it's been called
|
|
m._recompile_count = 0
|
|
|
|
def _fake_recompile():
|
|
m._recompile_count += 1
|
|
|
|
m.recompile = _fake_recompile
|
|
|
|
# First train after export should always recompile
|
|
m.train()
|
|
self.assertNotEqual(m._recompile_count, 0)
|
|
count1 = m._recompile_count
|
|
|
|
# Train -> train should not recompile
|
|
m.train()
|
|
self.assertEqual(m._recompile_count, count1)
|
|
|
|
# Train -> eval should recompile
|
|
m.eval()
|
|
self.assertNotEqual(m._recompile_count, count1)
|
|
count2 = m._recompile_count
|
|
|
|
# Eval -> eval should not recompile
|
|
m.eval()
|
|
self.assertEqual(m._recompile_count, count2)
|
|
|
|
def test_model_is_exported(self):
|
|
m = TestHelperModules.ConvWithBNRelu(relu=True)
|
|
example_inputs = (torch.rand(3, 3, 5, 5),)
|
|
exported_gm = export_for_training(m, example_inputs, strict=True).module()
|
|
fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs)
|
|
self.assertTrue(
|
|
torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm)
|
|
)
|
|
self.assertFalse(
|
|
torch.ao.quantization.pt2e.export_utils.model_is_exported(fx_traced_gm)
|
|
)
|
|
self.assertFalse(torch.ao.quantization.pt2e.export_utils.model_is_exported(m))
|
|
|
|
def test_reentrant(self):
|
|
"""Test we can safely call quantization apis multiple times"""
|
|
m = TestHelperModules.ConvBnReLU2dAndLinearReLU()
|
|
example_inputs = (torch.randn(3, 3, 10, 10),)
|
|
|
|
quantizer = XNNPACKQuantizer().set_global(
|
|
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
|
|
)
|
|
m.conv_bn_relu = export_for_training(
|
|
m.conv_bn_relu, example_inputs, strict=True
|
|
).module()
|
|
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
|
|
m(*example_inputs)
|
|
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu)
|
|
|
|
quantizer = XNNPACKQuantizer().set_module_type(
|
|
torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
|
|
)
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
m = prepare_pt2e(m, quantizer)
|
|
m = convert_pt2e(m)
|
|
|
|
node_occurrence = {
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
): 4,
|
|
# one for weight
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
): 5,
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
): 1,
|
|
}
|
|
node_list = [
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.conv2d.default),
|
|
ns.call_function(torch.ops.aten.relu.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
),
|
|
ns.call_function(torch.ops.aten.linear.default),
|
|
ns.call_function(
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
),
|
|
]
|
|
self.checkGraphModuleNodes(
|
|
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
|
|
)
|
|
|
|
def test_groupwise_per_channel_quant(self):
|
|
m = TestHelperModules.GroupwiseConv2d()
|
|
quantizer = XNNPACKQuantizer()
|
|
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
|
quantizer.set_global(operator_config)
|
|
example_inputs = m.example_inputs()
|
|
m = self._quantize(m, quantizer, example_inputs)
|
|
# make sure it runs
|
|
m(*example_inputs)
|
|
|
|
def test_observer_callback(self):
|
|
from torch.library import impl, Library
|
|
|
|
test_lib = Library("test_int4", "DEF") # noqa: TOR901
|
|
test_lib.define(
|
|
"quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor"
|
|
)
|
|
|
|
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
|
|
def quantize_per_tensor_int4(
|
|
input: torch.Tensor,
|
|
scale: float,
|
|
zero_point: int,
|
|
) -> torch.Tensor:
|
|
inv_scale = 1.0 / scale
|
|
return (
|
|
torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15)
|
|
.to(torch.uint8)
|
|
.view(torch.bits8)
|
|
)
|
|
|
|
test_lib.define(
|
|
"dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor"
|
|
)
|
|
|
|
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
|
|
def dequantize_per_tensor_int4(
|
|
input: torch.Tensor,
|
|
scale: float,
|
|
zero_point: int,
|
|
) -> torch.Tensor:
|
|
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale
|
|
|
|
from torch.ao.quantization.observer import ObserverBase
|
|
|
|
class Int4Observer(ObserverBase):
|
|
def __init__(self, *args, **kwargs):
|
|
# just faking a dtype here
|
|
super().__init__(dtype=torch.int8)
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
def calculate_qparams(self, **kwargs):
|
|
pass
|
|
|
|
def convert(self, model: torch.fx.GraphModule, observer_node: Node):
|
|
with model.graph.inserting_before(observer_node):
|
|
q_node = model.graph.call_function(
|
|
torch.ops.test_int4.quantize_per_tensor_int4,
|
|
(observer_node.args[0], 1.0, 0),
|
|
{},
|
|
)
|
|
dq_node = model.graph.call_function(
|
|
torch.ops.test_int4.dequantize_per_tensor_int4,
|
|
(q_node, 1.0, 0),
|
|
{},
|
|
)
|
|
observer_node.replace_all_uses_with(dq_node)
|
|
model.graph.erase_node(observer_node)
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.add.Tensor
|
|
):
|
|
input_act0 = node.args[0]
|
|
assert isinstance(input_act0, Node)
|
|
input_act1 = node.args[1]
|
|
assert isinstance(input_act1, Node)
|
|
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=Int4Observer,
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act0: act_qspec,
|
|
input_act1: act_qspec,
|
|
},
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x1, x2):
|
|
return x1 + x2
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 3, 5, 5),
|
|
torch.randn(1, 3, 5, 5),
|
|
)
|
|
node_occurrence = {
|
|
# two for input of the first conv, one for output for the first conv
|
|
torch.ops.test_int4.quantize_per_tensor_int4: 3,
|
|
torch.ops.test_int4.dequantize_per_tensor_int4: 3,
|
|
}
|
|
node_list = [
|
|
torch.ops.test_int4.dequantize_per_tensor_int4,
|
|
torch.ops.test_int4.dequantize_per_tensor_int4,
|
|
torch.ops.aten.add.Tensor,
|
|
torch.ops.test_int4.quantize_per_tensor_int4,
|
|
]
|
|
self._test_quantizer(
|
|
M().eval(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_occurrence,
|
|
node_list,
|
|
)
|
|
|
|
def test_speed(self):
|
|
import time
|
|
|
|
def dynamic_quantize_pt2e(model, example_inputs):
|
|
torch._dynamo.reset()
|
|
model = export_for_training(model, example_inputs, strict=True).module()
|
|
# Per channel quantization for weight
|
|
# Dynamic quantization for activation
|
|
# Please read a detail: https://fburl.com/code/30zds51q
|
|
embedding_quantizer = EmbeddingQuantizer()
|
|
dynamic_quantizer = XNNPACKQuantizer()
|
|
operator_config_dynamic = get_symmetric_quantization_config(
|
|
is_per_channel=True, is_dynamic=True
|
|
)
|
|
dynamic_quantizer.set_global(operator_config_dynamic)
|
|
composed_quantizer = ComposableQuantizer(
|
|
[embedding_quantizer, dynamic_quantizer]
|
|
)
|
|
prev = time.time()
|
|
model = prepare_qat_pt2e(model, composed_quantizer)
|
|
cur = time.time()
|
|
# print("prepare time:", cur - prev)
|
|
# Without Calibraiton, scale/zero value will have an initialized value of 1.0
|
|
# Per channel quantization needs a proper scale/zero shape/value to work properly.
|
|
# So we need to run calibration before converting to quantized model.
|
|
model(*example_inputs)
|
|
prev = time.time()
|
|
model = convert_pt2e(model)
|
|
cur = time.time()
|
|
# uncomment to see the time
|
|
# print("convert time:", cur - prev)
|
|
return model
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
m = M().eval()
|
|
example_inputs = (torch.randn(5, 5),)
|
|
_ = dynamic_quantize_pt2e(m, example_inputs)
|
|
|
|
def test_conv_transpose_bn_relu(self):
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
int8_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_symmetric,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
quantization_config = QuantizationConfig(
|
|
input_activation=int8_qspec,
|
|
weight=int8_qspec,
|
|
bias=None,
|
|
output_activation=int8_qspec,
|
|
)
|
|
# conv_transpose + bn is fused automatically in PTQ (not configurable)
|
|
# so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu
|
|
# pattern
|
|
OP_TO_ANNOTATOR["conv_transpose_relu"](model, quantization_config)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
node_occurrence = {
|
|
# two for input of the first conv, one for output for the first conv
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
|
}
|
|
node_list = [
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.aten.conv_transpose2d.input,
|
|
torch.ops.aten.relu.default,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
]
|
|
self._test_quantizer(
|
|
TestHelperModules.ConvTWithBNRelu(relu=True, bn=True),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_occurrence,
|
|
node_list,
|
|
)
|
|
|
|
def test_conv_padding_bn_relu(self):
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
|
)
|
|
bias_qspec = QuantizationSpec(
|
|
dtype=torch.float32,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
|
)
|
|
|
|
for n in model.graph.nodes:
|
|
if (
|
|
n.op != "call_function"
|
|
or n.target != torch.ops.aten.relu.default
|
|
):
|
|
continue
|
|
relu_node = n
|
|
n = n.args[0]
|
|
|
|
# Check for any of the conv operations
|
|
conv_ops = [
|
|
torch.ops.aten.conv1d.padding,
|
|
torch.ops.aten.conv2d.padding,
|
|
torch.ops.aten.conv3d.padding,
|
|
]
|
|
if n.op != "call_function" or n.target not in conv_ops:
|
|
continue
|
|
|
|
conv_node = n
|
|
input_act = conv_node.args[0]
|
|
weight = conv_node.args[1]
|
|
bias = conv_node.args[2]
|
|
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
bias: bias_qspec,
|
|
},
|
|
_annotated=True,
|
|
)
|
|
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
output_qspec=act_qspec,
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
# Test cases for Conv1d, Conv2d, Conv3d
|
|
test_cases = [
|
|
{
|
|
"dim": 1,
|
|
"example_input": (torch.randn(1, 3, 5),),
|
|
"conv_op": torch.ops.aten.conv1d.padding,
|
|
},
|
|
{
|
|
"dim": 2,
|
|
"example_input": (torch.randn(1, 3, 5, 5),),
|
|
"conv_op": torch.ops.aten.conv2d.padding,
|
|
},
|
|
{
|
|
"dim": 3,
|
|
"example_input": (torch.randn(1, 3, 5, 5, 5),),
|
|
"conv_op": torch.ops.aten.conv3d.padding,
|
|
},
|
|
]
|
|
|
|
for test_case in test_cases:
|
|
with self.subTest(dim=test_case["dim"]):
|
|
model = TestHelperModules.ConvWithBNRelu(
|
|
relu=True,
|
|
dim=test_case["dim"],
|
|
bn=True,
|
|
bias=True,
|
|
padding="same", # This will trigger the .padding variants
|
|
).eval()
|
|
|
|
node_occurrence = {
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
|
}
|
|
node_list = [
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
test_case["conv_op"],
|
|
torch.ops.aten.relu.default,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
]
|
|
|
|
self._test_quantizer(
|
|
model,
|
|
test_case["example_input"],
|
|
BackendAQuantizer(),
|
|
node_occurrence,
|
|
node_list,
|
|
)
|
|
|
|
def test_multi_users_without_output_observer(self):
|
|
"""
|
|
Test the case in which a node is used by multiple users,
|
|
and had its output observer removed.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x, x + 1
|
|
|
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
|
m = M()
|
|
m = export_for_training(m, example_inputs, strict=True).module()
|
|
quantizer = XNNPACKQuantizer().set_global(
|
|
get_symmetric_quantization_config(),
|
|
)
|
|
m = prepare_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
|
|
# Remove output observer
|
|
observer_to_remove = None
|
|
for n in m.graph.nodes:
|
|
if n.op == "output":
|
|
observer_to_remove = n.args[0][0]
|
|
assert observer_to_remove.op == "call_module"
|
|
assert observer_to_remove.target.startswith("activation_post_process_")
|
|
break
|
|
assert observer_to_remove is not None
|
|
observer_to_remove.replace_all_uses_with(observer_to_remove.args[0])
|
|
m.graph.erase_node(observer_to_remove)
|
|
m.recompile()
|
|
|
|
# Convert should succeed
|
|
m = convert_pt2e(m)
|
|
m(*example_inputs)
|
|
|
|
def test_prepare_obs_or_fq_callback(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = torch.nn.functional.max_pool2d(x, 2, 2)
|
|
x = torch.nn.functional.pixel_shuffle(x, 2)
|
|
return x.permute(0, 2, 3, 1)
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=observer.default_observer,
|
|
)
|
|
for node in model.graph.nodes:
|
|
if node.op == "call_function" and node.target in (
|
|
torch.ops.aten.max_pool2d.default,
|
|
torch.ops.aten.permute.default,
|
|
torch.ops.aten.pixel_shuffle.default,
|
|
):
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
node.args[0]: act_qspec,
|
|
},
|
|
output_qspec=SharedQuantizationSpec((node.args[0], node)),
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
def prepare_obs_or_fq_callback(
|
|
self,
|
|
model: torch.fx.GraphModule,
|
|
edge_or_node_to_obs_or_fq: dict[EdgeOrNode, ObserverOrFakeQuantize],
|
|
) -> None:
|
|
# hard code output quant by updating entire sharing group
|
|
output_node = next(n for n in model.graph.nodes if n.op == "output")
|
|
output_value = output_node.args[0][0]
|
|
old_observer = edge_or_node_to_obs_or_fq[output_value]
|
|
sharing_group = [
|
|
k for k, v in edge_or_node_to_obs_or_fq.items() if v is old_observer
|
|
]
|
|
new_observer = observer.FixedQParamsObserver(
|
|
scale=0.125,
|
|
zero_point=42,
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=torch.per_tensor_affine,
|
|
)
|
|
for x in sharing_group:
|
|
edge_or_node_to_obs_or_fq[x] = new_observer
|
|
|
|
example_inputs = (torch.rand(1, 32, 16, 16),)
|
|
gm = export_for_training(Model().eval(), example_inputs, strict=True).module()
|
|
gm = prepare_pt2e(gm, BackendAQuantizer())
|
|
gm = convert_pt2e(gm)
|
|
for n in gm.graph.nodes:
|
|
if n.op == "call_function" and n.target in (
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
):
|
|
# Entire graph share the same qspec which was overridden by FixedQParamsObserver
|
|
self.assertEqual(n.args[1], 0.125)
|
|
self.assertEqual(n.args[2], 42)
|
|
|
|
def test_preserve_nn_module_stack(self):
|
|
"""Test we can preserve nn_module_stack on replaced pattern's nodes"""
|
|
m = TestHelperModules.ConvBnReLU2dAndLinearReLU()
|
|
example_inputs = (torch.randn(3, 3, 10, 10),)
|
|
|
|
quantizer = XNNPACKQuantizer().set_global(
|
|
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
|
|
)
|
|
|
|
def check_nn_module(node):
|
|
self.assertTrue("nn_module_stack" in node.meta)
|
|
self.assertTrue(
|
|
"ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1]
|
|
)
|
|
|
|
m.conv_bn_relu = export_for_training(
|
|
m.conv_bn_relu, example_inputs, strict=True
|
|
).module()
|
|
for node in m.conv_bn_relu.graph.nodes:
|
|
if node.op not in ["placeholder", "output", "get_attr"]:
|
|
check_nn_module(node)
|
|
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
|
|
for node in m.conv_bn_relu.graph.nodes:
|
|
if node.name == "mul":
|
|
check_nn_module(node)
|
|
|
|
|
|
@skipIfNoQNNPACK
|
|
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
|
|
def test_channel_group_quantization(self):
|
|
from torch.ao.quantization.observer import MappingType, PerGroup, PerToken
|
|
from torch.ao.quantization.pt2e._affine_quantization import (
|
|
AffineQuantizedMinMaxObserver,
|
|
)
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.linear.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
|
|
act_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=None,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
|
|
# TODO: maybe align the arg name here
|
|
target_dtype=torch.uint8,
|
|
mapping_type=MappingType.SYMMETRIC,
|
|
granularity=PerToken(),
|
|
),
|
|
)
|
|
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.uint8,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
qscheme=None,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
|
|
target_dtype=torch.uint8,
|
|
mapping_type=MappingType.SYMMETRIC,
|
|
granularity=PerGroup(group_size=128),
|
|
),
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
},
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(128, 20)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
node_occurrence = {
|
|
torch.ops.pt2e_quant.quantize_affine: 1,
|
|
torch.ops.pt2e_quant.dequantize_affine: 2,
|
|
}
|
|
node_list = [
|
|
torch.ops.pt2e_quant.quantize_affine,
|
|
torch.ops.pt2e_quant.dequantize_affine,
|
|
]
|
|
example_inputs = (torch.randn(5, 128),)
|
|
self._test_quantizer(
|
|
M().eval(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_occurrence,
|
|
node_list,
|
|
is_debug_mode=True,
|
|
)
|
|
|
|
def test_dynamic_affine_act_per_channel_weights(self):
|
|
import operator
|
|
|
|
from torch.ao.quantization.observer import (
|
|
MappingType,
|
|
PerChannelMinMaxObserver,
|
|
PerToken,
|
|
)
|
|
from torch.ao.quantization.pt2e._affine_quantization import (
|
|
AffineQuantizedMovingAverageMinMaxObserver,
|
|
)
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.linear.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
|
|
activation_dtype = torch.int8
|
|
act_qspec = QuantizationSpec(
|
|
dtype=activation_dtype,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=None,
|
|
is_dynamic=True,
|
|
observer_or_fake_quant_ctr=AffineQuantizedMovingAverageMinMaxObserver.with_args(
|
|
# TODO: maybe align the arg name here
|
|
target_dtype=activation_dtype,
|
|
mapping_type=MappingType.SYMMETRIC,
|
|
granularity=PerToken(),
|
|
averaging_constant=1,
|
|
),
|
|
)
|
|
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-127,
|
|
quant_max=127,
|
|
qscheme=torch.per_channel_symmetric,
|
|
ch_axis=0,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(),
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
},
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(128, 20)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
node_occurrence = {
|
|
torch.ops.pt2e_quant.choose_qparams_affine: 1,
|
|
operator.getitem: 2,
|
|
torch.ops.pt2e_quant.quantize_affine: 1,
|
|
torch.ops.pt2e_quant.dequantize_affine: 1,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
|
}
|
|
node_list = [
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
torch.ops.pt2e_quant.choose_qparams_affine,
|
|
operator.getitem,
|
|
torch.ops.pt2e_quant.quantize_affine,
|
|
torch.ops.pt2e_quant.dequantize_affine,
|
|
]
|
|
example_inputs = (torch.randn(5, 128),)
|
|
self._test_quantizer(
|
|
M().eval(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_occurrence,
|
|
node_list,
|
|
is_debug_mode=True,
|
|
)
|
|
|
|
def test_dynamic_per_tok_act_per_group_weights(self):
|
|
import operator
|
|
|
|
from torch.ao.quantization.observer import MappingType, PerGroup, PerToken
|
|
from torch.ao.quantization.pt2e._affine_quantization import (
|
|
AffineQuantizedMinMaxObserver,
|
|
AffineQuantizedPlaceholderObserver,
|
|
)
|
|
|
|
class BackendAQuantizer(Quantizer):
|
|
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in model.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.linear.default
|
|
):
|
|
input_act = node.args[0]
|
|
assert isinstance(input_act, Node)
|
|
weight = node.args[1]
|
|
assert isinstance(weight, Node)
|
|
|
|
activation_dtype = torch.int8
|
|
act_qspec = QuantizationSpec(
|
|
dtype=activation_dtype,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme=None,
|
|
is_dynamic=True,
|
|
observer_or_fake_quant_ctr=AffineQuantizedPlaceholderObserver.with_args(
|
|
# TODO: maybe align the arg name here
|
|
target_dtype=activation_dtype,
|
|
mapping_type=MappingType.SYMMETRIC,
|
|
granularity=PerToken(),
|
|
),
|
|
)
|
|
|
|
weight_qspec = QuantizationSpec(
|
|
dtype=torch.int8,
|
|
quant_min=-127,
|
|
quant_max=127,
|
|
qscheme=torch.per_channel_symmetric,
|
|
ch_axis=0,
|
|
is_dynamic=False,
|
|
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
|
|
target_dtype=torch.int8,
|
|
mapping_type=MappingType.SYMMETRIC,
|
|
granularity=PerGroup(group_size=128),
|
|
),
|
|
)
|
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
input_qspec_map={
|
|
input_act: act_qspec,
|
|
weight: weight_qspec,
|
|
},
|
|
_annotated=True,
|
|
)
|
|
|
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
pass
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(128, 20)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
node_occurrence = {
|
|
torch.ops.pt2e_quant.choose_qparams_affine: 1,
|
|
operator.getitem: 2,
|
|
torch.ops.pt2e_quant.quantize_affine: 1,
|
|
torch.ops.pt2e_quant.dequantize_affine: 2,
|
|
}
|
|
node_list = [
|
|
torch.ops.pt2e_quant.dequantize_affine,
|
|
torch.ops.pt2e_quant.choose_qparams_affine,
|
|
operator.getitem,
|
|
torch.ops.pt2e_quant.quantize_affine,
|
|
torch.ops.pt2e_quant.dequantize_affine,
|
|
]
|
|
example_inputs = (torch.randn(5, 128),)
|
|
self._test_quantizer(
|
|
M().eval(),
|
|
example_inputs,
|
|
BackendAQuantizer(),
|
|
node_occurrence,
|
|
node_list,
|
|
is_debug_mode=True,
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestQuantizePT2E)
|