[Quant] Add DQ duplication pass (#107900)

Summary:
During convert step observers are first replaced by Q-DQ pair. In some
scenarios like following output DQ has a fan out.

                 ---> OP2 -> Q -> DQ
                /
OP -> Q -> DQ -
                \
                 ---> OP3 -> Q -> DQ

If either op OP2 or OP3 are configured to be quantized, then the input
is expected to quantized. In this case quantized equivalent of some
pattern, that quantizer asked to be quantized, should look like:
[DQ -> {pattern} -> Q]. However, in scenario like above where DQ node
is shared between multiple "quantized" patterns, boundary of "quantized"
pattern is not clear because DQ now belongs to multiple quantized
patterns.

This poses challenge for:
- Porting metadata: which "quantized" partition this DQ node belongs
- Quantized representation, equivalently, needs to identify
self-contained quantized pattern that is replaced by its equivalent pattern
that captures compute in the quantized precision.

Test Plan:
test_duplicate_dq_pass

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D48663147](https://our.internmc.facebook.com/intern/diff/D48663147)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107900
Approved by: https://github.com/jerryzh168, https://github.com/andrewor14, https://github.com/leslie-fang-intel
ghstack dependencies: #107105, #107106, #107899
This commit is contained in:
Kimish Patel 2023-09-01 08:38:16 -07:00 committed by PyTorch MergeBot
parent f8d1ca9835
commit eb67c452c8
12 changed files with 563 additions and 75 deletions

View File

@ -513,6 +513,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# However when dq has multiple users we will have
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# \-> to_float -> sub -> mul]
# So for now we will discount one pattern here
# 2. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-conv pattern matched in quantization weight prepack * 2
@ -525,8 +532,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
8,
39,
7,
37,
check_quantization=True,
)
@ -573,6 +580,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# However when dq has multiple users we will have
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# \-> to_float -> sub -> mul]
# So for now we will discount one pattern here
# 2. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-conv pattern matched in quantization weight prepack * 2
@ -585,8 +599,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
8,
40,
7,
38,
check_quantization=True,
)
@ -629,6 +643,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
# 1. Pair of to_int8 and to_fp32 at conv input * 2, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# However when dq has multiple users we will have
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# \-> to_float -> sub -> mul]
# So for now we will discount one pattern here
# 2. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-conv pattern matched in quantization weight prepack * 3
@ -641,8 +662,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
11,
54,
10,
52,
check_quantization=True,
)
@ -764,6 +785,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
# 1. Pair of to_int8 and to_fp32 at linear input * 2, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# However when dq has multiple users we will have
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# \-> to_float -> sub -> mul]
# So for now we will discount one pattern here
# 2. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-linear pattern matched in quantization weight prepack * 3
@ -773,8 +801,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
self._test_common(
mod,
(v,),
11,
50,
10,
48,
check_quantization=True,
)

View File

