[reland][bc-breaking][quant][be] Refactor fuser_method to include is_qat argument" (#71956)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71956

Pull Request resolved: https://github.com/facebookresearch/mobile-vision/pull/59

Original commit changeset: f3912e210e8c

Original Phabricator Diff: D33178977 (ef501e8fed)

Test Plan:
Please see original diff for test plans

**Static Docs Preview: classyvision**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D33833203/V3/classyvision/)|

|**Modified Pages**|

Reviewed By: andrewor14

Differential Revision: D33833203

fbshipit-source-id: 74a8f22730b00aafa6a173b208e635c1d696959e
(cherry picked from commit fb88772b18)
This commit is contained in:
Jerry Zhang 2022-01-31 14:57:02 -08:00 committed by PyTorch MergeBot
parent 847dbb8684
commit 082ff25f37
13 changed files with 213 additions and 107 deletions

View File

@ -820,7 +820,7 @@ class TestDistributed(QuantizationTestCase):
torch.ao.quantization.DeQuantStub(),
)
torch.ao.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
torch.ao.quantization.fuse_modules_qat(model, [['1', '2', '3'], ['4', '5']], inplace=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
torch.ao.quantization.prepare_qat(model, inplace=True)
@ -861,7 +861,7 @@ class TestDistributed(QuantizationTestCase):
model = Model()
# fuse it
fused_model = torch.ao.quantization.fuse_modules(
fused_model = torch.ao.quantization.fuse_modules_qat(
model,
[['conv', 'bn']],
)

View File

@ -15,6 +15,7 @@ from torch.ao.quantization import (
prepare_qat,
quantize_qat,
fuse_modules,
fuse_modules_qat,
QConfig,
default_qconfig,
default_qat_qconfig,
@ -43,8 +44,8 @@ class TestFuseEager(QuantizationTestCase):
def test_fuse_module_train(self):
model = ModelForFusion(default_qat_qconfig).train()
# Test step by step fusion
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
msg="Fused Conv + BN + Relu first layer")
self.assertEqual(type(model.bn1), torch.nn.Identity,
@ -91,7 +92,9 @@ class TestFuseEager(QuantizationTestCase):
checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).train()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
model = fuse_modules_qat(
model,
[['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
@ -101,7 +104,9 @@ class TestFuseEager(QuantizationTestCase):
def test_fuse_module_eval(self):
model = ModelForFusion(default_qconfig)
model.eval()
model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
model = fuse_modules(
model,
[['conv3', 'bn3', 'relu4'],
['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
@ -168,7 +173,9 @@ class TestFuseEager(QuantizationTestCase):
checkQuantized(model)
model = ModelForFusion(default_qconfig).eval()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
model = fuse_modules(
model,
[['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn'],
@ -181,11 +188,13 @@ class TestFuseEager(QuantizationTestCase):
with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().train()
model.to(torch.float)
fuse_modules(model, [['conv1', 'relu1'] ,
fuse_modules_qat(
model, [['conv1', 'relu1'] ,
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']], inplace=True)
['classifier.0', 'classifier.1']],
inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
@ -233,11 +242,14 @@ class TestFuseEager(QuantizationTestCase):
with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().eval()
model.to(torch.float)
fuse_modules(model, [['conv1', 'relu1'] ,
fuse_modules(
model,
[['conv1', 'relu1'],
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']], inplace=True)
['classifier.0', 'classifier.1']],
inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
@ -286,7 +298,9 @@ class TestFuseEager(QuantizationTestCase):
# fused model
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
weight=torch.nn.Identity)
model = fuse_modules(model_orig, [["conv1", "bn1", "relu1"],
model = fuse_modules_qat(
model_orig,
[["conv1", "bn1", "relu1"],
["conv2", "bn2"]])
prep_model = prepare_qat(model, inplace=False)
# output with fusion but no observers.
@ -385,8 +399,8 @@ class TestFuseEager(QuantizationTestCase):
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
fused = True
before_fusion_pre_count = counter['pre_forwards']

View File

@ -69,7 +69,7 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model = torch.ao.quantization.QuantWrapper(my_model)
fq_model.train()
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.ao.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.ao.quantization.disable_fake_quant)
@ -105,7 +105,7 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model = torch.ao.quantization.QuantWrapper(my_model)
fq_model.train()
fq_model.qconfig = qconfig
torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.ao.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.ao.quantization.disable_fake_quant)

View File

@ -48,6 +48,7 @@ from torch.ao.quantization import (
get_default_qat_qconfig,
get_default_qconfig_dict,
fuse_modules,
fuse_modules_qat,
prepare,
prepare_qat,
convert,
@ -363,6 +364,8 @@ class TestFuseFx(QuantizationTestCase):
@skipIfNoFBGEMM
def test_qconfig_fused_module(self):
""" TODO: add test for all fused modules
"""
qconfig_dict = {
"": None,
"object_type": [(nn.Linear, default_qconfig),
@ -890,14 +893,19 @@ class TestQuantizeFx(QuantizationTestCase):
m_eager.eval()
qconfig = get_default_qconfig(qengine)
prepare_fn = prepare
is_qat = False
else:
m_eager.train()
qconfig = get_default_qat_qconfig(qengine)
prepare_fn = prepare_qat
is_qat = True
fuse_list = ["conv", "bn"]
if has_relu:
fuse_list.append("relu")
if is_qat:
fuse_modules_qat(m_eager, fuse_list, inplace=True)
else:
fuse_modules(m_eager, fuse_list, inplace=True)
m_eager.qconfig = qconfig
m_eager = prepare_fn(m_eager)
@ -5847,6 +5855,7 @@ class TestQuantizeFxModels(QuantizationTestCase):
graph.eval()
calibrate_or_train = test_only_eval_fn
data = self.img_data_2d
is_qat = False
else:
assert quant_type == QuantType.QAT
qconfig = default_qat_qconfig
@ -5856,6 +5865,7 @@ class TestQuantizeFxModels(QuantizationTestCase):
graph.train()
calibrate_or_train = test_only_train_fn
data = self.img_data_2d_train
is_qat = True
if hasattr(eager, "fuse_model"):
eager.fuse_model()

View File

@ -2,6 +2,7 @@
from .fake_quantize import * # noqa: F403
from .fuse_modules import fuse_modules # noqa: F403
from .fuse_modules import fuse_modules_qat # noqa: F403
from .fuser_method_mappings import * # noqa: F403
from .observer import * # noqa: F403
from .qconfig import * # noqa: F403

View File

@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module):
setattr(cur_mod, tokens[-1], module)
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
r"""Returns a list of modules that fuses the operations specified
in the input module list.
@ -46,7 +46,7 @@ def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
fused = fuser_method(*mod_list)
fused = fuser_method(is_qat, *mod_list)
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
# Move pre forward hooks of the base module to resulting fused module
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
@ -65,7 +65,7 @@ def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
return new_mod
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
@ -74,12 +74,25 @@ def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_cu
mod_list.append(_get_module(model, item))
# Fuse list of modules
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
# Replace original module list with fused module list
for i, item in enumerate(modules_to_fuse):
_set_module(model, item, new_mod_list[i])
def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
if not inplace:
model = copy.deepcopy(model)
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
# Handle case of modules_to_fuse being a list
_fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict)
else:
# Handle case of modules_to_fuse being a list of lists
for module_list in modules_to_fuse:
_fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict)
return model
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
r"""Fuses a list of modules into a single module
@ -121,27 +134,34 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
Examples::
>>> m = myModel()
>>> m = M().eval()
>>> # m is a module containing the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)
>>> m = myModel()
>>> m = M().eval()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)
"""
if not inplace:
model = copy.deepcopy(model)
return _fuse_modules(
model,
modules_to_fuse,
is_qat=False,
inplace=inplace,
fuser_func=fuse_known_modules,
fuse_custom_config_dict=None)
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
# Handle case of modules_to_fuse being a list
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
else:
# Handle case of modules_to_fuse being a list of lists
for module_list in modules_to_fuse:
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
return model
def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
""" QAT version for `fuse_modules`
"""
return _fuse_modules(
model,
modules_to_fuse,
is_qat=True,
inplace=inplace,
fuser_func=fuse_known_modules,
fuse_custom_config_dict=None)

View File

@ -7,10 +7,12 @@ from torch.ao.quantization.utils import Pattern
from torch.ao.quantization.utils import get_combined_dict
def fuse_conv_bn(conv, bn):
def fuse_conv_bn(is_qat, conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
is_qat: a flag for whether we are using quantization aware training fusion
or post training quantization fusion
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
@ -29,7 +31,9 @@ def fuse_conv_bn(conv, bn):
nn.Conv3d: nni.ConvBn3d,
}
if conv.training:
if is_qat:
# TODO: remove the assert later
assert conv.training, "qat is only supported when conv.training is True currently"
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
@ -41,10 +45,12 @@ def fuse_conv_bn(conv, bn):
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu):
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
is_qat: a flag for whether we are using quantization aware training fusion
or post training quantization fusion
conv: Module instance of type conv2d/conv3d
bn: Spatial BN instance that needs to be fused with the conv
@ -58,7 +64,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
fused_module : Optional[Type[nn.Sequential]] = None
if conv.training:
if is_qat:
# TODO: remove the assert later
assert conv.training, "qat is only supported when conv.training is True currently"
map_to_fused_module_train = {
nn.Conv1d: nni.ConvBnReLU1d,
nn.Conv2d: nni.ConvBnReLU2d,
@ -85,10 +93,12 @@ def fuse_conv_bn_relu(conv, bn, relu):
else:
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
def fuse_linear_bn(linear, bn):
def fuse_linear_bn(is_qat, linear, bn):
r"""Given the linear and bn modules, fuses them and returns the fused module
Args:
is_qat: a flag for whether we are using quantization aware training fusion
or post training quantization fusion
linear: Module instance of type Linear
bn: BatchNorm1d instance that needs to be fused with the linear layer
@ -101,13 +111,14 @@ def fuse_linear_bn(linear, bn):
assert(linear.training == bn.training),\
"Linear and BN both must be in the same mode (train or eval)."
if linear.training:
if is_qat:
# TODO: remove the assert later
assert linear.training, "qat is only supported when linear.training is True currently"
raise Exception("Fusing Linear+BatchNorm not yet supported in training.")
else:
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
def fuse_convtranspose_bn(convt, bn):
def fuse_convtranspose_bn(is_qat, convt, bn):
r"""Given ConvTranspose and bn modules, fuses them and returns the fused module
Args:
@ -124,11 +135,20 @@ def fuse_convtranspose_bn(convt, bn):
assert(convt.training == bn.training),\
"ConvTranspose and BN both must be in the same mode (train or eval)."
if convt.training:
if is_qat:
assert convt.training, "qat is only supported when convt.training is True currently"
raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in training.")
else:
return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
def sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules
"""
def fuser_method(is_qat, m1, m2):
return sequential(m1, m2)
return fuser_method
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
@ -137,13 +157,13 @@ DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
(nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
(nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
(nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
(nn.Linear, nn.ReLU): nni.LinearReLU,
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
(nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
(nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
(nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
(nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
(nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
(nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
@ -161,13 +181,25 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
return fuser_method
def reverse_sequential_wrapper2(sequential):
""" Given a sequential class for two modules, return a function that takes
is_qat, and then two modules as argument, that ignores the is_qat flag
and always returns the sequential that combines the two input modules, with
the order of two inputs reversed
"""
def fuser_method(is_qat, m1, m2):
return sequential(m2, m1)
return fuser_method
def reverse2(f):
return lambda x, y: f(y, x)
def reversed(is_qat, x, y):
return f(is_qat, y, x)
return reversed
def reverse3(f):
def reversed(x, w):
def reversed(is_qat, x, w):
y, z = w
return f(z, y, x)
return f(is_qat, z, y, x)
return reversed
DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
@ -177,13 +209,13 @@ DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] =
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
(nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
(nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d),
(nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d),
(nn.ReLU, nn.Conv3d): reverse2(nni.ConvReLU3d),
(nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d),
(nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d),
(nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d),
(nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
(nn.ReLU, nn.Linear): reverse2(nni.LinearReLU),
(nn.ReLU, nn.BatchNorm2d): reverse2(nni.BNReLU2d),
(nn.ReLU, nn.BatchNorm3d): reverse2(nni.BNReLU3d),
(nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU),
(nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d),
(nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d),
(nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
(nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
(nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),

View File

@ -28,6 +28,7 @@ class Fuser:
def fuse(
self,
model: GraphModule,
is_qat: bool,
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
@ -72,7 +73,7 @@ class Fuser:
root_node = get_root_node(matched_node_pattern) # type: ignore[index]
env[node.name] = obj.fuse(
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
fuse_custom_config_dict, fuser_method_mapping)
fuse_custom_config_dict, fuser_method_mapping, is_qat)
elif maybe_last_node is None:
env[node.name] = self.fused_graph.node_copy(node, load_arg)
# node matched in patterns and is not root is removed here

View File

@ -28,7 +28,8 @@ class FuseHandler(ABC):
root_node: Node,
matched_node_pattern: NodePattern,
fuse_custom_config_dict: Dict[str, Any],
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
is_qat: bool) -> Node:
pass
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
@ -69,7 +70,8 @@ class DefaultFuseHandler(FuseHandler):
root_node: Node,
matched_node_pattern: NodePattern,
fuse_custom_config_dict: Dict[str, Any],
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
is_qat: bool) -> Node:
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
root_module = quantizer.modules[root_node.target]
@ -113,7 +115,7 @@ class DefaultFuseHandler(FuseHandler):
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
# TODO: change the signature for fuser_method to take matched module patterns
# as input
fused_module = fuser_method(*matched_modules)
fused_module = fuser_method(is_qat, *matched_modules)
# TODO: maybe add a pass to cleanup bn modules?
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
return quantizer.fused_graph.node_copy(root_node, load_arg)

View File

@ -1198,6 +1198,7 @@ def insert_observers_for_model(
def run_prepare_fx_on_standalone_modules(
model: torch.nn.Module,
is_qat: bool,
modules: Dict[str, torch.nn.Module],
matches: Any,
prepare_custom_config_dict: Dict[str, Any],
@ -1228,6 +1229,7 @@ def run_prepare_fx_on_standalone_modules(
prepare(
standalone_module,
sm_qconfig_dict,
is_qat,
sm_prepare_config_dict,
backend_config_dict=sm_backend_config_dict)
preserved_attributes = \
@ -1264,12 +1266,12 @@ def save_state(
def prepare(
model: GraphModule,
qconfig_dict: Any,
is_qat: bool,
node_name_to_scope: Dict[str, Tuple[str, type]],
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False,
is_qat: bool = False) -> ObservedGraphModule:
is_standalone_module: bool = False) -> ObservedGraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
@ -1388,7 +1390,7 @@ def prepare(
"output_quantized_idxs", [])
run_prepare_fx_on_standalone_modules(
model, modules, matches, prepare_custom_config_dict, backend_config_dict)
model, is_qat, modules, matches, prepare_custom_config_dict, backend_config_dict)
# record names for the set of observed node, so that in convert step
# we know whether we need to convert a floating point module to reference

View File

@ -11,9 +11,9 @@ from torch.fx import (
from torch.fx.graph import (
Graph,
)
from torch.nn.intrinsic import _FusedModule
from ..utils import _parent_name
from ..fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD
from ..qconfig_dict_utils import (
get_object_type_qconfig,
maybe_adjust_qconfig_for_module_type_or_name,
@ -56,23 +56,29 @@ def update_qconfig_for_fusion(
for node in model.graph.nodes:
if node.op == 'call_module' and node.target in modules:
module_type = type(modules[str(node.target)])
if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()):
maybe_fused_module = modules[str(node.target)]
if not isinstance(maybe_fused_module, _FusedModule):
continue
for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items():
if module_type == fuser:
fused_qconfig = object_type_dict.get(ops[0], None)
ops = list(maybe_fused_module._modules.values())
fused_qconfig = object_type_dict.get(type(ops[0]), None)
# Raise an error if the modules in the fused module have
# different qconfigs specified in the qconfig_dict
for op in ops:
if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig):
raise LookupError("During fusion, we need to specify the same " +
f"qconfigs for both modules in {module_type}.")
# TODO: currently it only works for modules,
# need to make this work for torch.nn.functional.relu
# TODO: currently it only works for object_type configurations,
# ideally it should work for different types of configurations,
# maybe we want to redesign this part
for op in ops[1:]:
if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig):
raise LookupError(
"During fusion, we need to specify the same " +
f"qconfigs for all module types in {type(maybe_fused_module)} " +
f"offending type: {type(op)}")
if fused_qconfig is not None:
object_type_dict[module_type] = fused_qconfig
object_type_dict[type(maybe_fused_module)] = fused_qconfig
return qconfig_dict

View File

@ -47,6 +47,7 @@ def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
def _fuse_fx(
graph_module: GraphModule,
is_qat: bool,
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
@ -57,7 +58,8 @@ def _fuse_fx(
"""
_check_is_graph_module(graph_module)
fuser = Fuser()
return fuser.fuse(graph_module, fuse_custom_config_dict, backend_config_dict)
return fuser.fuse(
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)
class Scope(object):
@ -175,11 +177,11 @@ class QuantizationTracer(Tracer):
def _prepare_fx(
model: torch.nn.Module,
qconfig_dict: Any,
is_qat: bool,
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False,
is_qat: bool = False,
) -> ObservedGraphModule:
r""" Internal helper function for prepare_fx
Args:
@ -235,16 +237,20 @@ forward graph of the parent module,
graph_module = GraphModule(model, tracer.trace(model))
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict, backend_config_dict)
graph_module = _fuse_fx(
graph_module,
is_qat,
prepare_custom_config_dict,
backend_config_dict)
prepared = prepare(
graph_module,
qconfig_dict,
is_qat,
tracer.node_name_to_scope,
prepare_custom_config_dict=prepare_custom_config_dict,
equalization_qconfig_dict=equalization_qconfig_dict,
backend_config_dict=backend_config_dict,
is_standalone_module=is_standalone_module,
is_qat=is_qat,
)
for attr_name in preserved_attributes:
@ -255,9 +261,9 @@ forward graph of the parent module,
def _prepare_standalone_module_fx(
model: torch.nn.Module,
qconfig_dict: Any,
is_qat: bool,
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_qat: bool = False,
) -> GraphModule:
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
@ -284,10 +290,10 @@ def _prepare_standalone_module_fx(
return _prepare_fx(
model,
qconfig_dict,
is_qat,
prepare_custom_config_dict,
backend_config_dict=backend_config_dict,
is_standalone_module=True,
is_qat=is_qat,
)
@ -332,7 +338,7 @@ def fuse_fx(
)
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
return _fuse_fx(graph_module, fuse_custom_config_dict)
return _fuse_fx(graph_module, False, fuse_custom_config_dict)
def prepare_fx(
@ -509,10 +515,10 @@ def prepare_fx(
return _prepare_fx(
model,
qconfig_dict,
False, # is_qat
prepare_custom_config_dict,
equalization_qconfig_dict,
backend_config_dict,
is_qat=False,
)
@ -558,9 +564,9 @@ def prepare_qat_fx(
return _prepare_fx(
model,
qconfig_dict,
True, # is_qat
prepare_custom_config_dict,
backend_config_dict=backend_config_dict,
is_qat=True,
)

View File

@ -1195,6 +1195,10 @@ class AnnotatedConvBnReLUModel(torch.nn.Module):
return x
def fuse_model(self):
# TODO: remove this check and define two fuse_modules function on this module
if self.training:
torch.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True)
else:
torch.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
class TwoLayerConvModel(torch.nn.Module):
@ -1464,7 +1468,11 @@ class InnerModule(torch.nn.Module):
if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
fusable_layers.append([current_name,
named_children[idx + 1][0]])
torch.quantization.fuse_modules(self, fusable_layers, inplace=True)
# TODO: remove this check and define two fuse_modules function on this module
if self.training:
torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
else:
torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
class FunctionalLinear(torch.nn.Module):
def __init__(self):
@ -1955,7 +1963,11 @@ class ResNetBase(torch.nn.Module):
return out
def fuse_model(self):
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
# TODO: remove this check and define two fuse_model function on this module
if self.training:
torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True)
else:
torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
class ModelMultipleOps(torch.nn.Module):
def __init__(self):