mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] Actually support transitive sharing for SharedQuantizationSpec (#111172)
Summary: Previously we actually did not really support this, this PR added the support. Next * clean up insert observer logic * add allow_transitive_sharing boolean flag to allow people to turn this op for certain edges Test Plan: python test/test_quantization.py TestQuantizePT2E.test_shared_qspec_transitivity Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D50250789](https://our.internmc.facebook.com/intern/diff/D50250789) Pull Request resolved: https://github.com/pytorch/pytorch/pull/111172 Approved by: https://github.com/kimishpatel
This commit is contained in:
parent
1ad0f0b308
commit
43c211facb
|
|
@ -811,6 +811,255 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
||||
)
|
||||
|
||||
def _test_transitive_sharing_with_cat_helper(self, quantizer):
|
||||
m = TestHelperModules.Conv2dWithTwoCat().eval()
|
||||
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5), torch.randn(1, 6, 3, 3), torch.randn(1, 6, 3, 3))
|
||||
|
||||
# program capture
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
# make sure the two input observers and output are shared
|
||||
conv_output_obs = []
|
||||
for n in m.graph.nodes:
|
||||
if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default:
|
||||
conv_output_obs.append(getattr(m, list(n.users)[0].target))
|
||||
if n.op == "call_function" and n.target == torch.ops.aten.cat.default:
|
||||
inputs = n.args[0]
|
||||
input0 = inputs[0]
|
||||
input1 = inputs[1]
|
||||
assert input0.op == "call_module"
|
||||
assert input1.op == "call_module"
|
||||
obs_ins0 = getattr(m, input0.target)
|
||||
obs_ins1 = getattr(m, input1.target)
|
||||
assert obs_ins0 == obs_ins1
|
||||
|
||||
output_obs = list(n.users)[0]
|
||||
assert output_obs.op == "call_module"
|
||||
obs_ins2 = getattr(m, output_obs.target)
|
||||
assert obs_ins0 == obs_ins2, "input observer does not match output"
|
||||
|
||||
assert len(conv_output_obs) == 2, "expecting two observer that follows conv2d ops"
|
||||
# checking that the output observers for the two convs are shared as well
|
||||
assert conv_output_obs[0] == conv_output_obs[1]
|
||||
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m, fold_quantize=True)
|
||||
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
): 7,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 9,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
),
|
||||
ns.call_function(torch.ops.aten.cat.default),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
),
|
||||
ns.call_function(torch.ops.aten.cat.default),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
||||
)
|
||||
|
||||
def test_shared_qspec_transitivity(self):
|
||||
"""This tests the transitivity of SharedQuantizationSpec, that is
|
||||
if A is shared with B, B is shared with C, then C should be shared with A as well
|
||||
|
||||
x1 -> conv1 -> cat1 -----> cat2
|
||||
x2 -> conv2 -/ /
|
||||
x3 -> add /
|
||||
x4 /
|
||||
|
||||
both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor
|
||||
so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same
|
||||
sharing group after transitive sharing
|
||||
"""
|
||||
# TODO: refactor this to a common util
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
for node in model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.conv2d.default
|
||||
):
|
||||
input_act = node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
weight = node.args[1]
|
||||
assert isinstance(weight, Node)
|
||||
bias = node.args[2]
|
||||
assert isinstance(bias, Node)
|
||||
act_qspec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_observer,
|
||||
)
|
||||
weight_qspec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
||||
)
|
||||
bias_qspec = QuantizationSpec(
|
||||
dtype=torch.float32,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
||||
)
|
||||
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
input_act: act_qspec,
|
||||
weight: weight_qspec,
|
||||
bias: bias_qspec,
|
||||
},
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
elif node.target is torch.ops.aten.cat.default:
|
||||
cat_node = node
|
||||
input_nodes = cat_node.args[0]
|
||||
first_input_node = input_nodes[0]
|
||||
input_qspec_map = {}
|
||||
act_qspec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_observer,
|
||||
)
|
||||
input_qspec_map[first_input_node] = act_qspec
|
||||
share_qparams_with_input_act0_qspec = SharedQuantizationSpec((first_input_node, cat_node))
|
||||
for input_node in input_nodes[1:]:
|
||||
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
|
||||
|
||||
cat_node.meta[
|
||||
"quantization_annotation"
|
||||
] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act0_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
|
||||
|
||||
def test_shared_qspec_transitivity_case_2(self):
|
||||
"""This tests the transitivity of SharedQuantizationSpec, that is
|
||||
if A is shared with B, B is shared with C, then C should be shared with A as well
|
||||
|
||||
x1 -> conv1 -> cat1 -----> cat2
|
||||
x2 -> conv2 -/ /
|
||||
x3 -> add /
|
||||
x4 /
|
||||
|
||||
both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor
|
||||
so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same
|
||||
sharing group after transitive sharing
|
||||
|
||||
the difference is that for this one, all edges and nodes are shared with the second input edge of cat
|
||||
instead of the first input edge of cat as in previous example
|
||||
"""
|
||||
# TODO: refactor this to a common util
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
for node in model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.conv2d.default
|
||||
):
|
||||
input_act = node.args[0]
|
||||
assert isinstance(input_act, Node)
|
||||
weight = node.args[1]
|
||||
assert isinstance(weight, Node)
|
||||
bias = node.args[2]
|
||||
assert isinstance(bias, Node)
|
||||
act_qspec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_observer,
|
||||
)
|
||||
weight_qspec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
||||
)
|
||||
bias_qspec = QuantizationSpec(
|
||||
dtype=torch.float32,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
|
||||
)
|
||||
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
input_act: act_qspec,
|
||||
weight: weight_qspec,
|
||||
bias: bias_qspec,
|
||||
},
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
elif node.target is torch.ops.aten.cat.default:
|
||||
cat_node = node
|
||||
input_nodes = cat_node.args[0]
|
||||
first_input_node = input_nodes[0]
|
||||
second_input_node = input_nodes[1]
|
||||
input_qspec_map = {}
|
||||
act_qspec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_observer,
|
||||
)
|
||||
input_qspec_map[second_input_node] = act_qspec
|
||||
share_qparams_with_input_act1_qspec = SharedQuantizationSpec((second_input_node, cat_node))
|
||||
input_qspec_map[first_input_node] = share_qparams_with_input_act1_qspec
|
||||
|
||||
cat_node.meta[
|
||||
"quantization_annotation"
|
||||
] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
output_qspec=share_qparams_with_input_act1_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
|
||||
|
||||
def test_int16(self):
|
||||
class Int16ActQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ from torch.ao.quantization.fx.prepare import (
|
|||
_get_output_act_obs_or_fq,
|
||||
_get_dtype_and_is_dynamic,
|
||||
_insert_obs_or_fq,
|
||||
_maybe_insert_output_observer_for_node,
|
||||
_save_state,
|
||||
_is_activation_post_process_node,
|
||||
_get_qspec_for_arg,
|
||||
_create_obs_or_fq_from_qspec,
|
||||
)
|
||||
from torch.fx import (
|
||||
GraphModule,
|
||||
Graph,
|
||||
Node,
|
||||
)
|
||||
from torch.fx.node import Argument
|
||||
|
|
@ -19,14 +19,217 @@ from torch.fx.node import Argument
|
|||
from torch.ao.quantization import QConfigMapping
|
||||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||
from typing import Dict, Tuple, Union, Any
|
||||
from typing import Dict, Tuple, Union, Any, Optional
|
||||
from torch.ao.quantization.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
EdgeOrNode,
|
||||
SharedQuantizationSpec,
|
||||
QuantizationSpecBase,
|
||||
)
|
||||
from torch.ao.quantization import ObserverOrFakeQuantize
|
||||
|
||||
# TODO: make pt2e folder private?
|
||||
__all__ = [
|
||||
"prepare",
|
||||
]
|
||||
|
||||
def _find_root(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
|
||||
"""Find the root node for the sharing tree
|
||||
Args:
|
||||
edge_or_node: edge/node that we want to find the root
|
||||
shared_with_map: each edge/node points to the parent, the root node will points to itself
|
||||
|
||||
Returns:
|
||||
root edge/node
|
||||
"""
|
||||
parent = shared_with_map[edge_or_node]
|
||||
if parent == edge_or_node:
|
||||
return edge_or_node
|
||||
root = _find_root(parent, shared_with_map)
|
||||
# path compression
|
||||
shared_with_map[edge_or_node] = root
|
||||
return root
|
||||
|
||||
def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None:
|
||||
"""Merge the subtree for `child` with `parent`, the order is important here
|
||||
"""
|
||||
root_parent = _find_root(parent, shared_with_map)
|
||||
root_child = _find_root(child, shared_with_map)
|
||||
# union the two trees by pointing the root of child to root of parent
|
||||
shared_with_map[root_child] = root_parent
|
||||
|
||||
def _update_shared_with(edge_or_node: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]):
|
||||
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
|
||||
configuration and established the relationship between `edge_or_node` with the edge/node that it
|
||||
is pointing to, we'll use this information in the end to get the group id
|
||||
"""
|
||||
if isinstance(qspec, SharedQuantizationSpec):
|
||||
sharing_with = qspec.edge_or_node
|
||||
# we point from edge_or_node to the node that it is sharing_with, e.g.
|
||||
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
|
||||
_union(sharing_with, edge_or_node, shared_with_map)
|
||||
|
||||
def _find_root_qspec(
|
||||
qspec: QuantizationSpecBase,
|
||||
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
|
||||
shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
|
||||
) -> QuantizationSpecBase:
|
||||
"""Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
|
||||
if qspec is SharedQuantizationSpec
|
||||
(1). tries to find the root node for the node that the qspec points to
|
||||
(2). recursively find the root qspec based on the qspec for the root node
|
||||
"""
|
||||
if isinstance(qspec, SharedQuantizationSpec):
|
||||
sharing_with = qspec.edge_or_node
|
||||
root = _find_root(sharing_with, shared_with_map)
|
||||
qspec = edge_or_node_to_qspec[root]
|
||||
return _find_root_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
||||
return qspec
|
||||
|
||||
def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
|
||||
return (
|
||||
hasattr(qspec_a, "dtype") and
|
||||
hasattr(qspec_b, "dtype") and
|
||||
qspec_a.dtype == qspec_b.dtype
|
||||
)
|
||||
|
||||
def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
|
||||
return (
|
||||
hasattr(qspec_a, "is_dynamic") and
|
||||
hasattr(qspec_b, "is_dynamic") and
|
||||
qspec_a.is_dynamic == qspec_b.is_dynamic
|
||||
)
|
||||
|
||||
def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]:
|
||||
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes
|
||||
"""
|
||||
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
|
||||
for n in model.graph.nodes:
|
||||
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
|
||||
qa = n.meta["quantization_annotation"]
|
||||
for input_to_n, qspec in qa.input_qspec_map.items():
|
||||
input_edge = (input_to_n, n)
|
||||
edge_or_node_to_qspec[input_edge] = qspec
|
||||
if qa.output_qspec is not None:
|
||||
output_node = n
|
||||
qspec = qa.output_qspec
|
||||
edge_or_node_to_qspec[output_node] = qspec
|
||||
return edge_or_node_to_qspec
|
||||
|
||||
def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
|
||||
"""Map from edge/node to the group ID, generated from quantization annotations,
|
||||
edge/node with the same group ID should use the same observer/fake_quant instance
|
||||
|
||||
This is applying SharedQuantizationSpec configuration and map each edge/node to a group
|
||||
There is another implicit sharing that's built in the quantization, when we have the following:
|
||||
* op1 -> op2
|
||||
* output of op1: int8_qspec
|
||||
* (op1 -> op2) input edge: int8_qspec
|
||||
we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
|
||||
|
||||
Figuring out the correct group ID for all edge/node is a standard union find problem:
|
||||
https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
|
||||
|
||||
Args:
|
||||
edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
|
||||
Returns:
|
||||
edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
|
||||
belongs to the same group should have the same id
|
||||
|
||||
Example:
|
||||
op2 -> cat1 -> cat2
|
||||
op1 / /
|
||||
op3
|
||||
edge_or_node_to_qspec: {
|
||||
op1: int8_qspec,
|
||||
op2: int8_qspec,
|
||||
(op1, cat1): int8_qspc,
|
||||
(op2, cat1): SharedQuantizationSpec((op1, cat1)),
|
||||
cat1: SharedQuantizationSpec((op1, cat1)),
|
||||
(op3, cat2): int8_qspec,
|
||||
(cat1, cat2): SharedQuantizationSpec((op3, cat2)),
|
||||
cat2: SharedQuantizationSpec((op3, cat2)),
|
||||
}
|
||||
|
||||
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
|
||||
edge_or_node_to_group_id: {
|
||||
op1: 1,
|
||||
op2: 1,
|
||||
(op1, cat1): 1,
|
||||
(op2, cat1): 1,
|
||||
cat1: 1,
|
||||
(op3, cat2): 1,
|
||||
(cat1, cat2): 1,
|
||||
cat2: 1,
|
||||
}
|
||||
# everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
|
||||
# connects the two sharing group around cat1 and cat2 op due to transitive sharing
|
||||
"""
|
||||
# means the observer of key should be shared with observer with value, by default it will
|
||||
# be shared with itself
|
||||
shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()}
|
||||
for edge_or_node, qspec in edge_or_node_to_qspec.items():
|
||||
if isinstance(edge_or_node, torch.fx.Node):
|
||||
output_node = edge_or_node
|
||||
_update_shared_with(output_node, qspec, shared_with_map)
|
||||
else:
|
||||
input_edge = edge_or_node
|
||||
input_edge_root = _find_root(input_edge, shared_with_map)
|
||||
input_edge_root_qspec = edge_or_node_to_qspec[input_edge_root]
|
||||
input_edge_root_qspec = _find_root_qspec(input_edge_root_qspec, edge_or_node_to_qspec, shared_with_map)
|
||||
|
||||
# find root_qspec for `arg` Node (the output of previous node)
|
||||
assert isinstance(input_edge, tuple)
|
||||
arg, n = input_edge
|
||||
arg_as_output_root_qspec = None
|
||||
if arg in edge_or_node_to_qspec:
|
||||
arg_as_output_qspec = edge_or_node_to_qspec[arg]
|
||||
arg_as_output_root_qspec = _find_root_qspec(arg_as_output_qspec, edge_or_node_to_qspec, shared_with_map)
|
||||
# TODO: add assertions for types of root qspecs
|
||||
if (
|
||||
arg_as_output_root_qspec is not None and
|
||||
_has_same_dtype(arg_as_output_root_qspec, input_edge_root_qspec) and
|
||||
_has_same_is_dynamic(arg_as_output_root_qspec, input_edge_root_qspec)
|
||||
):
|
||||
# the input arg to the node should reuse the existing output observer for arg
|
||||
# since dtype is the same (we may want to extend this to be a more strict check
|
||||
# in the future)
|
||||
# so we point from `input_edge` to `arg` (output of the argument)
|
||||
_union(arg, input_edge, shared_with_map)
|
||||
_update_shared_with(input_edge, qspec, shared_with_map)
|
||||
|
||||
# now that we get the sharing relations between all edges and nodes, we can assingn group ids
|
||||
cur_group_id = 0
|
||||
edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
|
||||
for edge_or_node in shared_with_map.keys():
|
||||
root = _find_root(edge_or_node, shared_with_map)
|
||||
if root not in edge_or_node_to_group_id:
|
||||
edge_or_node_to_group_id[root] = cur_group_id
|
||||
cur_group_id += 1
|
||||
edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
|
||||
|
||||
return edge_or_node_to_group_id
|
||||
|
||||
def _get_obs_or_fq_map(
|
||||
edge_or_node_to_group_id: Dict[EdgeOrNode, int],
|
||||
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
|
||||
is_qat: bool
|
||||
) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]:
|
||||
"""Generates the EdgeOrNode to observer/fake_quant instances
|
||||
Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
|
||||
instances
|
||||
"""
|
||||
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
|
||||
group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {}
|
||||
for edge_or_node, qspec in edge_or_node_to_qspec.items():
|
||||
group_id = edge_or_node_to_group_id[edge_or_node]
|
||||
if group_id not in group_id_to_obs_or_fq:
|
||||
# TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
|
||||
# the implementation for _create_obs_or_fq_from_qspec
|
||||
group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat)
|
||||
obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
|
||||
return obs_or_fq_map
|
||||
|
||||
def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
node: Union[Node, Any],
|
||||
arg: Argument,
|
||||
|
|
@ -72,21 +275,11 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
|||
observed_arg = arg.args[0]
|
||||
assert isinstance(observed_arg, Node), f"expect observed argument to be a Node, but got: {type(observed_arg)}"
|
||||
assert observed_arg in obs_or_fq_map, \
|
||||
f"can't refer to a node that does not have observer/fake_quant inserted yet: {observed_arg}"
|
||||
input_qspec_map = quantization_annotation.input_qspec_map
|
||||
input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
|
||||
if isinstance(input_arg_qspec, SharedQuantizationSpec):
|
||||
# if the argument is set to use SharedQuantizationSpec, we will
|
||||
# reset the observer instance to align with the configured edge/node
|
||||
obs_or_fq_name = arg.target
|
||||
setattr(model, obs_or_fq_name, arg_as_input_act_obs_or_fq)
|
||||
named_modules[obs_or_fq_name] = arg_as_input_act_obs_or_fq
|
||||
else:
|
||||
# otherwise reuse the existing obs/fq
|
||||
arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg]
|
||||
f"can't find a sharing group for node: {observed_arg}"
|
||||
# reuse the existing obs/fq
|
||||
arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg]
|
||||
# we don't need to insert new observer node
|
||||
new_arg = arg
|
||||
obs_or_fq_map[(observed_arg, node)] = arg_as_input_act_obs_or_fq
|
||||
else:
|
||||
# skip inserting new observers if there is an observer inserted for the arg before
|
||||
# that has the same dtype that we want to insert here
|
||||
|
|
@ -113,23 +306,24 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
|||
|
||||
assert arg_as_input_act_obs_or_fq is not None
|
||||
if existing_obs_node is None:
|
||||
maybe_observed_arg = arg
|
||||
# When quantizing two layers with different configs we can have
|
||||
# conv2d (int8) -> avgpool(uint8)
|
||||
# In this case observer insertion for avgpool will come here but the input
|
||||
# to avgpool will be output observer of conv2d
|
||||
# Now the obs map that we update must correspond to the original input of
|
||||
# avgpool and not the output obs of conv2d
|
||||
# This is because when referring to the edge, quantizer would refer to
|
||||
# original input and not the observed one.
|
||||
while _is_activation_post_process_node(arg, named_modules):
|
||||
arg = arg.args[0] # type: ignore[assignment]
|
||||
arg_as_input_act_obs_or_fq = obs_or_fq_map[(arg, node)]
|
||||
new_obs_node = _insert_obs_or_fq(
|
||||
arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)
|
||||
maybe_observed_arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph)
|
||||
# override this arg to be the observed arg
|
||||
new_arg = new_obs_node
|
||||
else:
|
||||
new_arg = existing_obs_node
|
||||
# When quantizing two layers with different configs we can have
|
||||
# conv2d (int8) -> avgpool(uint8)
|
||||
# In this case observer insertion for avgpool will come here but the input
|
||||
# to avgpool will be output observer of conv2d
|
||||
# Now the obs map that we update must correspond to the original input of
|
||||
# avgpool and not the output obs of conv2d
|
||||
# This is because when referring to the edge, quantizer would refer to
|
||||
# original input and not the observed one.
|
||||
while _is_activation_post_process_node(arg, named_modules):
|
||||
arg = arg.args[0] # type: ignore[assignment]
|
||||
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
|
||||
|
||||
return new_arg
|
||||
|
||||
|
|
@ -172,6 +366,19 @@ def _maybe_insert_input_observers_for_node(
|
|||
# assign the new args to the node, inplace
|
||||
node.args = tuple(new_args)
|
||||
|
||||
def _maybe_insert_output_observer_for_node(
|
||||
node: Node,
|
||||
model: torch.nn.Module,
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
graph: Graph,
|
||||
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||||
is_qat: bool,
|
||||
) -> Optional[Node]:
|
||||
if node in obs_or_fq_map:
|
||||
output_act_obs_or_fq = obs_or_fq_map[node]
|
||||
return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
|
||||
return None
|
||||
|
||||
def _maybe_insert_input_and_output_observers_for_node(
|
||||
node: Node,
|
||||
model: torch.fx.GraphModule,
|
||||
|
|
@ -213,7 +420,8 @@ def _maybe_insert_input_and_output_observers_for_node(
|
|||
return
|
||||
|
||||
# this returns the new observer node if it was needed
|
||||
maybe_output_obs_node = _maybe_insert_output_observer_for_node(node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
|
||||
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
|
||||
node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
|
||||
|
||||
if maybe_output_obs_node is None:
|
||||
return
|
||||
|
|
@ -246,9 +454,17 @@ def prepare(
|
|||
# Since we are mutating the graph as we go, we iterate over the original
|
||||
# nodes before observer insertion, instead of model.graph.nodes.
|
||||
nodes_before_observation = list(model.graph.nodes)
|
||||
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
|
||||
|
||||
# At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
|
||||
# all edge/nodes that belongs to the same group will use the same instance
|
||||
# and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
|
||||
# instance
|
||||
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
|
||||
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
|
||||
obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
|
||||
|
||||
for node in nodes_before_observation:
|
||||
# TODO: simplify logic for inserting observers
|
||||
_maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat)
|
||||
|
||||
model = GraphModule(model, model.graph)
|
||||
|
|
|
|||
|
|
@ -231,8 +231,8 @@ def convert_pt2e(
|
|||
model = _convert_to_reference_decomposed_fx(model)
|
||||
model = _fold_conv_bn_qat(model)
|
||||
pm = PassManager([DuplicateDQPass()])
|
||||
model = pm(model).graph_module
|
||||
|
||||
model = pm(model).graph_module
|
||||
pm = PassManager([PortNodeMetaForQDQ()])
|
||||
model = pm(model).graph_module
|
||||
|
||||
|
|
|
|||
|
|
@ -2571,6 +2571,20 @@ class TestHelperModules:
|
|||
z = torch.cat([x, y], dim=1)
|
||||
return z
|
||||
|
||||
class Conv2dWithTwoCat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
||||
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x1, x2, x3, x4):
|
||||
x1 = self.conv1(x1)
|
||||
x2 = self.conv2(x2)
|
||||
y = torch.cat([x1, x2], dim=1)
|
||||
z = x3 + x4
|
||||
w = torch.cat([z, y])
|
||||
return w
|
||||
|
||||
class EmbeddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user