mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Quant] Move to BFS instead of DFS to check for connectedness (#108572)
Summary: Using dfs to check if two nodes are connecgted is making it very slow. Use of BFS makes it much faster. Test Plan: https://gist.github.com/leslie-fang-intel/9cd828623f567a3afbf41564d3546398 Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D48971710](https://our.internmc.facebook.com/intern/diff/D48971710) Pull Request resolved: https://github.com/pytorch/pytorch/pull/108572 Approved by: https://github.com/jerryzh168, https://github.com/osalpekar
This commit is contained in:
parent
2a40fe2dbf
commit
c1877e99c5
|
|
@ -22,7 +22,6 @@ logger.setLevel(logging.WARNING)
|
|||
__all__ = ["PortNodeMetaForQDQ"]
|
||||
|
||||
_METADATA_TO_PORT = [
|
||||
"nn_module_stack",
|
||||
"stack_trace",
|
||||
"quantization_tag",
|
||||
]
|
||||
|
|
@ -167,6 +166,12 @@ class PortNodeMetaForQDQ(_ExportPassBase):
|
|||
- Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
|
||||
- Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
|
||||
- Note first Q does not inherit metadata from any nodes
|
||||
NB:
|
||||
- The best place for porting metadata is during observer conversion to q/dq. This is because it precisely
|
||||
knows which quantization spec is converted to q/dq and thus from where the metadata should be ported.
|
||||
However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit.
|
||||
Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant
|
||||
code, this pass should like to be integrated in the refactored variant of "convert" step.
|
||||
"""
|
||||
|
||||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noq
|
|||
|
||||
from torch.ao.quantization.quantizer import QuantizationAnnotation
|
||||
|
||||
|
||||
__all__ = [
|
||||
"fold_bn_weights_into_conv_node",
|
||||
"get_aten_graph_module",
|
||||
|
|
@ -34,15 +35,19 @@ _DEQUANTIZE_OPS = [
|
|||
]
|
||||
|
||||
|
||||
def _is_connected(next_node: torch.fx.Node, target: torch.fx.Node) -> bool:
|
||||
if target.op == "output":
|
||||
return False
|
||||
if next_node == target:
|
||||
return True
|
||||
for n in next_node.users.keys():
|
||||
if _is_connected(n, target):
|
||||
return True
|
||||
return False
|
||||
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
|
||||
"""
|
||||
Assuming dest is one of the ops inserted by quant workflow, this function
|
||||
finds if source and dest are connected. Assumption is that only quant workflow
|
||||
inserted ops exist between source and dest
|
||||
"""
|
||||
quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
|
||||
quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
|
||||
while dest.target in quant_workflow_ops:
|
||||
if not isinstance(dest.args[0], torch.fx.Node):
|
||||
raise ValueError(f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}")
|
||||
dest = dest.args[0]
|
||||
return (dest == source)
|
||||
|
||||
|
||||
def _find_q_dq_node_for_user(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user