[quant][pt2e] Actually support transitive sharing for SharedQuantizationSpec (#111172)

Summary:
Previously we actually did not really support this, this PR added the support.

Next
* clean up insert observer logic
* add allow_transitive_sharing boolean flag to allow people to turn this op for certain edges

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_shared_qspec_transitivity

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D50250789](https://our.internmc.facebook.com/intern/diff/D50250789)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111172
Approved by: https://github.com/kimishpatel
This commit is contained in:
Jerry Zhang 2023-10-19 21:59:55 -07:00 committed by PyTorch MergeBot
parent 1ad0f0b308
commit 43c211facb
4 changed files with 510 additions and 31 deletions

View File

@ -811,6 +811,255 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def _test_transitive_sharing_with_cat_helper(self, quantizer):
m = TestHelperModules.Conv2dWithTwoCat().eval()
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5), torch.randn(1, 6, 3, 3), torch.randn(1, 6, 3, 3))
# program capture
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
# make sure the two input observers and output are shared
conv_output_obs = []
for n in m.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default:
conv_output_obs.append(getattr(m, list(n.users)[0].target))
if n.op == "call_function" and n.target == torch.ops.aten.cat.default:
inputs = n.args[0]
input0 = inputs[0]
input1 = inputs[1]
assert input0.op == "call_module"
assert input1.op == "call_module"
obs_ins0 = getattr(m, input0.target)
obs_ins1 = getattr(m, input1.target)
assert obs_ins0 == obs_ins1
output_obs = list(n.users)[0]
assert output_obs.op == "call_module"
obs_ins2 = getattr(m, output_obs.target)
assert obs_ins0 == obs_ins2, "input observer does not match output"
assert len(conv_output_obs) == 2, "expecting two observer that follows conv2d ops"
# checking that the output observers for the two convs are shared as well
assert conv_output_obs[0] == conv_output_obs[1]
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=True)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 7,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 9,
}
node_list = [
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
ns.call_function(torch.ops.aten.cat.default),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
ns.call_function(torch.ops.aten.cat.default),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_shared_qspec_transitivity(self):
"""This tests the transitivity of SharedQuantizationSpec, that is
if A is shared with B, B is shared with C, then C should be shared with A as well
x1 -> conv1 -> cat1 -----> cat2
x2 -> conv2 -/ /
x3 -> add /
x4 /
both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor
so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same
sharing group after transitive sharing
"""
# TODO: refactor this to a common util
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.conv2d.default
):
input_act = node.args[0]
assert isinstance(input_act, Node)
weight = node.args[1]
assert isinstance(weight, Node)
bias = node.args[2]
assert isinstance(bias, Node)
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
weight_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_weight_observer,
)
bias_qspec = QuantizationSpec(
dtype=torch.float32,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
)
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
weight: weight_qspec,
bias: bias_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)
elif node.target is torch.ops.aten.cat.default:
cat_node = node
input_nodes = cat_node.args[0]
first_input_node = input_nodes[0]
input_qspec_map = {}
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
input_qspec_map[first_input_node] = act_qspec
share_qparams_with_input_act0_qspec = SharedQuantizationSpec((first_input_node, cat_node))
for input_node in input_nodes[1:]:
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
cat_node.meta[
"quantization_annotation"
] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
def test_shared_qspec_transitivity_case_2(self):
"""This tests the transitivity of SharedQuantizationSpec, that is
if A is shared with B, B is shared with C, then C should be shared with A as well
x1 -> conv1 -> cat1 -----> cat2
x2 -> conv2 -/ /
x3 -> add /
x4 /
both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor
so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same
sharing group after transitive sharing
the difference is that for this one, all edges and nodes are shared with the second input edge of cat
instead of the first input edge of cat as in previous example
"""
# TODO: refactor this to a common util
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.conv2d.default
):
input_act = node.args[0]
assert isinstance(input_act, Node)
weight = node.args[1]
assert isinstance(weight, Node)
bias = node.args[2]
assert isinstance(bias, Node)
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
weight_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_weight_observer,
)
bias_qspec = QuantizationSpec(
dtype=torch.float32,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
)
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
weight: weight_qspec,
bias: bias_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)
elif node.target is torch.ops.aten.cat.default:
cat_node = node
input_nodes = cat_node.args[0]
first_input_node = input_nodes[0]
second_input_node = input_nodes[1]
input_qspec_map = {}
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
input_qspec_map[second_input_node] = act_qspec
share_qparams_with_input_act1_qspec = SharedQuantizationSpec((second_input_node, cat_node))
input_qspec_map[first_input_node] = share_qparams_with_input_act1_qspec
cat_node.meta[
"quantization_annotation"
] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act1_qspec,
_annotated=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
def test_int16(self):
class Int16ActQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

