mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[reland][bc-breaking][quant][be] Refactor fuser_method to include is_qat argument" (#71956)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71956 Pull Request resolved: https://github.com/facebookresearch/mobile-vision/pull/59 Original commit changeset: f3912e210e8c Original Phabricator Diff: D33178977 (ef501e8fed) Test Plan: Please see original diff for test plans **Static Docs Preview: classyvision** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D33833203/V3/classyvision/)| |**Modified Pages**| Reviewed By: andrewor14 Differential Revision: D33833203 fbshipit-source-id: 74a8f22730b00aafa6a173b208e635c1d696959e (cherry picked from commitfb88772b18)
This commit is contained in:
parent
847dbb8684
commit
082ff25f37
|
|
@ -820,7 +820,7 @@ class TestDistributed(QuantizationTestCase):
|
|||
torch.ao.quantization.DeQuantStub(),
|
||||
)
|
||||
|
||||
torch.ao.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules_qat(model, [['1', '2', '3'], ['4', '5']], inplace=True)
|
||||
|
||||
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
|
||||
torch.ao.quantization.prepare_qat(model, inplace=True)
|
||||
|
|
@ -861,7 +861,7 @@ class TestDistributed(QuantizationTestCase):
|
|||
|
||||
model = Model()
|
||||
# fuse it
|
||||
fused_model = torch.ao.quantization.fuse_modules(
|
||||
fused_model = torch.ao.quantization.fuse_modules_qat(
|
||||
model,
|
||||
[['conv', 'bn']],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from torch.ao.quantization import (
|
|||
prepare_qat,
|
||||
quantize_qat,
|
||||
fuse_modules,
|
||||
fuse_modules_qat,
|
||||
QConfig,
|
||||
default_qconfig,
|
||||
default_qat_qconfig,
|
||||
|
|
@ -43,8 +44,8 @@ class TestFuseEager(QuantizationTestCase):
|
|||
def test_fuse_module_train(self):
|
||||
model = ModelForFusion(default_qat_qconfig).train()
|
||||
# Test step by step fusion
|
||||
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
|
||||
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
||||
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
|
||||
msg="Fused Conv + BN + Relu first layer")
|
||||
self.assertEqual(type(model.bn1), torch.nn.Identity,
|
||||
|
|
@ -91,7 +92,9 @@ class TestFuseEager(QuantizationTestCase):
|
|||
checkQuantized(model)
|
||||
|
||||
model = ModelForFusion(default_qat_qconfig).train()
|
||||
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
|
||||
model = fuse_modules_qat(
|
||||
model,
|
||||
[['conv1', 'bn1', 'relu1'],
|
||||
['sub1.conv', 'sub1.bn']])
|
||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
||||
|
|
@ -101,7 +104,9 @@ class TestFuseEager(QuantizationTestCase):
|
|||
def test_fuse_module_eval(self):
|
||||
model = ModelForFusion(default_qconfig)
|
||||
model.eval()
|
||||
model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
|
||||
model = fuse_modules(
|
||||
model,
|
||||
[['conv3', 'bn3', 'relu4'],
|
||||
['conv1', 'bn1', 'relu1'],
|
||||
['conv2', 'relu2'],
|
||||
['bn2', 'relu3'],
|
||||
|
|
@ -168,7 +173,9 @@ class TestFuseEager(QuantizationTestCase):
|
|||
checkQuantized(model)
|
||||
|
||||
model = ModelForFusion(default_qconfig).eval()
|
||||
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
|
||||
model = fuse_modules(
|
||||
model,
|
||||
[['conv1', 'bn1', 'relu1'],
|
||||
['conv2', 'relu2'],
|
||||
['bn2', 'relu3'],
|
||||
['sub1.conv', 'sub1.bn'],
|
||||
|
|
@ -181,11 +188,13 @@ class TestFuseEager(QuantizationTestCase):
|
|||
with override_quantized_engine(qengine):
|
||||
model = ModelWithSequentialFusion().train()
|
||||
model.to(torch.float)
|
||||
fuse_modules(model, [['conv1', 'relu1'] ,
|
||||
fuse_modules_qat(
|
||||
model, [['conv1', 'relu1'] ,
|
||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||
['classifier.0', 'classifier.1']], inplace=True)
|
||||
['classifier.0', 'classifier.1']],
|
||||
inplace=True)
|
||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||
|
|
@ -233,11 +242,14 @@ class TestFuseEager(QuantizationTestCase):
|
|||
with override_quantized_engine(qengine):
|
||||
model = ModelWithSequentialFusion().eval()
|
||||
model.to(torch.float)
|
||||
fuse_modules(model, [['conv1', 'relu1'] ,
|
||||
fuse_modules(
|
||||
model,
|
||||
[['conv1', 'relu1'],
|
||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||
['classifier.0', 'classifier.1']], inplace=True)
|
||||
['classifier.0', 'classifier.1']],
|
||||
inplace=True)
|
||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||
|
|
@ -286,7 +298,9 @@ class TestFuseEager(QuantizationTestCase):
|
|||
# fused model
|
||||
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
|
||||
weight=torch.nn.Identity)
|
||||
model = fuse_modules(model_orig, [["conv1", "bn1", "relu1"],
|
||||
model = fuse_modules_qat(
|
||||
model_orig,
|
||||
[["conv1", "bn1", "relu1"],
|
||||
["conv2", "bn2"]])
|
||||
prep_model = prepare_qat(model, inplace=False)
|
||||
# output with fusion but no observers.
|
||||
|
|
@ -385,8 +399,8 @@ class TestFuseEager(QuantizationTestCase):
|
|||
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
|
||||
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
|
||||
|
||||
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
|
||||
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
||||
|
||||
fused = True
|
||||
before_fusion_pre_count = counter['pre_forwards']
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class TestModelNumericsEager(QuantizationTestCase):
|
|||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||
fq_model.train()
|
||||
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
|
||||
torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.prepare_qat(fq_model)
|
||||
fq_model.eval()
|
||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||
|
|
@ -105,7 +105,7 @@ class TestModelNumericsEager(QuantizationTestCase):
|
|||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||
fq_model.train()
|
||||
fq_model.qconfig = qconfig
|
||||
torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.prepare_qat(fq_model)
|
||||
fq_model.eval()
|
||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from torch.ao.quantization import (
|
|||
get_default_qat_qconfig,
|
||||
get_default_qconfig_dict,
|
||||
fuse_modules,
|
||||
fuse_modules_qat,
|
||||
prepare,
|
||||
prepare_qat,
|
||||
convert,
|
||||
|
|
@ -363,6 +364,8 @@ class TestFuseFx(QuantizationTestCase):
|
|||
|
||||
@skipIfNoFBGEMM
|
||||
def test_qconfig_fused_module(self):
|
||||
""" TODO: add test for all fused modules
|
||||
"""
|
||||
qconfig_dict = {
|
||||
"": None,
|
||||
"object_type": [(nn.Linear, default_qconfig),
|
||||
|
|
@ -890,14 +893,19 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
m_eager.eval()
|
||||
qconfig = get_default_qconfig(qengine)
|
||||
prepare_fn = prepare
|
||||
is_qat = False
|
||||
else:
|
||||
m_eager.train()
|
||||
qconfig = get_default_qat_qconfig(qengine)
|
||||
prepare_fn = prepare_qat
|
||||
is_qat = True
|
||||
|
||||
fuse_list = ["conv", "bn"]
|
||||
if has_relu:
|
||||
fuse_list.append("relu")
|
||||
if is_qat:
|
||||
fuse_modules_qat(m_eager, fuse_list, inplace=True)
|
||||
else:
|
||||
fuse_modules(m_eager, fuse_list, inplace=True)
|
||||
m_eager.qconfig = qconfig
|
||||
m_eager = prepare_fn(m_eager)
|
||||
|
|
@ -5847,6 +5855,7 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
graph.eval()
|
||||
calibrate_or_train = test_only_eval_fn
|
||||
data = self.img_data_2d
|
||||
is_qat = False
|
||||
else:
|
||||
assert quant_type == QuantType.QAT
|
||||
qconfig = default_qat_qconfig
|
||||
|
|
@ -5856,6 +5865,7 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
graph.train()
|
||||
calibrate_or_train = test_only_train_fn
|
||||
data = self.img_data_2d_train
|
||||
is_qat = True
|
||||
|
||||
if hasattr(eager, "fuse_model"):
|
||||
eager.fuse_model()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from .fake_quantize import * # noqa: F403
|
||||
from .fuse_modules import fuse_modules # noqa: F403
|
||||
from .fuse_modules import fuse_modules_qat # noqa: F403
|
||||
from .fuser_method_mappings import * # noqa: F403
|
||||
from .observer import * # noqa: F403
|
||||
from .qconfig import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module):
|
|||
|
||||
setattr(cur_mod, tokens[-1], module)
|
||||
|
||||
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
||||
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
|
||||
r"""Returns a list of modules that fuses the operations specified
|
||||
in the input module list.
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
|||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
|
||||
fused = fuser_method(*mod_list)
|
||||
fused = fuser_method(is_qat, *mod_list)
|
||||
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
|
||||
# Move pre forward hooks of the base module to resulting fused module
|
||||
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
|
||||
|
|
@ -65,7 +65,7 @@ def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
|||
|
||||
return new_mod
|
||||
|
||||
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
if fuse_custom_config_dict is None:
|
||||
fuse_custom_config_dict = {}
|
||||
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
||||
|
|
@ -74,12 +74,25 @@ def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_cu
|
|||
mod_list.append(_get_module(model, item))
|
||||
|
||||
# Fuse list of modules
|
||||
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
|
||||
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
|
||||
|
||||
# Replace original module list with fused module list
|
||||
for i, item in enumerate(modules_to_fuse):
|
||||
_set_module(model, item, new_mod_list[i])
|
||||
|
||||
def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
||||
# Handle case of modules_to_fuse being a list
|
||||
_fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict)
|
||||
else:
|
||||
# Handle case of modules_to_fuse being a list of lists
|
||||
for module_list in modules_to_fuse:
|
||||
_fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict)
|
||||
return model
|
||||
|
||||
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
r"""Fuses a list of modules into a single module
|
||||
|
||||
|
|
@ -121,27 +134,34 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
|
|||
|
||||
Examples::
|
||||
|
||||
>>> m = myModel()
|
||||
>>> m = M().eval()
|
||||
>>> # m is a module containing the sub-modules below
|
||||
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
|
||||
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
|
||||
>>> output = fused_m(input)
|
||||
|
||||
>>> m = myModel()
|
||||
>>> m = M().eval()
|
||||
>>> # Alternately provide a single list of modules to fuse
|
||||
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
|
||||
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
|
||||
>>> output = fused_m(input)
|
||||
|
||||
"""
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
return _fuse_modules(
|
||||
model,
|
||||
modules_to_fuse,
|
||||
is_qat=False,
|
||||
inplace=inplace,
|
||||
fuser_func=fuse_known_modules,
|
||||
fuse_custom_config_dict=None)
|
||||
|
||||
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
||||
# Handle case of modules_to_fuse being a list
|
||||
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
|
||||
else:
|
||||
# Handle case of modules_to_fuse being a list of lists
|
||||
for module_list in modules_to_fuse:
|
||||
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
|
||||
return model
|
||||
def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
""" QAT version for `fuse_modules`
|
||||
"""
|
||||
return _fuse_modules(
|
||||
model,
|
||||
modules_to_fuse,
|
||||
is_qat=True,
|
||||
inplace=inplace,
|
||||
fuser_func=fuse_known_modules,
|
||||
fuse_custom_config_dict=None)
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@ from torch.ao.quantization.utils import Pattern
|
|||
from torch.ao.quantization.utils import get_combined_dict
|
||||
|
||||
|
||||
def fuse_conv_bn(conv, bn):
|
||||
def fuse_conv_bn(is_qat, conv, bn):
|
||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
is_qat: a flag for whether we are using quantization aware training fusion
|
||||
or post training quantization fusion
|
||||
conv: Module instance of type conv2d/conv3d
|
||||
bn: Spatial BN instance that needs to be fused with the conv
|
||||
|
||||
|
|
@ -29,7 +31,9 @@ def fuse_conv_bn(conv, bn):
|
|||
nn.Conv3d: nni.ConvBn3d,
|
||||
}
|
||||
|
||||
if conv.training:
|
||||
if is_qat:
|
||||
# TODO: remove the assert later
|
||||
assert conv.training, "qat is only supported when conv.training is True currently"
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
||||
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
||||
|
|
@ -41,10 +45,12 @@ def fuse_conv_bn(conv, bn):
|
|||
else:
|
||||
return nn.utils.fuse_conv_bn_eval(conv, bn)
|
||||
|
||||
def fuse_conv_bn_relu(conv, bn, relu):
|
||||
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
is_qat: a flag for whether we are using quantization aware training fusion
|
||||
or post training quantization fusion
|
||||
conv: Module instance of type conv2d/conv3d
|
||||
bn: Spatial BN instance that needs to be fused with the conv
|
||||
|
||||
|
|
@ -58,7 +64,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
|
|||
assert(conv.training == bn.training == relu.training),\
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
fused_module : Optional[Type[nn.Sequential]] = None
|
||||
if conv.training:
|
||||
if is_qat:
|
||||
# TODO: remove the assert later
|
||||
assert conv.training, "qat is only supported when conv.training is True currently"
|
||||
map_to_fused_module_train = {
|
||||
nn.Conv1d: nni.ConvBnReLU1d,
|
||||
nn.Conv2d: nni.ConvBnReLU2d,
|
||||
|
|
@ -85,10 +93,12 @@ def fuse_conv_bn_relu(conv, bn, relu):
|
|||
else:
|
||||
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
|
||||
|
||||
def fuse_linear_bn(linear, bn):
|
||||
def fuse_linear_bn(is_qat, linear, bn):
|
||||
r"""Given the linear and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
is_qat: a flag for whether we are using quantization aware training fusion
|
||||
or post training quantization fusion
|
||||
linear: Module instance of type Linear
|
||||
bn: BatchNorm1d instance that needs to be fused with the linear layer
|
||||
|
||||
|
|
@ -101,13 +111,14 @@ def fuse_linear_bn(linear, bn):
|
|||
assert(linear.training == bn.training),\
|
||||
"Linear and BN both must be in the same mode (train or eval)."
|
||||
|
||||
if linear.training:
|
||||
if is_qat:
|
||||
# TODO: remove the assert later
|
||||
assert linear.training, "qat is only supported when linear.training is True currently"
|
||||
raise Exception("Fusing Linear+BatchNorm not yet supported in training.")
|
||||
else:
|
||||
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
|
||||
|
||||
|
||||
def fuse_convtranspose_bn(convt, bn):
|
||||
def fuse_convtranspose_bn(is_qat, convt, bn):
|
||||
r"""Given ConvTranspose and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
|
|
@ -124,11 +135,20 @@ def fuse_convtranspose_bn(convt, bn):
|
|||
assert(convt.training == bn.training),\
|
||||
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
||||
|
||||
if convt.training:
|
||||
if is_qat:
|
||||
assert convt.training, "qat is only supported when convt.training is True currently"
|
||||
raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in training.")
|
||||
else:
|
||||
return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
|
||||
|
||||
def sequential_wrapper2(sequential):
|
||||
""" Given a sequential class for two modules, return a function that takes
|
||||
is_qat, and then two modules as argument, that ignores the is_qat flag
|
||||
and always returns the sequential that combines the two input modules
|
||||
"""
|
||||
def fuser_method(is_qat, m1, m2):
|
||||
return sequential(m1, m2)
|
||||
return fuser_method
|
||||
|
||||
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
||||
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
|
||||
|
|
@ -137,13 +157,13 @@ DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
|||
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
|
||||
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
|
||||
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
|
||||
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
|
||||
(nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
|
||||
(nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
|
||||
(nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
|
||||
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
|
||||
(nn.Linear, nn.ReLU): nni.LinearReLU,
|
||||
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
|
||||
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
|
||||
(nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
|
||||
(nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
|
||||
(nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
|
||||
(nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
|
||||
(nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
|
||||
(nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
|
||||
|
|
@ -161,13 +181,25 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
|
|||
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
|
||||
return fuser_method
|
||||
|
||||
def reverse_sequential_wrapper2(sequential):
|
||||
""" Given a sequential class for two modules, return a function that takes
|
||||
is_qat, and then two modules as argument, that ignores the is_qat flag
|
||||
and always returns the sequential that combines the two input modules, with
|
||||
the order of two inputs reversed
|
||||
"""
|
||||
def fuser_method(is_qat, m1, m2):
|
||||
return sequential(m2, m1)
|
||||
return fuser_method
|
||||
|
||||
def reverse2(f):
|
||||
return lambda x, y: f(y, x)
|
||||
def reversed(is_qat, x, y):
|
||||
return f(is_qat, y, x)
|
||||
return reversed
|
||||
|
||||
def reverse3(f):
|
||||
def reversed(x, w):
|
||||
def reversed(is_qat, x, w):
|
||||
y, z = w
|
||||
return f(z, y, x)
|
||||
return f(is_qat, z, y, x)
|
||||
return reversed
|
||||
|
||||
DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
|
||||
|
|
@ -177,13 +209,13 @@ DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] =
|
|||
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
|
||||
(nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
|
||||
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
|
||||
(nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d),
|
||||
(nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d),
|
||||
(nn.ReLU, nn.Conv3d): reverse2(nni.ConvReLU3d),
|
||||
(nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d),
|
||||
(nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d),
|
||||
(nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d),
|
||||
(nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
|
||||
(nn.ReLU, nn.Linear): reverse2(nni.LinearReLU),
|
||||
(nn.ReLU, nn.BatchNorm2d): reverse2(nni.BNReLU2d),
|
||||
(nn.ReLU, nn.BatchNorm3d): reverse2(nni.BNReLU3d),
|
||||
(nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU),
|
||||
(nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d),
|
||||
(nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d),
|
||||
(nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
|
||||
(nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
|
||||
(nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class Fuser:
|
|||
def fuse(
|
||||
self,
|
||||
model: GraphModule,
|
||||
is_qat: bool,
|
||||
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> GraphModule:
|
||||
|
|
@ -72,7 +73,7 @@ class Fuser:
|
|||
root_node = get_root_node(matched_node_pattern) # type: ignore[index]
|
||||
env[node.name] = obj.fuse(
|
||||
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
|
||||
fuse_custom_config_dict, fuser_method_mapping)
|
||||
fuse_custom_config_dict, fuser_method_mapping, is_qat)
|
||||
elif maybe_last_node is None:
|
||||
env[node.name] = self.fused_graph.node_copy(node, load_arg)
|
||||
# node matched in patterns and is not root is removed here
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ class FuseHandler(ABC):
|
|||
root_node: Node,
|
||||
matched_node_pattern: NodePattern,
|
||||
fuse_custom_config_dict: Dict[str, Any],
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
|
||||
is_qat: bool) -> Node:
|
||||
pass
|
||||
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
|
||||
|
|
@ -69,7 +70,8 @@ class DefaultFuseHandler(FuseHandler):
|
|||
root_node: Node,
|
||||
matched_node_pattern: NodePattern,
|
||||
fuse_custom_config_dict: Dict[str, Any],
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
|
||||
is_qat: bool) -> Node:
|
||||
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
||||
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
|
||||
root_module = quantizer.modules[root_node.target]
|
||||
|
|
@ -113,7 +115,7 @@ class DefaultFuseHandler(FuseHandler):
|
|||
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
|
||||
# TODO: change the signature for fuser_method to take matched module patterns
|
||||
# as input
|
||||
fused_module = fuser_method(*matched_modules)
|
||||
fused_module = fuser_method(is_qat, *matched_modules)
|
||||
# TODO: maybe add a pass to cleanup bn modules?
|
||||
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
|
||||
return quantizer.fused_graph.node_copy(root_node, load_arg)
|
||||
|
|
|
|||
|
|
@ -1198,6 +1198,7 @@ def insert_observers_for_model(
|
|||
|
||||
def run_prepare_fx_on_standalone_modules(
|
||||
model: torch.nn.Module,
|
||||
is_qat: bool,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
matches: Any,
|
||||
prepare_custom_config_dict: Dict[str, Any],
|
||||
|
|
@ -1228,6 +1229,7 @@ def run_prepare_fx_on_standalone_modules(
|
|||
prepare(
|
||||
standalone_module,
|
||||
sm_qconfig_dict,
|
||||
is_qat,
|
||||
sm_prepare_config_dict,
|
||||
backend_config_dict=sm_backend_config_dict)
|
||||
preserved_attributes = \
|
||||
|
|
@ -1264,12 +1266,12 @@ def save_state(
|
|||
def prepare(
|
||||
model: GraphModule,
|
||||
qconfig_dict: Any,
|
||||
is_qat: bool,
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
is_standalone_module: bool = False,
|
||||
is_qat: bool = False) -> ObservedGraphModule:
|
||||
is_standalone_module: bool = False) -> ObservedGraphModule:
|
||||
""" standalone_module means it a submodule that is not inlined in
|
||||
parent module, and will be quantized separately as one unit.
|
||||
|
||||
|
|
@ -1388,7 +1390,7 @@ def prepare(
|
|||
"output_quantized_idxs", [])
|
||||
|
||||
run_prepare_fx_on_standalone_modules(
|
||||
model, modules, matches, prepare_custom_config_dict, backend_config_dict)
|
||||
model, is_qat, modules, matches, prepare_custom_config_dict, backend_config_dict)
|
||||
|
||||
# record names for the set of observed node, so that in convert step
|
||||
# we know whether we need to convert a floating point module to reference
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ from torch.fx import (
|
|||
from torch.fx.graph import (
|
||||
Graph,
|
||||
)
|
||||
from torch.nn.intrinsic import _FusedModule
|
||||
|
||||
from ..utils import _parent_name
|
||||
from ..fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD
|
||||
from ..qconfig_dict_utils import (
|
||||
get_object_type_qconfig,
|
||||
maybe_adjust_qconfig_for_module_type_or_name,
|
||||
|
|
@ -56,23 +56,29 @@ def update_qconfig_for_fusion(
|
|||
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'call_module' and node.target in modules:
|
||||
module_type = type(modules[str(node.target)])
|
||||
if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()):
|
||||
maybe_fused_module = modules[str(node.target)]
|
||||
if not isinstance(maybe_fused_module, _FusedModule):
|
||||
continue
|
||||
|
||||
for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items():
|
||||
if module_type == fuser:
|
||||
fused_qconfig = object_type_dict.get(ops[0], None)
|
||||
ops = list(maybe_fused_module._modules.values())
|
||||
fused_qconfig = object_type_dict.get(type(ops[0]), None)
|
||||
|
||||
# Raise an error if the modules in the fused module have
|
||||
# different qconfigs specified in the qconfig_dict
|
||||
for op in ops:
|
||||
if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig):
|
||||
raise LookupError("During fusion, we need to specify the same " +
|
||||
f"qconfigs for both modules in {module_type}.")
|
||||
# TODO: currently it only works for modules,
|
||||
# need to make this work for torch.nn.functional.relu
|
||||
# TODO: currently it only works for object_type configurations,
|
||||
# ideally it should work for different types of configurations,
|
||||
# maybe we want to redesign this part
|
||||
for op in ops[1:]:
|
||||
if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig):
|
||||
raise LookupError(
|
||||
"During fusion, we need to specify the same " +
|
||||
f"qconfigs for all module types in {type(maybe_fused_module)} " +
|
||||
f"offending type: {type(op)}")
|
||||
|
||||
if fused_qconfig is not None:
|
||||
object_type_dict[module_type] = fused_qconfig
|
||||
object_type_dict[type(maybe_fused_module)] = fused_qconfig
|
||||
|
||||
return qconfig_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
|
|||
|
||||
def _fuse_fx(
|
||||
graph_module: GraphModule,
|
||||
is_qat: bool,
|
||||
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> GraphModule:
|
||||
|
|
@ -57,7 +58,8 @@ def _fuse_fx(
|
|||
"""
|
||||
_check_is_graph_module(graph_module)
|
||||
fuser = Fuser()
|
||||
return fuser.fuse(graph_module, fuse_custom_config_dict, backend_config_dict)
|
||||
return fuser.fuse(
|
||||
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)
|
||||
|
||||
|
||||
class Scope(object):
|
||||
|
|
@ -175,11 +177,11 @@ class QuantizationTracer(Tracer):
|
|||
def _prepare_fx(
|
||||
model: torch.nn.Module,
|
||||
qconfig_dict: Any,
|
||||
is_qat: bool,
|
||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
is_standalone_module: bool = False,
|
||||
is_qat: bool = False,
|
||||
) -> ObservedGraphModule:
|
||||
r""" Internal helper function for prepare_fx
|
||||
Args:
|
||||
|
|
@ -235,16 +237,20 @@ forward graph of the parent module,
|
|||
graph_module = GraphModule(model, tracer.trace(model))
|
||||
for attr_name in preserved_attributes:
|
||||
setattr(graph_module, attr_name, getattr(model, attr_name))
|
||||
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict, backend_config_dict)
|
||||
graph_module = _fuse_fx(
|
||||
graph_module,
|
||||
is_qat,
|
||||
prepare_custom_config_dict,
|
||||
backend_config_dict)
|
||||
prepared = prepare(
|
||||
graph_module,
|
||||
qconfig_dict,
|
||||
is_qat,
|
||||
tracer.node_name_to_scope,
|
||||
prepare_custom_config_dict=prepare_custom_config_dict,
|
||||
equalization_qconfig_dict=equalization_qconfig_dict,
|
||||
backend_config_dict=backend_config_dict,
|
||||
is_standalone_module=is_standalone_module,
|
||||
is_qat=is_qat,
|
||||
)
|
||||
|
||||
for attr_name in preserved_attributes:
|
||||
|
|
@ -255,9 +261,9 @@ forward graph of the parent module,
|
|||
def _prepare_standalone_module_fx(
|
||||
model: torch.nn.Module,
|
||||
qconfig_dict: Any,
|
||||
is_qat: bool,
|
||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
is_qat: bool = False,
|
||||
) -> GraphModule:
|
||||
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
|
||||
parent module.
|
||||
|
|
@ -284,10 +290,10 @@ def _prepare_standalone_module_fx(
|
|||
return _prepare_fx(
|
||||
model,
|
||||
qconfig_dict,
|
||||
is_qat,
|
||||
prepare_custom_config_dict,
|
||||
backend_config_dict=backend_config_dict,
|
||||
is_standalone_module=True,
|
||||
is_qat=is_qat,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -332,7 +338,7 @@ def fuse_fx(
|
|||
)
|
||||
for attr_name in preserved_attributes:
|
||||
setattr(graph_module, attr_name, getattr(model, attr_name))
|
||||
return _fuse_fx(graph_module, fuse_custom_config_dict)
|
||||
return _fuse_fx(graph_module, False, fuse_custom_config_dict)
|
||||
|
||||
|
||||
def prepare_fx(
|
||||
|
|
@ -509,10 +515,10 @@ def prepare_fx(
|
|||
return _prepare_fx(
|
||||
model,
|
||||
qconfig_dict,
|
||||
False, # is_qat
|
||||
prepare_custom_config_dict,
|
||||
equalization_qconfig_dict,
|
||||
backend_config_dict,
|
||||
is_qat=False,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -558,9 +564,9 @@ def prepare_qat_fx(
|
|||
return _prepare_fx(
|
||||
model,
|
||||
qconfig_dict,
|
||||
True, # is_qat
|
||||
prepare_custom_config_dict,
|
||||
backend_config_dict=backend_config_dict,
|
||||
is_qat=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1195,6 +1195,10 @@ class AnnotatedConvBnReLUModel(torch.nn.Module):
|
|||
return x
|
||||
|
||||
def fuse_model(self):
|
||||
# TODO: remove this check and define two fuse_modules function on this module
|
||||
if self.training:
|
||||
torch.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True)
|
||||
else:
|
||||
torch.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
|
||||
|
||||
class TwoLayerConvModel(torch.nn.Module):
|
||||
|
|
@ -1464,7 +1468,11 @@ class InnerModule(torch.nn.Module):
|
|||
if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
|
||||
fusable_layers.append([current_name,
|
||||
named_children[idx + 1][0]])
|
||||
torch.quantization.fuse_modules(self, fusable_layers, inplace=True)
|
||||
# TODO: remove this check and define two fuse_modules function on this module
|
||||
if self.training:
|
||||
torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
|
||||
else:
|
||||
torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
|
||||
|
||||
class FunctionalLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -1955,7 +1963,11 @@ class ResNetBase(torch.nn.Module):
|
|||
return out
|
||||
|
||||
def fuse_model(self):
|
||||
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
# TODO: remove this check and define two fuse_model function on this module
|
||||
if self.training:
|
||||
torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
else:
|
||||
torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
|
||||
class ModelMultipleOps(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user