diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index f781d68d47d..f53821a9981 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -1,7 +1,7 @@ import dataclasses import itertools import operator -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Tuple import torch from torch.fx import Graph, GraphModule, Node @@ -52,7 +52,6 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs( is_per_channel: bool, has_bias: bool, is_cuda: bool, - has_add: bool = False, ) -> Dict[str, Any]: """ Optional example inputs for both `_quantized_qat_conv2d_bn_pattern` @@ -67,9 +66,6 @@ def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs( kwargs["zero_point"] = torch.tensor([0], dtype=torch.int) if has_bias: kwargs["conv_bias"] = torch.randn(1) - if has_add: - # extra_input_for_add, use same shape as x here since conv weight is torch.randn(1, 1, 1, 1) - kwargs["extra_input"] = torch.randn(1, 1, 3, 3) if is_cuda: for k, v in kwargs.items(): if isinstance(v, torch.Tensor): @@ -147,53 +143,6 @@ def _qat_conv2d_bn_pattern_no_conv_bias( x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps) return x -def _get_input_output_quantized_filter(): - def _input_output_quantized_filter( - match: "InternalMatch", # type: ignore[name-defined] - original_graph: Graph, - pattern_graph: Graph, - ) -> bool: - """ - Make sure that the matched pattern's input is coming from dq node - and the output is from q node. This is used to filter out the nodes for - conv-bn pattern. - We need to replace qat's conv-bn pattern with just conv-bn nodes. - QAT's conv-bn pattern has q-dq node inserted after convert step. - In order to replace QAT pattern, see _get_quantized_qat_conv2d_bn_pattern, - with a simpler pattern, see _get_folded_quantized_qat_conv2d_bn_pattern, - we need to port the quantization parameters from q/dq nodes of weight. - This porting becomes easier if there is only one q/dq node because we dont have to - reason about about finding the right q/dq node from original graph. - In order to facilitate that matched pattern and replacement pattern cannot have q for - input activation and dq for output of the fusion. Thus those nodes are removed from - pattern to be matched, however we still want to make sure that input activation of - the pattern is actually quantized and output is dequantized. Hence this filter. - """ - input_dq_node = None - output_q_node = None - for pattern_node, original_node in match.nodes_map.items(): - if pattern_node.op == "placeholder": - if ( - original_node.target - == torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): - input_dq_node = original_node - # output node is not a separate node in the list of nodes seen in the matçh - # it is a node in the node.users list of the last node. - if ( - len(pattern_node.users) == 1 - and next(iter(pattern_node.users.keys())).op == "output" - ): - output_node = next(iter(original_node.users.keys())) - if ( - output_node.target - == torch.ops.quantized_decomposed.quantize_per_tensor.default - ): - output_q_node = original_node - return (input_dq_node is not None) and (output_q_node is not None) - - return _input_output_quantized_filter - def _append_qdq(x, is_per_channel, kwargs): """ Helper function to append q-dq ops after `x`, using dummy values for the qparams @@ -221,11 +170,7 @@ def _append_qdq(x, is_per_channel, kwargs): def _get_quantized_qat_conv2d_bn_pattern( is_per_channel: bool, - has_relu: bool, has_bias: bool, - relu_is_inplace: bool, - has_add: bool, - add_is_inplace: bool, bias_is_quantized: bool, ) -> Callable: """ @@ -266,28 +211,12 @@ def _get_quantized_qat_conv2d_bn_pattern( if has_bias: x = x + kwargs["conv_bias"].reshape(bias_shape) x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps) - - if has_add: - if add_is_inplace: - x += kwargs["extra_input"] - else: - x = x + kwargs["extra_input"] - - if has_relu: - if relu_is_inplace: - x = F.relu_(x) - else: - x = F.relu(x) return x return _quantized_qat_conv2d_bn_pattern def _get_folded_quantized_qat_conv2d_bn_pattern( is_per_channel: bool, - has_relu: bool, has_bias: bool, - relu_is_inplace: bool, - has_add: bool, - add_is_inplace: bool, bias_is_quantized: bool, ) -> Callable: """ @@ -314,18 +243,6 @@ def _get_folded_quantized_qat_conv2d_bn_pattern( bias = None x = F.conv2d(x, conv_weight, bias) x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps) - - if has_add: - if add_is_inplace: - x += kwargs["extra_input"] - else: - x = x + kwargs["extra_input"] - - if has_relu: - if relu_is_inplace: - x = F.relu_(x) - else: - x = F.relu(x) return x return _folded_quantized_qat_conv2d_bn_pattern @@ -382,15 +299,14 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod The following names may exist in the map: "conv_weight_q", "conv_weight_dq", "conv_bias", - "conv_bias_q", "conv_bias_dq", "add", "relu" + "conv_bias_q", "conv_bias_dq" """ - def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node, Optional[Node], Optional[Node]]: + def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]: """ - Return a 5-tuple of (conv_node, bn_node, getitem_node, add_node, relu_node). - This asserts that the match contains exactly one conv, bn, and getitem, - and at most one add and one relu. + Return a 3-tuple of (conv_node, bn_node, getitem_node). + This asserts that the match contains exactly one of each node. """ - conv_node, bn_node, getitem_node, add_node, relu_node = None, None, None, None, None + conv_node, bn_node, getitem_node = None, None, None for n in nodes: if n.op != "call_function": continue @@ -403,21 +319,10 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod if n.target == operator.getitem: assert getitem_node is None getitem_node = n - if n.target == torch.ops.aten.relu.default: - assert relu_node is None - relu_node = n - if (n.target in [torch.ops.aten.add_.Tensor, torch.ops.aten.add.Tensor]) and ( - (isinstance(n.args[0], torch.fx.Node) and n.args[0].target == operator.getitem) - or (isinstance(n.args[1], torch.fx.Node) and n.args[1].target == operator.getitem) - ): - # One of Add's input should be BN's getitem node - assert add_node is None - add_node = n - assert conv_node is not None assert bn_node is not None assert getitem_node is not None - return (conv_node, bn_node, getitem_node, add_node, relu_node) + return (conv_node, bn_node, getitem_node) def _get_q_dq_nodes(n: Node) -> Tuple[Node, Node, Node]: """ @@ -432,8 +337,8 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod return (orig_node, q_node, n) original_nodes = list(_filter_nodes_map(r.nodes_map).values()) - o_conv, o_bn, o_getitem, o_add, o_relu = _get_nodes(original_nodes) - r_conv, r_bn, r_getitem, r_add, r_relu = _get_nodes(r.replacements) + o_conv, o_bn, o_getitem = _get_nodes(original_nodes) + r_conv, r_bn, r_getitem = _get_nodes(r.replacements) # Create the mapping from original node to replacement node mapping = { @@ -441,17 +346,11 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod "bn": (o_bn, r_bn), "getitem": (o_getitem, r_getitem), } - if o_relu is not None: - assert r_relu is not None - mapping["relu"] = (o_relu, r_relu) - if o_add is not None: - assert r_add is not None - mapping["add"] = (o_add, r_add) # Extract conv input and weight # Note: here we extract the original nodes indirectly through the pattern nodes # because the args of the original nodes are no longer available after replacement - (p_conv, _, _, _, _) = _get_nodes(list(r.nodes_map.keys())) + (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys())) (p_conv_input, p_conv_weight, *_) = p_conv.args (r_conv_input, r_conv_weight, *_) = r_conv.args assert isinstance(p_conv_input, Node) @@ -681,7 +580,7 @@ def _fuse_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule: all_original_to_replacement_nodes = {} for r in replacements_with_conv_bias + replacements_no_conv_bias: for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values(): - # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem (- add) (- relu)] + # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem] replacement_node.meta = original_node.meta if original_node.target == torch.ops.aten.conv2d.default: # Step (3b): Copy over conv literal args @@ -789,34 +688,20 @@ def _fold_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule: replacements = [] replacement_options = itertools.product( [True, False], # is_per_channel - [True, False], # has_relu [True, False], # has_bias - [True, False], # relu_is_inplace - [True, False], # has_add - [True, False], # add_is_inplace [True, False], # bias_is_quantized ) - for ( - is_per_channel, has_relu, has_bias, relu_is_inplace, - has_add, add_is_inplace, bias_is_quantized, - ) in replacement_options: - # For the cases without relu, `relu_is_inplace` is irrelevant, so here we arbitrarily + for is_per_channel, has_bias, bias_is_quantized in replacement_options: + # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily # filter out one of the values for this flag to avoid having duplicate patterns - # Same for `add_is_inplace` and `bias_is_quantized` - if ( - (not has_relu and relu_is_inplace) or - (not has_add and add_is_inplace) or - (not has_bias and bias_is_quantized) - ): + if not has_bias and bias_is_quantized: continue example_inputs = _quantized_conv2d_bn_pattern_example_inputs - kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias, is_cuda, has_add) - match_pattern = _get_quantized_qat_conv2d_bn_pattern( - is_per_channel, has_relu, has_bias, relu_is_inplace, has_add, add_is_inplace, bias_is_quantized - ) + kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias, is_cuda) + match_pattern = _get_quantized_qat_conv2d_bn_pattern(is_per_channel, has_bias, bias_is_quantized) match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs) replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern( - is_per_channel, has_relu, has_bias, relu_is_inplace, has_add, add_is_inplace, bias_is_quantized + is_per_channel, has_bias, bias_is_quantized, ) replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs) replacements.extend( @@ -824,7 +709,6 @@ def _fold_conv_bn_qat_helper(m: GraphModule, is_cuda: bool) -> GraphModule: m, match_pattern, replacement_pattern, - match_filters=[_get_input_output_quantized_filter()], ignore_literals=True, ) )