mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[quant][pt2e] Support conv bn fusion in convert step for QAT flow (#100442)
Summary: This PR adds support for folding bn weights into conv for QAT flow, this is equivalent to the QAT branch of `from_float` in eager mode quantized conv module: https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/modules/conv.py#L223 Items that needs followup: * there are some workaround I did because quantize_per_tensor is using float/int args and dynamo does not support these args, need to fix after we change the quantized model representation and also change these args to Tensor Test Plan: buck2 test @//mode/opt //caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_convert_qat_conv_bn_fusion (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)' Reviewed By: andrewor14 Differential Revision: D45344281 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100442 Approved by: https://github.com/kimishpatel
This commit is contained in:
parent
f92b3e1477
commit
c3f3cb5b0f
|
|
@ -21,6 +21,7 @@ from torch.ao.quantization._pt2e.quantizer import (
|
|||
)
|
||||
from torch.ao.quantization._quantize_pt2e import (
|
||||
convert_pt2e,
|
||||
_convert_to_reference_decomposed_fx,
|
||||
prepare_pt2e_quantizer,
|
||||
prepare_qat_pt2e_quantizer,
|
||||
)
|
||||
|
|
@ -32,9 +33,9 @@ from torch.ao.quantization.qconfig import (
|
|||
default_symmetric_qnnpack_qat_qconfig,
|
||||
)
|
||||
from torch.ao.quantization.quantize_fx import (
|
||||
convert_to_reference_fx,
|
||||
prepare_fx,
|
||||
prepare_qat_fx,
|
||||
convert_to_reference_fx,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
|
|
@ -44,9 +45,6 @@ from torch.testing._internal.common_quantization import (
|
|||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
|
||||
|
||||
from torch.ao.quantization.quantize_fx import _convert_to_reference_decomposed_fx
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2E(QuantizationTestCase):
|
||||
def test_simple_quantizer(self):
|
||||
|
|
@ -599,6 +597,7 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
model: torch.nn.Module,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
is_per_channel: bool,
|
||||
verify_convert: bool = False,
|
||||
):
|
||||
"""
|
||||
Helper method to verify that the QAT numerics for PT2E quantization match those of
|
||||
|
|
@ -615,7 +614,7 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
aten_graph=True,
|
||||
)
|
||||
model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer)
|
||||
result_pt2e = model_pt2e(*example_inputs)
|
||||
after_prepare_result_pt2e = model_pt2e(*example_inputs)
|
||||
|
||||
# FX
|
||||
# Note: In order to match the PT2E numerics exactly, we need to feed the
|
||||
|
|
@ -632,11 +631,36 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
|
||||
backend_config = get_qnnpack_backend_config()
|
||||
model_fx = prepare_qat_fx(model_fx, qconfig_mapping, example_inputs, backend_config=backend_config)
|
||||
result_fx = model_fx(*example_inputs)
|
||||
after_prepare_result_fx = model_fx(*example_inputs)
|
||||
|
||||
# Verify that numerics match
|
||||
self.assertEqual(result_pt2e, result_fx)
|
||||
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
|
||||
|
||||
if verify_convert:
|
||||
model_pt2e = convert_pt2e(model_pt2e)
|
||||
quant_result_pt2e = model_pt2e(*example_inputs)
|
||||
|
||||
model_fx = _convert_to_reference_decomposed_fx(model_fx, backend_config=backend_config)
|
||||
quant_result_fx = model_fx(*example_inputs)
|
||||
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
|
||||
|
||||
|
||||
def test_convert_qat_conv_bn_numerics(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=False)
|
||||
# TODO: enable in a separate PR
|
||||
# self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=True)
|
||||
|
||||
class TestQuantizePT2EModels(QuantizationTestCase):
|
||||
@skip_if_no_torchvision
|
||||
|
|
|
|||
|
|
@ -1284,6 +1284,7 @@ def get_fake_value(node, tx):
|
|||
unimplemented("guard on data-dependent symbolic int/float")
|
||||
elif isinstance(cause, torch.utils._sympy.value_ranges.ValueRangeError):
|
||||
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, e.args[0]) from e
|
||||
# why don't we print the exception here?
|
||||
raise TorchRuntimeError() from e
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ import torch
|
|||
from torch.fx import GraphModule, Node
|
||||
from torch.fx.subgraph_rewriter import _replace_pattern
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
from .utils import _fold_bn_weights_into_conv_node
|
||||
|
||||
|
||||
# Example inputs for both `_conv2d_bn_pattern` and `_fused_qat_conv2d_bn_pattern`
|
||||
# Example inputs for both `_conv2d_bn_pattern` and `_qat_conv2d_bn_pattern`
|
||||
_conv2d_bn_pattern_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
|
|
@ -19,6 +20,23 @@ _conv2d_bn_pattern_example_inputs = (
|
|||
torch.randn(1), # bn_running_var
|
||||
)
|
||||
|
||||
# Example inputs for both `_quantized_qat_conv2d_bn_pattern` and `_folded_quantized_qat_conv2d_bn_pattern`
|
||||
_quantized_conv2d_bn_pattern_example_inputs = (
|
||||
torch.randn(1, 1, 3, 3).to(torch.int8), # x
|
||||
torch.randn(1, 1, 1, 1), # conv_weight
|
||||
torch.randn(1), # conv_bias
|
||||
torch.randn(1), # bn_weight
|
||||
torch.randn(1), # bn_bias
|
||||
torch.randn(1), # bn_running_mean
|
||||
torch.randn(1), # bn_running_var
|
||||
torch.tensor([1], dtype=torch.float), # input_scale
|
||||
torch.tensor([0], dtype=torch.int), # input_zero_point
|
||||
torch.tensor([1], dtype=torch.float), # weight_scale
|
||||
torch.tensor([0], dtype=torch.int), # weight_zero_point
|
||||
torch.tensor([1], dtype=torch.float), # output_scale
|
||||
torch.tensor([0], dtype=torch.int), # output_zero_point
|
||||
)
|
||||
|
||||
def _conv2d_bn_pattern(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
|
|
@ -32,7 +50,7 @@ def _conv2d_bn_pattern(
|
|||
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True)
|
||||
return x
|
||||
|
||||
def _fused_qat_conv2d_bn_pattern(
|
||||
def _qat_conv2d_bn_pattern(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
conv_bias: torch.Tensor,
|
||||
|
|
@ -62,6 +80,96 @@ def _fused_qat_conv2d_bn_pattern(
|
|||
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
|
||||
return x
|
||||
|
||||
def _quantized_qat_conv2d_bn_pattern(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
conv_bias: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
input_zero_point: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_zero_point: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
output_zero_point: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Quantized version of qat conv bn pattern,
|
||||
This is based on `nniqat.ConvBn2d._forward_approximate`.
|
||||
used in qat convert, we first match this pattern and then replace it with
|
||||
normal conv - bn pattern and then fold the weights of bn into conv
|
||||
"""
|
||||
# TODO: allow setting eps
|
||||
bn_eps = 1e-5
|
||||
weight_quant_min = -127
|
||||
weight_quant_max = 127
|
||||
input_quant_min = -128
|
||||
input_quant_max = 127
|
||||
output_quant_min = -128
|
||||
output_quant_max = 127
|
||||
|
||||
running_std = torch.sqrt(bn_running_var + bn_eps)
|
||||
scale_factor = bn_weight / running_std
|
||||
weight_shape = [1] * len(conv_weight.shape)
|
||||
weight_shape[0] = -1
|
||||
bias_shape = [1] * len(conv_weight.shape)
|
||||
bias_shape[1] = -1
|
||||
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
|
||||
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x, input_scale, input_zero_point, input_quant_min, input_quant_max, torch.int8)
|
||||
zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
|
||||
scaled_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
scaled_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
||||
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
scaled_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
||||
x = F.conv2d(x, scaled_weight, zero_bias)
|
||||
x = x / scale_factor.reshape(bias_shape)
|
||||
x = x + conv_bias.reshape(bias_shape)
|
||||
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
|
||||
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
x, output_scale, output_zero_point, output_quant_min, output_quant_max, torch.int8)
|
||||
return x
|
||||
|
||||
def _folded_quantized_qat_conv2d_bn_pattern(
|
||||
x: torch.Tensor,
|
||||
conv_weight: torch.Tensor,
|
||||
conv_bias: torch.Tensor,
|
||||
bn_weight: torch.Tensor,
|
||||
bn_bias: torch.Tensor,
|
||||
bn_running_mean: torch.Tensor,
|
||||
bn_running_var: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
input_zero_point: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_zero_point: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
output_zero_point: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
""" Quantized QAT conv - bn pattern with bn weights being folded into conv
|
||||
"""
|
||||
# TODO: allow setting eps
|
||||
bn_eps = 1e-5
|
||||
weight_quant_min = -127
|
||||
weight_quant_max = 127
|
||||
input_quant_min = -128
|
||||
input_quant_max = 127
|
||||
output_quant_min = -128
|
||||
output_quant_max = 127
|
||||
|
||||
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x, input_scale, input_zero_point, input_quant_min, input_quant_max, torch.int8)
|
||||
conv_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
conv_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
||||
conv_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
conv_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
||||
x = F.conv2d(x, conv_weight, conv_bias)
|
||||
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
|
||||
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
x, output_scale, output_zero_point, output_quant_min, output_quant_max, torch.int8)
|
||||
return x
|
||||
|
||||
def _get_aten_graph_module(
|
||||
pattern: Callable,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
|
|
@ -94,7 +202,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
m.recompile()
|
||||
example_inputs = _conv2d_bn_pattern_example_inputs
|
||||
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
|
||||
replacement_pattern = _get_aten_graph_module(_fused_qat_conv2d_bn_pattern, example_inputs)
|
||||
replacement_pattern = _get_aten_graph_module(_qat_conv2d_bn_pattern, example_inputs)
|
||||
# TODO: use the public replace_pattern API once it also returns replacement nodes
|
||||
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern)
|
||||
m.recompile()
|
||||
|
|
@ -127,3 +235,60 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
if original_node.target == operator.getitem:
|
||||
replacement_getitem_node.meta = original_node.meta
|
||||
return m
|
||||
|
||||
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
|
||||
"""
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
|
||||
match_pattern = _get_aten_graph_module(_quantized_qat_conv2d_bn_pattern, example_inputs)
|
||||
|
||||
# Workaround: current convert does not produce q/dq ops with a specific overload
|
||||
# we'll remove the overload from the pattern here as a workaround since we do not want to break BC
|
||||
for n in match_pattern.graph.nodes:
|
||||
if n.op == "call_function" and n.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor:
|
||||
n.target = torch.ops.quantized_decomposed.quantize_per_tensor
|
||||
if n.op == "call_function" and n.target == torch.ops.quantized_decomposed.dequantize_per_tensor.tensor:
|
||||
n.target = torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
|
||||
replacement_pattern = _get_aten_graph_module(_folded_quantized_qat_conv2d_bn_pattern, example_inputs)
|
||||
|
||||
# TODO: use the public replace_pattern API once it also returns replacement nodes
|
||||
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern, ignore_literals=True)
|
||||
m.recompile()
|
||||
|
||||
for mr in match_and_replacement:
|
||||
# Find replacement conv and bn nodes by climbing upwards from anchor node
|
||||
assert len(mr.replacements) == 1, "expected only one replacement node"
|
||||
|
||||
# find conv, bn, weight, bias nodes in the graph
|
||||
replacement_quantize_node = mr.replacements[0]
|
||||
assert replacement_quantize_node.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
n = replacement_quantize_node
|
||||
conv_node = None
|
||||
bn_node = None
|
||||
while conv_node is None or bn_node is None:
|
||||
if n.target == torch.ops.aten.convolution.default:
|
||||
conv_node = n
|
||||
if n.target == torch.ops.aten._native_batch_norm_legit.default:
|
||||
bn_node = n
|
||||
assert isinstance(n.args[0], Node)
|
||||
n = n.args[0]
|
||||
assert conv_node is not None and bn_node is not None
|
||||
|
||||
conv_weight_dq = conv_node.args[1]
|
||||
assert conv_weight_dq.target == torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
||||
conv_weight_q = conv_weight_dq.args[0]
|
||||
assert conv_weight_q.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
||||
conv_weight = conv_weight_q.args[0]
|
||||
assert conv_weight.op == "get_attr"
|
||||
conv_bias = conv_node.args[2]
|
||||
|
||||
# fold bn weights into conv
|
||||
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
return m
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import (
|
||||
GraphModule,
|
||||
Node,
|
||||
)
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
||||
# TODO[jerryzh168]: move this to a more general util function
|
||||
from torch.ao.quantization.fx.prepare import (
|
||||
|
|
@ -15,68 +18,94 @@ def _get_tensor_constant_from_node(node, m):
|
|||
assert node.op == "get_attr"
|
||||
return getattr(m, node.target)
|
||||
|
||||
def _get_all_arguments(orig_args, orig_kwargs, args_schema):
|
||||
all_args = []
|
||||
for i, schema in enumerate(args_schema):
|
||||
if schema.name in orig_kwargs:
|
||||
all_args.append(orig_kwargs[schema.name])
|
||||
elif not schema.kwarg_only and i < len(orig_args):
|
||||
all_args.append(orig_args[i])
|
||||
else:
|
||||
all_args.append(schema.default_value)
|
||||
return all_args
|
||||
|
||||
def _fold_bn_weights_into_conv_node(
|
||||
conv_node: Node,
|
||||
conv_weight_node: Node,
|
||||
conv_bias_node: Node,
|
||||
bn_node: Node,
|
||||
m: GraphModule
|
||||
) -> None:
|
||||
# conv weight
|
||||
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
|
||||
# conv bias
|
||||
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
|
||||
transpose = conv_node.args[6]
|
||||
|
||||
bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
|
||||
bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
|
||||
|
||||
# bn weight
|
||||
bn_w = _get_tensor_constant_from_node(bn_args[1], m)
|
||||
# bn bias
|
||||
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
|
||||
# bn running mean
|
||||
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
|
||||
# bn running variance
|
||||
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
|
||||
bn_eps = bn_args[6]
|
||||
|
||||
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
|
||||
|
||||
# update the weight and bias for conv
|
||||
conv_args = list(conv_node.args)
|
||||
# calling data since the fused_weight and fused_bias are nn.Parameter
|
||||
weight_attr_name = conv_weight_node.target
|
||||
setattr(m, weight_attr_name, fused_weight) # type: ignore[arg-type]
|
||||
if conv_bias_node is not None:
|
||||
bias_attr_name = conv_bias_node.target
|
||||
else:
|
||||
bias_attr_name = weight_attr_name + "_bias"
|
||||
with m.graph.inserting_before(conv_node):
|
||||
get_bias_node = m.graph.get_attr(bias_attr_name)
|
||||
# NOTE: here we assume the bias of conv is not quantized!
|
||||
conv_args[2] = get_bias_node
|
||||
setattr(m, bias_attr_name, fused_bias) # type: ignore[arg-type]
|
||||
conv_node.args = tuple(conv_args)
|
||||
|
||||
# native_batch_norm has 3 outputs, we expect getitem calls on the output
|
||||
# and we want to replace the uses of getitem 0 with the output of conv
|
||||
#
|
||||
# Before:
|
||||
# conv -> bn - (first output) -> users1
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# After:
|
||||
# conv -> (first output) -> users1
|
||||
# bn -
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# if users2 and users3 are empty then bn will be removed through dead code elimination
|
||||
|
||||
for user in bn_node.users:
|
||||
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
|
||||
continue
|
||||
user.replace_all_uses_with(conv_node)
|
||||
|
||||
# fuse conv bn weights, inplace modification of the graph_module and graph
|
||||
def _fuse_conv_bn_(m: GraphModule) -> None:
|
||||
for n in m.graph.nodes:
|
||||
if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
|
||||
continue
|
||||
bn_op = n
|
||||
n = bn_op.args[0]
|
||||
bn_node = n
|
||||
n = bn_node.args[0]
|
||||
if n.op != "call_function" or n.target != torch.ops.aten.convolution.default:
|
||||
continue
|
||||
conv_op = n
|
||||
conv_node = n
|
||||
conv_weight_node = conv_node.args[1]
|
||||
conv_bias_node = conv_node.args[2]
|
||||
_fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
|
||||
|
||||
# conv weight
|
||||
conv_w = _get_tensor_constant_from_node(conv_op.args[1], m)
|
||||
# conv bias
|
||||
conv_b = _get_tensor_constant_from_node(conv_op.args[2], m)
|
||||
transpose = conv_op.args[6]
|
||||
|
||||
# bn weight
|
||||
bn_w = _get_tensor_constant_from_node(bn_op.args[1], m)
|
||||
# bn bias
|
||||
bn_b = _get_tensor_constant_from_node(bn_op.args[2], m)
|
||||
# bn running mean
|
||||
bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m)
|
||||
# bn running variance
|
||||
bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m)
|
||||
bn_eps = bn_op.args[6]
|
||||
|
||||
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
|
||||
|
||||
# update the weight and bias for conv
|
||||
conv_args = list(conv_op.args)
|
||||
# calling data since the fused_weight and fused_bias are nn.Parameter
|
||||
weight_attr_name = conv_args[1].target
|
||||
setattr(m, weight_attr_name, fused_weight)
|
||||
if conv_args[2] is not None:
|
||||
bias_attr_name = conv_args[2].target
|
||||
else:
|
||||
bias_attr_name = weight_attr_name + "_bias"
|
||||
with m.graph.inserting_before(conv_op):
|
||||
get_bias_node = m.graph.get_attr(bias_attr_name)
|
||||
conv_args[2] = get_bias_node
|
||||
setattr(m, bias_attr_name, fused_bias)
|
||||
conv_op.args = tuple(conv_args)
|
||||
|
||||
# native_batch_norm has 3 outputs, we expect getitem calls on the output
|
||||
# and we want to replace the uses of getitem 0 with the output of conv
|
||||
#
|
||||
# Before:
|
||||
# conv -> bn - (first output) -> users1
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# After:
|
||||
# conv -> (first output) -> users1
|
||||
# bn -
|
||||
# \ - (second output) -> users2
|
||||
# \ - (third output) -> users3
|
||||
# if users2 and users3 are empty then bn will be removed through dead code elimination
|
||||
|
||||
for user in bn_op.users:
|
||||
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
|
||||
continue
|
||||
user.replace_all_uses_with(conv_op)
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,10 @@ from torch.fx import GraphModule
|
|||
|
||||
from ._pt2e.prepare import prepare
|
||||
from ._pt2e._propagate_annotation import propagate_annotation
|
||||
from ._pt2e.qat_utils import _fuse_conv_bn_qat
|
||||
from ._pt2e.qat_utils import (
|
||||
_fuse_conv_bn_qat,
|
||||
_fold_conv_bn_qat,
|
||||
)
|
||||
from ._pt2e.utils import (
|
||||
_get_node_name_to_scope,
|
||||
_fuse_conv_bn_,
|
||||
|
|
@ -81,4 +84,6 @@ def prepare_qat_pt2e_quantizer(
|
|||
def convert_pt2e(
|
||||
model: GraphModule
|
||||
):
|
||||
return _convert_to_reference_decomposed_fx(model)
|
||||
model = _convert_to_reference_decomposed_fx(model) # type: ignore[assignment]
|
||||
model = _fold_conv_bn_qat(model)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -222,6 +222,7 @@ def _replace_pattern(
|
|||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule],
|
||||
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
|
||||
ignore_literals: bool = False,
|
||||
) -> List[ReplacedPatterns]:
|
||||
|
||||
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
|
||||
|
|
@ -243,7 +244,7 @@ def _replace_pattern(
|
|||
replacement_graph = symbolic_trace(replacement).graph
|
||||
|
||||
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
|
||||
remove_overlapping_matches=True)
|
||||
remove_overlapping_matches=True, ignore_literals=ignore_literals)
|
||||
_matches: List[InternalMatch] = matcher.match(original_graph)
|
||||
|
||||
# Filter out matches that don't match the filter
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user