mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant] Add DQ duplication pass (#107900)
Summary:
During convert step observers are first replaced by Q-DQ pair. In some
scenarios like following output DQ has a fan out.
---> OP2 -> Q -> DQ
/
OP -> Q -> DQ -
\
---> OP3 -> Q -> DQ
If either op OP2 or OP3 are configured to be quantized, then the input
is expected to quantized. In this case quantized equivalent of some
pattern, that quantizer asked to be quantized, should look like:
[DQ -> {pattern} -> Q]. However, in scenario like above where DQ node
is shared between multiple "quantized" patterns, boundary of "quantized"
pattern is not clear because DQ now belongs to multiple quantized
patterns.
This poses challenge for:
- Porting metadata: which "quantized" partition this DQ node belongs
- Quantized representation, equivalently, needs to identify
self-contained quantized pattern that is replaced by its equivalent pattern
that captures compute in the quantized precision.
Test Plan:
test_duplicate_dq_pass
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D48663147](https://our.internmc.facebook.com/intern/diff/D48663147)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107900
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14, https://github.com/leslie-fang-intel
ghstack dependencies: #107105, #107106, #107899
This commit is contained in:
parent
f8d1ca9835
commit
eb67c452c8
|
|
@ -513,6 +513,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1, and graph output * 1
|
||||
# matched in pointless_convert pass at
|
||||
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
|
||||
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
|
||||
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
|
||||
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
|
||||
# However when dq has multiple users we will have
|
||||
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
|
||||
# \-> to_float -> sub -> mul]
|
||||
# So for now we will discount one pattern here
|
||||
# 2. Dequant pattern matcher for dequant promotion * 1
|
||||
# [convert_element_type_3, sub_1, mul_3]
|
||||
# 3. Dequant-conv pattern matched in quantization weight prepack * 2
|
||||
|
|
@ -525,8 +532,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
8,
|
||||
39,
|
||||
7,
|
||||
37,
|
||||
check_quantization=True,
|
||||
)
|
||||
|
||||
|
|
@ -573,6 +580,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1, and graph output * 1
|
||||
# matched in pointless_convert pass at
|
||||
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
|
||||
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
|
||||
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
|
||||
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
|
||||
# However when dq has multiple users we will have
|
||||
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
|
||||
# \-> to_float -> sub -> mul]
|
||||
# So for now we will discount one pattern here
|
||||
# 2. Dequant pattern matcher for dequant promotion * 1
|
||||
# [convert_element_type_3, sub_1, mul_3]
|
||||
# 3. Dequant-conv pattern matched in quantization weight prepack * 2
|
||||
|
|
@ -585,8 +599,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
8,
|
||||
40,
|
||||
7,
|
||||
38,
|
||||
check_quantization=True,
|
||||
)
|
||||
|
||||
|
|
@ -629,6 +643,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
# 1. Pair of to_int8 and to_fp32 at conv input * 2, extra input of add * 1, and graph output * 1
|
||||
# matched in pointless_convert pass at
|
||||
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
|
||||
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
|
||||
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
|
||||
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
|
||||
# However when dq has multiple users we will have
|
||||
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
|
||||
# \-> to_float -> sub -> mul]
|
||||
# So for now we will discount one pattern here
|
||||
# 2. Dequant pattern matcher for dequant promotion * 1
|
||||
# [convert_element_type_3, sub_1, mul_3]
|
||||
# 3. Dequant-conv pattern matched in quantization weight prepack * 3
|
||||
|
|
@ -641,8 +662,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
11,
|
||||
54,
|
||||
10,
|
||||
52,
|
||||
check_quantization=True,
|
||||
)
|
||||
|
||||
|
|
@ -764,6 +785,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
# 1. Pair of to_int8 and to_fp32 at linear input * 2, extra input of add * 1, and graph output * 1
|
||||
# matched in pointless_convert pass at
|
||||
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
|
||||
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
|
||||
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
|
||||
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
|
||||
# However when dq has multiple users we will have
|
||||
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
|
||||
# \-> to_float -> sub -> mul]
|
||||
# So for now we will discount one pattern here
|
||||
# 2. Dequant pattern matcher for dequant promotion * 1
|
||||
# [convert_element_type_3, sub_1, mul_3]
|
||||
# 3. Dequant-linear pattern matched in quantization weight prepack * 3
|
||||
|
|
@ -773,8 +801,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
11,
|
||||
50,
|
||||
10,
|
||||
48,
|
||||
check_quantization=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
313
test/quantization/pt2e/test_duplicate_dq.py
Normal file
313
test/quantization/pt2e/test_duplicate_dq.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch._export as export
|
||||
|
||||
from torch.ao.quantization.observer import (
|
||||
HistogramObserver,
|
||||
MinMaxObserver,
|
||||
PlaceholderObserver,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.utils import _find_q_dq_node_for_user
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
QuantizationSpec,
|
||||
Quantizer,
|
||||
SharedQuantizationSpec,
|
||||
)
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
)
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
OP_TO_ANNOTATOR,
|
||||
QuantizationConfig,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
|
||||
|
||||
class TestHelperModules:
|
||||
class Conv2dWithObsSharingOps(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.hardtanh = torch.nn.Hardtanh()
|
||||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
x = self.hardtanh(x)
|
||||
x = x.view(-1, 3)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class Conv2dWithSharedDQ(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
||||
self.conv2 = torch.nn.Conv2d(3, 3, 1)
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
z = x.view(-1, 3)
|
||||
w = self.linear(z)
|
||||
|
||||
y = self.conv2(x)
|
||||
add_output = x + y
|
||||
|
||||
extra_output = x * 2
|
||||
return w, add_output, extra_output
|
||||
|
||||
class ModuleForDifferentQconfig(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
||||
self.conv2 = torch.nn.Conv2d(3, 3, 1)
|
||||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
w = self.adaptive_avg_pool2d(x)
|
||||
|
||||
y = self.conv2(x)
|
||||
add_output = x + y
|
||||
|
||||
extra_output = x + 2
|
||||
return w, add_output, extra_output
|
||||
|
||||
|
||||
class TestDuplicateDQPass(QuantizationTestCase):
|
||||
def _test_duplicate_dq(
|
||||
self,
|
||||
model,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
):
|
||||
m_eager = model.eval()
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = export.capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
for n in m.graph.nodes:
|
||||
annotation = n.meta.get("quantization_annotation", None)
|
||||
if annotation is not None:
|
||||
input_qspec_map = annotation.input_qspec_map
|
||||
for input_node, qspec in input_qspec_map.items():
|
||||
if (
|
||||
qspec is not None
|
||||
and hasattr(qspec, "dtype")
|
||||
and qspec.dtype != torch.float
|
||||
):
|
||||
q_node, dq_node = _find_q_dq_node_for_user(input_node, n)
|
||||
if dq_node is None:
|
||||
raise ValueError(
|
||||
f"No dq node found for {n}, even though {n} annotated for quantization."
|
||||
)
|
||||
self.assertEqual(len(dq_node.users.keys()), 1)
|
||||
|
||||
def test_no_need_for_duplicate_dq(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> avgpool -> hardtanh -> linear
|
||||
Check quantization tags on conv2d, avgpool and linear are correctly set
|
||||
"""
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True
|
||||
)
|
||||
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
|
||||
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
|
||||
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 7),)
|
||||
self._test_duplicate_dq(
|
||||
TestHelperModules.Conv2dWithObsSharingOps(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
)
|
||||
|
||||
def test_simple_duplicate_dq(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> conv2d -> add
|
||||
| |
|
||||
--------->
|
||||
|
|
||||
-----> view_copy --> linear
|
||||
|
|
||||
-----> mul
|
||||
There should be three dq nodes because output for the
|
||||
first conv2d is fed to next conv2d, add, and view_copy + linear.
|
||||
All three are quantized.
|
||||
Thus DQ node is not duplicated for those three uses
|
||||
"""
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True
|
||||
)
|
||||
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
|
||||
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
|
||||
OP_TO_ANNOTATOR["add"](gm, quantization_config)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 7),)
|
||||
self._test_duplicate_dq(
|
||||
TestHelperModules.Conv2dWithSharedDQ(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
)
|
||||
|
||||
def test_no_add_quant_duplicate_dq(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> conv2d -> add
|
||||
| |
|
||||
--------->
|
||||
|
|
||||
-----> view_copy --> linear
|
||||
|
|
||||
-----> mul
|
||||
There should be three dq nodes because output for the
|
||||
first conv2d is fed to next conv2d, and view_copy + linear.
|
||||
Both are quantized.
|
||||
However the skip connection to add and mul are not quantized.
|
||||
Thus DQ node is not duplicated for those two uses
|
||||
"""
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True
|
||||
)
|
||||
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
|
||||
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 7),)
|
||||
self._test_duplicate_dq(
|
||||
TestHelperModules.Conv2dWithSharedDQ(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
)
|
||||
|
||||
def test_avgpool_use_different_qconfig(self):
|
||||
"""
|
||||
Model under test
|
||||
conv2d -> conv2d -> add
|
||||
| |
|
||||
--------->
|
||||
|
|
||||
-----> adaptive_avgpool2d (different qconfig)
|
||||
|
|
||||
-----> add
|
||||
output
|
||||
conv2d -> dq -> conv2d -> add
|
||||
| |
|
||||
-------> dq ----->
|
||||
|
|
||||
-> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig)
|
||||
|
|
||||
-> dq -----> add
|
||||
"""
|
||||
|
||||
def _get_uint8_quantization_config():
|
||||
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
|
||||
act_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
||||
eps=2**-12
|
||||
),
|
||||
)
|
||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
MinMaxObserver
|
||||
)
|
||||
|
||||
extra_args: Dict[str, Any] = {"eps": 2**-12}
|
||||
weight_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
ch_axis=0,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
|
||||
**extra_args
|
||||
),
|
||||
)
|
||||
|
||||
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
PlaceholderObserver
|
||||
)
|
||||
bias_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.float,
|
||||
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr,
|
||||
)
|
||||
quantization_config = QuantizationConfig(
|
||||
act_quantization_spec,
|
||||
act_quantization_spec,
|
||||
weight_quantization_spec,
|
||||
bias_quantization_spec,
|
||||
)
|
||||
return quantization_config
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
backend_string = "BackendA"
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True
|
||||
)
|
||||
avgpool_qconfig = _get_uint8_quantization_config()
|
||||
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
|
||||
OP_TO_ANNOTATOR["add"](gm, quantization_config)
|
||||
for n in gm.graph.nodes:
|
||||
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
|
||||
qspec = avgpool_qconfig.input_activation
|
||||
input_act = n.args[0]
|
||||
output_qspec = SharedQuantizationSpec((input_act, n))
|
||||
n.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={input_act: qspec},
|
||||
output_qspec=output_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 7),)
|
||||
self._test_duplicate_dq(
|
||||
TestHelperModules.ModuleForDifferentQconfig(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
)
|
||||
|
|
@ -1285,7 +1285,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
node_occurrence = {
|
||||
# two input and one output for first add, and output for second add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
|
|
@ -1312,7 +1312,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
node_occurrence = {
|
||||
# two input and one output for first add, and output for second add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
|
|
|
|||
|
|
@ -372,7 +372,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
# 2 conv will share same input quant/dequant
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
|
|
@ -426,7 +426,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
# 2 conv will share same input quant/dequant
|
||||
# one for extra input node of add
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
}
|
||||
|
|
@ -457,7 +457,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
quantizer = X86InductorQuantizer().set_global(xiq.get_default_x86_inductor_quantization_config())
|
||||
node_occurrence = {
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default: 4,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -137,12 +137,13 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
|
|||
# Note: These patterns are comprised of torch ops and for internal use only.
|
||||
# They are exported to aten graphs before being passed to the FX subgraph rewriter.
|
||||
# TODO: find a better way to express this path without having to import
|
||||
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
|
||||
# `torch.ao.quantization.pt2e`, which interferes with memory profiling
|
||||
FILENAME_ALLOWLIST |= {
|
||||
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
|
||||
_module_dir(torch) + "ao/quantization/quantizer/xnnpack_quantizer.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/eval_utils.py",
|
||||
}
|
||||
|
||||
FILENAME_ALLOWLIST |= {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from .quantization_mappings import * # type: ignore[no-redef]
|
|||
from .quantize import * # noqa: F403
|
||||
from .quantize_jit import * # noqa: F403
|
||||
from .stubs import * # noqa: F403
|
||||
from .pt2e.utils import move_model_to_eval
|
||||
from .pt2e.eval_utils import _move_model_to_eval as move_model_to_eval
|
||||
from typing import Union, List, Callable, Tuple, Optional
|
||||
from torch import Tensor
|
||||
import torch
|
||||
|
|
|
|||
59
torch/ao/quantization/pt2e/duplicate_dq_pass.py
Normal file
59
torch/ao/quantization/pt2e/duplicate_dq_pass.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from torch._export.pass_base import _ExportPassBase
|
||||
|
||||
from torch.ao.quantization.pt2e.utils import (
|
||||
_filter_sym_size_users,
|
||||
_is_valid_annotation,
|
||||
)
|
||||
|
||||
from torch.fx.node import map_arg
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
__all__ = ["DuplicateDQPass"]
|
||||
|
||||
_DEQUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
def _maybe_duplicate_dq(
|
||||
gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
|
||||
):
|
||||
annotation = user.meta.get("quantization_annotation", None)
|
||||
if not _is_valid_annotation(annotation):
|
||||
return
|
||||
with gm.graph.inserting_after(dq_node):
|
||||
new_node = gm.graph.node_copy(dq_node)
|
||||
|
||||
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
|
||||
if n == dq_node:
|
||||
return new_node
|
||||
else:
|
||||
return n
|
||||
|
||||
new_args = map_arg(user.args, maybe_replace_node)
|
||||
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
|
||||
user.args = new_args
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
|
||||
class DuplicateDQPass(_ExportPassBase):
|
||||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
|
||||
dq_users = _filter_sym_size_users(node)
|
||||
if len(dq_users) <= 1:
|
||||
continue
|
||||
for user in dq_users:
|
||||
_maybe_duplicate_dq(graph_module, node, user)
|
||||
graph_module.graph.eliminate_dead_code()
|
||||
graph_module.recompile()
|
||||
return PassResult(graph_module, True)
|
||||
52
torch/ao/quantization/pt2e/eval_utils.py
Normal file
52
torch/ao/quantization/pt2e/eval_utils.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _replace_dropout_for_eval(m: torch.fx.GraphModule):
|
||||
"""
|
||||
Replace the aten training dropout pattern with a noop, intended for eval.
|
||||
|
||||
For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
|
||||
effectively turns these dropout ops into noops. For exported models, however,
|
||||
this is not done automatically, since the aten dropout patterns previously generated
|
||||
for training remain in the graph. Here we rewrite these dropout patterns with noops
|
||||
to avoid incorrectly applying further dropout during eval.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/103681.
|
||||
"""
|
||||
# Avoid circular dependencies
|
||||
from .utils import get_aten_graph_module
|
||||
|
||||
def dropout_train(x):
|
||||
return F.dropout(x, p=0.5, training=True)
|
||||
|
||||
def dropout_eval(x):
|
||||
return F.dropout(x, p=0.5, training=False)
|
||||
|
||||
example_inputs = (torch.randn(1),)
|
||||
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
|
||||
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
|
||||
|
||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern,
|
||||
match_filters=[],
|
||||
ignore_literals=True,
|
||||
)
|
||||
m.recompile()
|
||||
|
||||
|
||||
# TODO: also support move_model_to_train
|
||||
# TODO: also support standalone batchnorm
|
||||
def _move_model_to_eval(model: torch.fx.GraphModule):
|
||||
"""
|
||||
Move an exported GraphModule to eval mode.
|
||||
|
||||
This is equivalent to model.eval() but only for certain special ops like dropout.
|
||||
QAT users should call this before performing inference on the model.
|
||||
"""
|
||||
_replace_dropout_for_eval(model)
|
||||
return model
|
||||
|
|
@ -1,22 +1,105 @@
|
|||
import torch
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch.fx import (
|
||||
GraphModule,
|
||||
Node,
|
||||
)
|
||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
|
||||
from torch.utils._pytree import LeafSpec
|
||||
|
||||
# Makes sure that quantized_decomposed ops are registered
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
|
||||
from torch.ao.quantization.quantizer import QuantizationAnnotation
|
||||
|
||||
__all__ = [
|
||||
"fold_bn_weights_into_conv_node",
|
||||
"get_aten_graph_module",
|
||||
"move_model_to_eval",
|
||||
"remove_tensor_overload_for_qdq_ops",
|
||||
]
|
||||
|
||||
_QUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
_DEQUANTIZE_OPS = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||
]
|
||||
|
||||
|
||||
def _is_connected(next_node: torch.fx.Node, target: torch.fx.Node) -> bool:
|
||||
if target.op == "output":
|
||||
return False
|
||||
if next_node == target:
|
||||
return True
|
||||
for n in next_node.users.keys():
|
||||
if _is_connected(n, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _find_q_dq_node_for_user(
|
||||
produer: torch.fx.Node, user: torch.fx.Node
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Find d, dq pair corresponding to [producer ... -> q -> dq -> user]
|
||||
Utils works by finding dq arg of user and ensuring it is connected to
|
||||
producer
|
||||
"""
|
||||
dq_node = None
|
||||
for n in user.args:
|
||||
if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
|
||||
if _is_connected(produer, n):
|
||||
dq_node = n
|
||||
break
|
||||
if dq_node is None:
|
||||
for n in user.kwargs:
|
||||
if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
|
||||
if _is_connected(produer, n):
|
||||
dq_node = n
|
||||
break
|
||||
if dq_node is None:
|
||||
return (None, None)
|
||||
|
||||
q_node = None
|
||||
if dq_node.args[0].op == "call_function" and dq_node.args[0].target in _QUANTIZE_OPS:
|
||||
q_node = dq_node.args[0]
|
||||
return (q_node, dq_node)
|
||||
|
||||
|
||||
|
||||
def _is_sym_size_node(node: Node):
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.sym_size.default
|
||||
or node.target == torch.ops.aten.sym_numel.default
|
||||
or node.target == torch.ops.aten.sym_numel
|
||||
or node.target == torch.ops.aten.sym_size
|
||||
)
|
||||
|
||||
|
||||
def _filter_sym_size_users(node: torch.fx.Node) -> List[torch.fx.Node]:
|
||||
node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
|
||||
return node_users
|
||||
|
||||
|
||||
def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
|
||||
if annotation is None:
|
||||
return False
|
||||
input_qspec_map = annotation.input_qspec_map
|
||||
output_qspec = annotation.output_qspec
|
||||
if len(input_qspec_map) == 0 and output_qspec is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_tensor_constant_from_node(node, m):
|
||||
if node is None:
|
||||
return None
|
||||
|
|
@ -143,8 +226,6 @@ def get_aten_graph_module(
|
|||
"""
|
||||
Convert the pattern to an FX graph with decomposed aten ops.
|
||||
"""
|
||||
# Avoid circular dependencies
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
aten_pattern = capture_pre_autograd_graph(
|
||||
pattern,
|
||||
example_inputs,
|
||||
|
|
@ -175,37 +256,6 @@ def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
|||
if n.target in _MAP:
|
||||
n.target = _MAP[n.target]
|
||||
|
||||
def _replace_dropout_for_eval(m: GraphModule):
|
||||
"""
|
||||
Replace the aten training dropout pattern with a noop, intended for eval.
|
||||
|
||||
For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
|
||||
effectively turns these dropout ops into noops. For exported models, however,
|
||||
this is not done automatically, since the aten dropout patterns previously generated
|
||||
for training remain in the graph. Here we rewrite these dropout patterns with noops
|
||||
to avoid incorrectly applying further dropout during eval.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/103681.
|
||||
"""
|
||||
def dropout_train(x):
|
||||
return F.dropout(x, p=0.5, training=True)
|
||||
|
||||
def dropout_eval(x):
|
||||
return F.dropout(x, p=0.5, training=False)
|
||||
|
||||
example_inputs = (torch.randn(1),)
|
||||
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
|
||||
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
|
||||
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern,
|
||||
match_filters=[],
|
||||
ignore_literals=True,
|
||||
)
|
||||
m.recompile()
|
||||
|
||||
def _is_literal(arg):
|
||||
if isinstance(arg, (int, float)):
|
||||
return True
|
||||
|
|
@ -376,15 +426,3 @@ def _replace_literals_with_existing_placeholders(
|
|||
new_args = tuple(new_args)
|
||||
node.args = new_args
|
||||
return gm
|
||||
|
||||
# TODO: also support move_model_to_train
|
||||
# TODO: also support standalone batchnorm
|
||||
def move_model_to_eval(m: GraphModule):
|
||||
"""
|
||||
Move an exported GraphModule to eval mode.
|
||||
|
||||
This is equivalent to model.eval() but only for certain special ops like dropout.
|
||||
QAT users should call this before performing inference on the model.
|
||||
"""
|
||||
_replace_dropout_for_eval(m)
|
||||
return m
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ from torch.ao.quantization.backend_config import BackendConfig
|
|||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
|
||||
|
||||
__all__ = [
|
||||
"prepare_pt2e",
|
||||
"prepare_qat_pt2e",
|
||||
|
|
@ -93,6 +96,9 @@ def convert_pt2e(
|
|||
original_graph_meta = model.meta
|
||||
model = _convert_to_reference_decomposed_fx(model)
|
||||
model = _fold_conv_bn_qat(model)
|
||||
pm = PassManager([DuplicateDQPass()])
|
||||
model = pm(model).graph_module
|
||||
|
||||
if use_reference_representation:
|
||||
model = reference_representation_rewrite(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.pt2e.utils import _is_sym_size_node
|
||||
|
||||
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
|
||||
from torch.fx import Node
|
||||
|
||||
|
|
@ -23,16 +24,6 @@ def _annotate_output_qspec(node: Node, qspec):
|
|||
node.meta["quantization_annotation"] = quantization_annotation
|
||||
|
||||
|
||||
def _is_sym_size_node(node: Node):
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.sym_size.default
|
||||
or node.target == torch.ops.aten.sym_numel.default
|
||||
or node.target == torch.ops.aten.sym_numel
|
||||
or node.target == torch.ops.aten.sym_size
|
||||
)
|
||||
|
||||
|
||||
def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
|
||||
"""
|
||||
This utility is used to handle cases when dynami_shape=True tracing leads
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Callable, Dict, List, NamedTuple, Optional
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization.pt2e.utils import _is_sym_size_node
|
||||
from torch.ao.quantization.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
QuantizationSpec,
|
||||
|
|
@ -16,7 +17,6 @@ from torch.ao.quantization.quantizer import (
|
|||
from torch.ao.quantization.quantizer.utils import (
|
||||
_annotate_input_qspec_map,
|
||||
_annotate_output_qspec,
|
||||
_is_sym_size_node,
|
||||
_node_only_used_for_sym_size,
|
||||
)
|
||||
from torch.fx import Node
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user