@ -0,0 +1,313 @@
# Owner(s): ["oncall: quantization"]
import copy
import torch
import torch._export as export
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.pt2e.utils import _find_q_dq_node_for_user
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)
from torch.testing._internal.common_quantization import QuantizationTestCase
class TestHelperModules:
class Conv2dWithObsSharingOps(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = x.view(-1, 3)
x = self.linear(x)
return x
class Conv2dWithSharedDQ(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 1)
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.conv1(x)
z = x.view(-1, 3)
w = self.linear(z)
y = self.conv2(x)
add_output = x + y
extra_output = x * 2
return w, add_output, extra_output
class ModuleForDifferentQconfig(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 1)
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
w = self.adaptive_avg_pool2d(x)
y = self.conv2(x)
add_output = x + y
extra_output = x + 2
return w, add_output, extra_output
class TestDuplicateDQPass(QuantizationTestCase):
def _test_duplicate_dq(
self,
model,
example_inputs,
quantizer,
):
m_eager = model.eval()
# program capture
m = copy.deepcopy(m_eager)
m = export.capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
pt2_quant_output = m(*example_inputs)
for n in m.graph.nodes:
annotation = n.meta.get("quantization_annotation", None)
if annotation is not None:
input_qspec_map = annotation.input_qspec_map
for input_node, qspec in input_qspec_map.items():
if (
qspec is not None
and hasattr(qspec, "dtype")
and qspec.dtype != torch.float
):
q_node, dq_node = _find_q_dq_node_for_user(input_node, n)
if dq_node is None:
raise ValueError(
f"No dq node found for {n}, even though {n} annotated for quantization."
)
self.assertEqual(len(dq_node.users.keys()), 1)
def test_no_need_for_duplicate_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Check quantization tags on conv2d, avgpool and linear are correctly set
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
)
def test_simple_duplicate_dq(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> view_copy --> linear
|
-----> mul
There should be three dq nodes because output for the
first conv2d is fed to next conv2d, add, and view_copy + linear.
All three are quantized.
Thus DQ node is not duplicated for those three uses
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithSharedDQ(),
example_inputs,
BackendAQuantizer(),
)
def test_no_add_quant_duplicate_dq(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> view_copy --> linear
|
-----> mul
There should be three dq nodes because output for the
first conv2d is fed to next conv2d, and view_copy + linear.
Both are quantized.
However the skip connection to add and mul are not quantized.
Thus DQ node is not duplicated for those two uses
"""
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithSharedDQ(),
example_inputs,
BackendAQuantizer(),
)
def test_avgpool_use_different_qconfig(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> adaptive_avgpool2d (different qconfig)
|
-----> add
output
conv2d -> dq -> conv2d -> add
| |
-------> dq ----->
|
-> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig)
|
-> dq -----> add
"""
def _get_uint8_quantization_config():
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
eps=2**-12
),
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
extra_args: Dict[str, Any] = {"eps": 2**-12}
weight_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PlaceholderObserver
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.float,
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr,
)
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
)
return quantization_config
class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
avgpool_qconfig = _get_uint8_quantization_config()
OP_TO_ANNOTATOR["conv2d"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)
for n in gm.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
qspec = avgpool_qconfig.input_activation
input_act = n.args[0]
output_qspec = SharedQuantizationSpec((input_act, n))
n.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={input_act: qspec},
output_qspec=output_qspec,
_annotated=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.ModuleForDifferentQconfig(),
example_inputs,
BackendAQuantizer(),
)

View File

