[Quant] Move to BFS instead of DFS to check for connectedness (#108572)

Summary:
Using dfs to check if two nodes are connecgted is making it very slow.
Use of BFS makes it much faster.

Test Plan:
https://gist.github.com/leslie-fang-intel/9cd828623f567a3afbf41564d3546398

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D48971710](https://our.internmc.facebook.com/intern/diff/D48971710)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108572
Approved by: https://github.com/jerryzh168, https://github.com/osalpekar
This commit is contained in:
Kimish Patel 2023-09-06 11:04:31 -07:00 committed by PyTorch MergeBot
parent 2a40fe2dbf
commit c1877e99c5
2 changed files with 20 additions and 10 deletions

View File

@ -22,7 +22,6 @@ logger.setLevel(logging.WARNING)
__all__ = ["PortNodeMetaForQDQ"]
_METADATA_TO_PORT = [
"nn_module_stack",
"stack_trace",
"quantization_tag",
]
@ -167,6 +166,12 @@ class PortNodeMetaForQDQ(_ExportPassBase):
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
- Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
- Note first Q does not inherit metadata from any nodes
NB:
- The best place for porting metadata is during observer conversion to q/dq. This is because it precisely
knows which quantization spec is converted to q/dq and thus from where the metadata should be ported.
However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit.
Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant
code, this pass should like to be integrated in the refactored variant of "convert" step.
"""
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

View File

@ -14,6 +14,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noq
from torch.ao.quantization.quantizer import QuantizationAnnotation
__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
@ -34,15 +35,19 @@ _DEQUANTIZE_OPS = [
]
def _is_connected(next_node: torch.fx.Node, target: torch.fx.Node) -> bool:
if target.op == "output":
return False
if next_node == target:
return True
for n in next_node.users.keys():
if _is_connected(n, target):
return True
return False
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
"""
Assuming dest is one of the ops inserted by quant workflow, this function
finds if source and dest are connected. Assumption is that only quant workflow
inserted ops exist between source and dest
"""
quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
while dest.target in quant_workflow_ops:
if not isinstance(dest.args[0], torch.fx.Node):
raise ValueError(f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}")
dest = dest.args[0]
return (dest == source)
def _find_q_dq_node_for_user(