View File

@ -5,13 +5,13 @@ from torch.ao.quantization.fx.prepare import (
_get_output_act_obs_or_fq,
_get_dtype_and_is_dynamic,
_insert_obs_or_fq,
_maybe_insert_output_observer_for_node,
_save_state,
_is_activation_post_process_node,
_get_qspec_for_arg,
_create_obs_or_fq_from_qspec,
)
from torch.fx import (
GraphModule,
Graph,
Node,
)
from torch.fx.node import Argument
@ -19,14 +19,217 @@ from torch.fx.node import Argument
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from typing import Dict, Tuple, Union, Any
from typing import Dict, Tuple, Union, Any, Optional
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
EdgeOrNode,
SharedQuantizationSpec,
QuantizationSpecBase,
)
from torch.ao.quantization import ObserverOrFakeQuantize
# TODO: make pt2e folder private?
__all__ = [
"prepare",
]
def _find_root(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
"""Find the root node for the sharing tree
Args:
edge_or_node: edge/node that we want to find the root
shared_with_map: each edge/node points to the parent, the root node will points to itself
Returns:
root edge/node
"""
parent = shared_with_map[edge_or_node]
if parent == edge_or_node:
return edge_or_node
root = _find_root(parent, shared_with_map)
# path compression
shared_with_map[edge_or_node] = root
return root
def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None:
"""Merge the subtree for `child` with `parent`, the order is important here
"""
root_parent = _find_root(parent, shared_with_map)
root_child = _find_root(child, shared_with_map)
# union the two trees by pointing the root of child to root of parent
shared_with_map[root_child] = root_parent
def _update_shared_with(edge_or_node: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]):
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
configuration and established the relationship between `edge_or_node` with the edge/node that it
is pointing to, we'll use this information in the end to get the group id
"""
if isinstance(qspec, SharedQuantizationSpec):
sharing_with = qspec.edge_or_node
# we point from edge_or_node to the node that it is sharing_with, e.g.
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
_union(sharing_with, edge_or_node, shared_with_map)
def _find_root_qspec(
qspec: QuantizationSpecBase,
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
) -> QuantizationSpecBase:
"""Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
if qspec is SharedQuantizationSpec
(1). tries to find the root node for the node that the qspec points to
(2). recursively find the root qspec based on the qspec for the root node
"""
if isinstance(qspec, SharedQuantizationSpec):
sharing_with = qspec.edge_or_node
root = _find_root(sharing_with, shared_with_map)
qspec = edge_or_node_to_qspec[root]
return _find_root_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
return qspec
def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
return (
hasattr(qspec_a, "dtype") and
hasattr(qspec_b, "dtype") and
qspec_a.dtype == qspec_b.dtype
)
def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
return (
hasattr(qspec_a, "is_dynamic") and
hasattr(qspec_b, "is_dynamic") and
qspec_a.is_dynamic == qspec_b.is_dynamic
)
def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]:
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes
"""
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
for n in model.graph.nodes:
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
qa = n.meta["quantization_annotation"]
for input_to_n, qspec in qa.input_qspec_map.items():
input_edge = (input_to_n, n)
edge_or_node_to_qspec[input_edge] = qspec
if qa.output_qspec is not None:
output_node = n
qspec = qa.output_qspec
edge_or_node_to_qspec[output_node] = qspec
return edge_or_node_to_qspec
def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
"""Map from edge/node to the group ID, generated from quantization annotations,
edge/node with the same group ID should use the same observer/fake_quant instance
This is applying SharedQuantizationSpec configuration and map each edge/node to a group
There is another implicit sharing that's built in the quantization, when we have the following:
* op1 -> op2
* output of op1: int8_qspec
* (op1 -> op2) input edge: int8_qspec
we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
Figuring out the correct group ID for all edge/node is a standard union find problem:
https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
Args:
edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
Returns:
edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
belongs to the same group should have the same id
Example:
op2 -> cat1 -> cat2
op1 / /
op3
edge_or_node_to_qspec: {
op1: int8_qspec,
op2: int8_qspec,
(op1, cat1): int8_qspc,
(op2, cat1): SharedQuantizationSpec((op1, cat1)),
cat1: SharedQuantizationSpec((op1, cat1)),
(op3, cat2): int8_qspec,
(cat1, cat2): SharedQuantizationSpec((op3, cat2)),
cat2: SharedQuantizationSpec((op3, cat2)),
}
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
edge_or_node_to_group_id: {
op1: 1,
op2: 1,
(op1, cat1): 1,
(op2, cat1): 1,
cat1: 1,
(op3, cat2): 1,
(cat1, cat2): 1,
cat2: 1,
}
# everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
# connects the two sharing group around cat1 and cat2 op due to transitive sharing
"""
# means the observer of key should be shared with observer with value, by default it will
# be shared with itself
shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()}
for edge_or_node, qspec in edge_or_node_to_qspec.items():
if isinstance(edge_or_node, torch.fx.Node):
output_node = edge_or_node
_update_shared_with(output_node, qspec, shared_with_map)
else:
input_edge = edge_or_node
input_edge_root = _find_root(input_edge, shared_with_map)
input_edge_root_qspec = edge_or_node_to_qspec[input_edge_root]
input_edge_root_qspec = _find_root_qspec(input_edge_root_qspec, edge_or_node_to_qspec, shared_with_map)
# find root_qspec for `arg` Node (the output of previous node)
assert isinstance(input_edge, tuple)
arg, n = input_edge
arg_as_output_root_qspec = None
if arg in edge_or_node_to_qspec:
arg_as_output_qspec = edge_or_node_to_qspec[arg]
arg_as_output_root_qspec = _find_root_qspec(arg_as_output_qspec, edge_or_node_to_qspec, shared_with_map)
# TODO: add assertions for types of root qspecs
if (
arg_as_output_root_qspec is not None and
_has_same_dtype(arg_as_output_root_qspec, input_edge_root_qspec) and
_has_same_is_dynamic(arg_as_output_root_qspec, input_edge_root_qspec)
):
# the input arg to the node should reuse the existing output observer for arg
# since dtype is the same (we may want to extend this to be a more strict check
# in the future)
# so we point from `input_edge` to `arg` (output of the argument)
_union(arg, input_edge, shared_with_map)
_update_shared_with(input_edge, qspec, shared_with_map)
# now that we get the sharing relations between all edges and nodes, we can assingn group ids
cur_group_id = 0
edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
for edge_or_node in shared_with_map.keys():
root = _find_root(edge_or_node, shared_with_map)
if root not in edge_or_node_to_group_id:
edge_or_node_to_group_id[root] = cur_group_id
cur_group_id += 1
edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
return edge_or_node_to_group_id
def _get_obs_or_fq_map(
edge_or_node_to_group_id: Dict[EdgeOrNode, int],
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
is_qat: bool
) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]:
"""Generates the EdgeOrNode to observer/fake_quant instances
Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
instances
"""
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {}
for edge_or_node, qspec in edge_or_node_to_qspec.items():
group_id = edge_or_node_to_group_id[edge_or_node]
if group_id not in group_id_to_obs_or_fq:
# TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
# the implementation for _create_obs_or_fq_from_qspec
group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat)
obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
return obs_or_fq_map
def _maybe_insert_input_observer_for_arg_or_kwarg(
node: Union[Node, Any],
arg: Argument,
@ -72,21 +275,11 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
observed_arg = arg.args[0]
assert isinstance(observed_arg, Node), f"expect observed argument to be a Node, but got: {type(observed_arg)}"
assert observed_arg in obs_or_fq_map, \
f"can't refer to a node that does not have observer/fake_quant inserted yet: {observed_arg}"
input_qspec_map = quantization_annotation.input_qspec_map
input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
if isinstance(input_arg_qspec, SharedQuantizationSpec):
# if the argument is set to use SharedQuantizationSpec, we will
# reset the observer instance to align with the configured edge/node
obs_or_fq_name = arg.target
setattr(model, obs_or_fq_name, arg_as_input_act_obs_or_fq)
named_modules[obs_or_fq_name] = arg_as_input_act_obs_or_fq
else:
# otherwise reuse the existing obs/fq
arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg]
f"can't find a sharing group for node: {observed_arg}"
# reuse the existing obs/fq
arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg]
# we don't need to insert new observer node
new_arg = arg
obs_or_fq_map[(observed_arg, node)] = arg_as_input_act_obs_or_fq
else:
# skip inserting new observers if there is an observer inserted for the arg before
# that has the same dtype that we want to insert here
@ -113,23 +306,24 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
assert arg_as_input_act_obs_or_fq is not None
if existing_obs_node is None:
maybe_observed_arg = arg
# When quantizing two layers with different configs we can have
# conv2d (int8) -> avgpool(uint8)
# In this case observer insertion for avgpool will come here but the input
# to avgpool will be output observer of conv2d
# Now the obs map that we update must correspond to the original input of
# avgpool and not the output obs of conv2d
# This is because when referring to the edge, quantizer would refer to
# original input and not the observed one.
while _is_activation_post_process_node(arg, named_modules):
arg = arg.args[0] # type: ignore[assignment]
arg_as_input_act_obs_or_fq = obs_or_fq_map[(arg, node)]
new_obs_node = _insert_obs_or_fq(
arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)
maybe_observed_arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)
# override this arg to be the observed arg
new_arg = new_obs_node
else:
new_arg = existing_obs_node
# When quantizing two layers with different configs we can have
# conv2d (int8) -> avgpool(uint8)
# In this case observer insertion for avgpool will come here but the input
# to avgpool will be output observer of conv2d
# Now the obs map that we update must correspond to the original input of
# avgpool and not the output obs of conv2d
# This is because when referring to the edge, quantizer would refer to
# original input and not the observed one.
while _is_activation_post_process_node(arg, named_modules):
arg = arg.args[0] # type: ignore[assignment]
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
return new_arg
@ -172,6 +366,19 @@ def _maybe_insert_input_observers_for_node(
# assign the new args to the node, inplace
node.args = tuple(new_args)
def _maybe_insert_output_observer_for_node(
node: Node,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
graph: Graph,
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> Optional[Node]:
if node in obs_or_fq_map:
output_act_obs_or_fq = obs_or_fq_map[node]
return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
return None
def _maybe_insert_input_and_output_observers_for_node(
node: Node,
model: torch.fx.GraphModule,
@ -213,7 +420,8 @@ def _maybe_insert_input_and_output_observers_for_node(
return
# this returns the new observer node if it was needed
maybe_output_obs_node = _maybe_insert_output_observer_for_node(node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
if maybe_output_obs_node is None:
return
@ -246,9 +454,17 @@ def prepare(
# Since we are mutating the graph as we go, we iterate over the original
# nodes before observer insertion, instead of model.graph.nodes.
nodes_before_observation = list(model.graph.nodes)
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
# At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
# all edge/nodes that belongs to the same group will use the same instance
# and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
# instance
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
for node in nodes_before_observation:
# TODO: simplify logic for inserting observers
_maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat)
model = GraphModule(model, model.graph)

View File

@ -231,8 +231,8 @@ def convert_pt2e(
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
pm = PassManager([DuplicateDQPass()])
model = pm(model).graph_module
model = pm(model).graph_module
pm = PassManager([PortNodeMetaForQDQ()])
model = pm(model).graph_module

View File

@ -2571,6 +2571,20 @@ class TestHelperModules:
z = torch.cat([x, y], dim=1)
return z
class Conv2dWithTwoCat(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
def forward(self, x1, x2, x3, x4):
x1 = self.conv1(x1)
x2 = self.conv2(x2)
y = torch.cat([x1, x2], dim=1)
z = x3 + x4
w = torch.cat([z, y])
return w
class EmbeddingModule(torch.nn.Module):
def __init__(self):
super().__init__()