@ -1285,7 +1285,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_occurrence = {
# two input and one output for first add, and output for second add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
@ -1312,7 +1312,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
node_occurrence = {
# two input and one output for first add, and output for second add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,

View File

@ -372,7 +372,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# 2 conv will share same input quant/dequant
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
@ -426,7 +426,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
# 2 conv will share same input quant/dequant
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
@ -457,7 +457,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
quantizer = X86InductorQuantizer().set_global(xiq.get_default_x86_inductor_quantization_config())
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
torch.ops.quantized_decomposed.quantize_per_channel.default: 4,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 4,
}

View File

@ -137,12 +137,13 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
# Note: These patterns are comprised of torch ops and for internal use only.
# They are exported to aten graphs before being passed to the FX subgraph rewriter.
# TODO: find a better way to express this path without having to import
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
# `torch.ao.quantization.pt2e`, which interferes with memory profiling
FILENAME_ALLOWLIST |= {
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
_module_dir(torch) + "ao/quantization/quantizer/xnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
_module_dir(torch) + "ao/quantization/pt2e/eval_utils.py",
}
FILENAME_ALLOWLIST |= {

View File

@ -12,7 +12,7 @@ from .quantization_mappings import * # type: ignore[no-redef]
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
from .pt2e.utils import move_model_to_eval
from .pt2e.eval_utils import _move_model_to_eval as move_model_to_eval
from typing import Union, List, Callable, Tuple, Optional
from torch import Tensor
import torch

View File

@ -0,0 +1,59 @@
import logging
import torch
from torch._export.pass_base import _ExportPassBase
from torch.ao.quantization.pt2e.utils import (
_filter_sym_size_users,
_is_valid_annotation,
)
from torch.fx.node import map_arg
from torch.fx.passes.infra.pass_base import PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
__all__ = ["DuplicateDQPass"]
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
def _maybe_duplicate_dq(
gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
):
annotation = user.meta.get("quantization_annotation", None)
if not _is_valid_annotation(annotation):
return
with gm.graph.inserting_after(dq_node):
new_node = gm.graph.node_copy(dq_node)
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
if n == dq_node:
return new_node
else:
return n
new_args = map_arg(user.args, maybe_replace_node)
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
user.args = new_args
user.kwargs = new_kwargs
class DuplicateDQPass(_ExportPassBase):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
dq_users = _filter_sym_size_users(node)
if len(dq_users) <= 1:
continue
for user in dq_users:
_maybe_duplicate_dq(graph_module, node, user)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)

View File

@ -0,0 +1,52 @@
import torch
import torch.nn.functional as F
def _replace_dropout_for_eval(m: torch.fx.GraphModule):
"""
Replace the aten training dropout pattern with a noop, intended for eval.
For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
effectively turns these dropout ops into noops. For exported models, however,
this is not done automatically, since the aten dropout patterns previously generated
for training remain in the graph. Here we rewrite these dropout patterns with noops
to avoid incorrectly applying further dropout during eval.
See https://github.com/pytorch/pytorch/issues/103681.
"""
# Avoid circular dependencies
from .utils import get_aten_graph_module
def dropout_train(x):
return F.dropout(x, p=0.5, training=True)
def dropout_eval(x):
return F.dropout(x, p=0.5, training=False)
example_inputs = (torch.randn(1),)
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
match_filters=[],
ignore_literals=True,
)
m.recompile()
# TODO: also support move_model_to_train
# TODO: also support standalone batchnorm
def _move_model_to_eval(model: torch.fx.GraphModule):
"""
Move an exported GraphModule to eval mode.
This is equivalent to model.eval() but only for certain special ops like dropout.
QAT users should call this before performing inference on the model.
"""
_replace_dropout_for_eval(model)
return model

View File

@ -1,22 +1,105 @@
import torch
from torch._export import capture_pre_autograd_graph
from torch.fx import (
GraphModule,
Node,
)
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
import torch.nn.functional as F
from torch.nn.utils.fusion import fuse_conv_bn_weights
import operator
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
from torch.utils._pytree import LeafSpec
# Makes sure that quantized_decomposed ops are registered
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torch.ao.quantization.quantizer import QuantizationAnnotation
__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"move_model_to_eval",
"remove_tensor_overload_for_qdq_ops",
]
_QUANTIZE_OPS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
def _is_connected(next_node: torch.fx.Node, target: torch.fx.Node) -> bool:
if target.op == "output":
return False
if next_node == target:
return True
for n in next_node.users.keys():
if _is_connected(n, target):
return True
return False
def _find_q_dq_node_for_user(
produer: torch.fx.Node, user: torch.fx.Node
) -> Tuple[Any, Any]:
"""
Find d, dq pair corresponding to [producer ... -> q -> dq -> user]
Utils works by finding dq arg of user and ensuring it is connected to
producer
"""
dq_node = None
for n in user.args:
if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
if _is_connected(produer, n):
dq_node = n
break
if dq_node is None:
for n in user.kwargs:
if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
if _is_connected(produer, n):
dq_node = n
break
if dq_node is None:
return (None, None)
q_node = None
if dq_node.args[0].op == "call_function" and dq_node.args[0].target in _QUANTIZE_OPS:
q_node = dq_node.args[0]
return (q_node, dq_node)
def _is_sym_size_node(node: Node):
return (
node.op == "call_function"
and node.target == torch.ops.aten.sym_size.default
or node.target == torch.ops.aten.sym_numel.default
or node.target == torch.ops.aten.sym_numel
or node.target == torch.ops.aten.sym_size
)
def _filter_sym_size_users(node: torch.fx.Node) -> List[torch.fx.Node]:
node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
return node_users
def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
if annotation is None:
return False
input_qspec_map = annotation.input_qspec_map
output_qspec = annotation.output_qspec
if len(input_qspec_map) == 0 and output_qspec is None:
return False
return True
def _get_tensor_constant_from_node(node, m):
if node is None:
return None
@ -143,8 +226,6 @@ def get_aten_graph_module(
"""
Convert the pattern to an FX graph with decomposed aten ops.
"""
# Avoid circular dependencies
from torch._export import capture_pre_autograd_graph
aten_pattern = capture_pre_autograd_graph(
pattern,
example_inputs,
@ -175,37 +256,6 @@ def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
if n.target in _MAP:
n.target = _MAP[n.target]
def _replace_dropout_for_eval(m: GraphModule):
"""
Replace the aten training dropout pattern with a noop, intended for eval.
For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
effectively turns these dropout ops into noops. For exported models, however,
this is not done automatically, since the aten dropout patterns previously generated
for training remain in the graph. Here we rewrite these dropout patterns with noops
to avoid incorrectly applying further dropout during eval.
See https://github.com/pytorch/pytorch/issues/103681.
"""
def dropout_train(x):
return F.dropout(x, p=0.5, training=True)
def dropout_eval(x):
return F.dropout(x, p=0.5, training=False)
example_inputs = (torch.randn(1),)
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
match_filters=[],
ignore_literals=True,
)
m.recompile()
def _is_literal(arg):
if isinstance(arg, (int, float)):
return True
@ -376,15 +426,3 @@ def _replace_literals_with_existing_placeholders(
new_args = tuple(new_args)
node.args = new_args
return gm
# TODO: also support move_model_to_train
# TODO: also support standalone batchnorm
def move_model_to_eval(m: GraphModule):
"""
Move an exported GraphModule to eval mode.
This is equivalent to model.eval() but only for certain special ops like dropout.
QAT users should call this before performing inference on the model.
"""
_replace_dropout_for_eval(m)
return m

View File

@ -26,6 +26,9 @@ from torch.ao.quantization.backend_config import BackendConfig
from typing import Any, Tuple
from torch.fx.passes.infra.pass_manager import PassManager
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
__all__ = [
"prepare_pt2e",
"prepare_qat_pt2e",
@ -93,6 +96,9 @@ def convert_pt2e(
original_graph_meta = model.meta
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
pm = PassManager([DuplicateDQPass()])
model = pm(model).graph_module
if use_reference_representation:
model = reference_representation_rewrite(model)

View File

@ -1,6 +1,7 @@
from typing import List
import torch
from torch.ao.quantization.pt2e.utils import _is_sym_size_node
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
from torch.fx import Node
@ -23,16 +24,6 @@ def _annotate_output_qspec(node: Node, qspec):
node.meta["quantization_annotation"] = quantization_annotation
def _is_sym_size_node(node: Node):
return (
node.op == "call_function"
and node.target == torch.ops.aten.sym_size.default
or node.target == torch.ops.aten.sym_numel.default
or node.target == torch.ops.aten.sym_numel
or node.target == torch.ops.aten.sym_size
)
def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
"""
This utility is used to handle cases when dynami_shape=True tracing leads

View File

@ -6,6 +6,7 @@ from typing import Callable, Dict, List, NamedTuple, Optional
import torch
import torch.nn.functional as F
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.pt2e.utils import _is_sym_size_node
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
@ -16,7 +17,6 @@ from torch.ao.quantization.quantizer import (
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
_is_sym_size_node,
_node_only_used_for_sym_size,
)
from torch.fx import Node