mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D33178977: [bc-breaking][quant][be] Refactor fuser_method to include is_qat argument
Test Plan: revert-hammer Differential Revision: D33178977 (ef501e8fed) Original commit changeset: 0c1499c45526 Original Phabricator Diff: D33178977 (ef501e8fed) fbshipit-source-id: f3912e210e8c588fdbdc9c3c5f4acf2aa8fe6678 (cherry picked from commitcd62183414)
This commit is contained in:
parent
bf69a61293
commit
56511f859a
|
|
@ -820,7 +820,7 @@ class TestDistributed(QuantizationTestCase):
|
|||
torch.ao.quantization.DeQuantStub(),
|
||||
)
|
||||
|
||||
torch.ao.quantization.fuse_modules_qat(model, [['1', '2', '3'], ['4', '5']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
|
||||
|
||||
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
|
||||
torch.ao.quantization.prepare_qat(model, inplace=True)
|
||||
|
|
@ -861,7 +861,7 @@ class TestDistributed(QuantizationTestCase):
|
|||
|
||||
model = Model()
|
||||
# fuse it
|
||||
fused_model = torch.ao.quantization.fuse_modules_qat(
|
||||
fused_model = torch.ao.quantization.fuse_modules(
|
||||
model,
|
||||
[['conv', 'bn']],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from torch.ao.quantization import (
|
|||
prepare_qat,
|
||||
quantize_qat,
|
||||
fuse_modules,
|
||||
fuse_modules_qat,
|
||||
QConfig,
|
||||
default_qconfig,
|
||||
default_qat_qconfig,
|
||||
|
|
@ -44,8 +43,8 @@ class TestFuseEager(QuantizationTestCase):
|
|||
def test_fuse_module_train(self):
|
||||
model = ModelForFusion(default_qat_qconfig).train()
|
||||
# Test step by step fusion
|
||||
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
||||
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
|
||||
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
|
||||
msg="Fused Conv + BN + Relu first layer")
|
||||
self.assertEqual(type(model.bn1), torch.nn.Identity,
|
||||
|
|
@ -92,10 +91,8 @@ class TestFuseEager(QuantizationTestCase):
|
|||
checkQuantized(model)
|
||||
|
||||
model = ModelForFusion(default_qat_qconfig).train()
|
||||
model = fuse_modules_qat(
|
||||
model,
|
||||
[['conv1', 'bn1', 'relu1'],
|
||||
['sub1.conv', 'sub1.bn']])
|
||||
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
|
||||
['sub1.conv', 'sub1.bn']])
|
||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
||||
checkQuantized(model)
|
||||
|
|
@ -104,13 +101,11 @@ class TestFuseEager(QuantizationTestCase):
|
|||
def test_fuse_module_eval(self):
|
||||
model = ModelForFusion(default_qconfig)
|
||||
model.eval()
|
||||
model = fuse_modules(
|
||||
model,
|
||||
[['conv3', 'bn3', 'relu4'],
|
||||
['conv1', 'bn1', 'relu1'],
|
||||
['conv2', 'relu2'],
|
||||
['bn2', 'relu3'],
|
||||
['sub1.conv', 'sub1.bn']])
|
||||
model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
|
||||
['conv1', 'bn1', 'relu1'],
|
||||
['conv2', 'relu2'],
|
||||
['bn2', 'relu3'],
|
||||
['sub1.conv', 'sub1.bn']])
|
||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||
msg="Fused Conv + BN + Relu first layer (BN is folded)")
|
||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||
|
|
@ -173,13 +168,11 @@ class TestFuseEager(QuantizationTestCase):
|
|||
checkQuantized(model)
|
||||
|
||||
model = ModelForFusion(default_qconfig).eval()
|
||||
model = fuse_modules(
|
||||
model,
|
||||
[['conv1', 'bn1', 'relu1'],
|
||||
['conv2', 'relu2'],
|
||||
['bn2', 'relu3'],
|
||||
['sub1.conv', 'sub1.bn'],
|
||||
['conv3', 'bn3', 'relu4']])
|
||||
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
|
||||
['conv2', 'relu2'],
|
||||
['bn2', 'relu3'],
|
||||
['sub1.conv', 'sub1.bn'],
|
||||
['conv3', 'bn3', 'relu4']])
|
||||
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
|
||||
checkQuantized(model)
|
||||
|
||||
|
|
@ -188,13 +181,11 @@ class TestFuseEager(QuantizationTestCase):
|
|||
with override_quantized_engine(qengine):
|
||||
model = ModelWithSequentialFusion().train()
|
||||
model.to(torch.float)
|
||||
fuse_modules_qat(
|
||||
model, [['conv1', 'relu1'] ,
|
||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||
['classifier.0', 'classifier.1']],
|
||||
inplace=True)
|
||||
fuse_modules(model, [['conv1', 'relu1'] ,
|
||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||
['classifier.0', 'classifier.1']], inplace=True)
|
||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||
|
|
@ -242,14 +233,11 @@ class TestFuseEager(QuantizationTestCase):
|
|||
with override_quantized_engine(qengine):
|
||||
model = ModelWithSequentialFusion().eval()
|
||||
model.to(torch.float)
|
||||
fuse_modules(
|
||||
model,
|
||||
[['conv1', 'relu1'],
|
||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||
['classifier.0', 'classifier.1']],
|
||||
inplace=True)
|
||||
fuse_modules(model, [['conv1', 'relu1'] ,
|
||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||
['classifier.0', 'classifier.1']], inplace=True)
|
||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||
|
|
@ -298,10 +286,8 @@ class TestFuseEager(QuantizationTestCase):
|
|||
# fused model
|
||||
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
|
||||
weight=torch.nn.Identity)
|
||||
model = fuse_modules_qat(
|
||||
model_orig,
|
||||
[["conv1", "bn1", "relu1"],
|
||||
["conv2", "bn2"]])
|
||||
model = fuse_modules(model_orig, [["conv1", "bn1", "relu1"],
|
||||
["conv2", "bn2"]])
|
||||
prep_model = prepare_qat(model, inplace=False)
|
||||
# output with fusion but no observers.
|
||||
out_fused = prep_model(self.img_data_2d[0][0])
|
||||
|
|
@ -399,8 +385,8 @@ class TestFuseEager(QuantizationTestCase):
|
|||
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
|
||||
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
|
||||
|
||||
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
||||
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
|
||||
|
||||
fused = True
|
||||
before_fusion_pre_count = counter['pre_forwards']
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class TestModelNumericsEager(QuantizationTestCase):
|
|||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||
fq_model.train()
|
||||
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
|
||||
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.prepare_qat(fq_model)
|
||||
fq_model.eval()
|
||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||
|
|
@ -105,7 +105,7 @@ class TestModelNumericsEager(QuantizationTestCase):
|
|||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||
fq_model.train()
|
||||
fq_model.qconfig = qconfig
|
||||
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.prepare_qat(fq_model)
|
||||
fq_model.eval()
|
||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ from torch.ao.quantization import (
|
|||
get_default_qat_qconfig,
|
||||
get_default_qconfig_dict,
|
||||
fuse_modules,
|
||||
fuse_modules_qat,
|
||||
prepare,
|
||||
prepare_qat,
|
||||
convert,
|
||||
|
|
@ -364,8 +363,6 @@ class TestFuseFx(QuantizationTestCase):
|
|||
|
||||
@skipIfNoFBGEMM
|
||||
def test_qconfig_fused_module(self):
|
||||
""" TODO: add test for all fused modules
|
||||
"""
|
||||
qconfig_dict = {
|
||||
"": None,
|
||||
"object_type": [(nn.Linear, default_qconfig),
|
||||
|
|
@ -893,20 +890,15 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
m_eager.eval()
|
||||
qconfig = get_default_qconfig(qengine)
|
||||
prepare_fn = prepare
|
||||
is_qat = False
|
||||
else:
|
||||
m_eager.train()
|
||||
qconfig = get_default_qat_qconfig(qengine)
|
||||
prepare_fn = prepare_qat
|
||||
is_qat = True
|
||||
|
||||
fuse_list = ["conv", "bn"]
|
||||
if has_relu:
|
||||
fuse_list.append("relu")
|
||||
if is_qat:
|
||||
fuse_modules_qat(m_eager, fuse_list, inplace=True)
|
||||
else:
|
||||
fuse_modules(m_eager, fuse_list, inplace=True)
|
||||
fuse_modules(m_eager, fuse_list, inplace=True)
|
||||
m_eager.qconfig = qconfig
|
||||
m_eager = prepare_fn(m_eager)
|
||||
m_eager(*self.img_data_dict[dim][0])
|
||||
|
|
@ -5837,7 +5829,6 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
graph.eval()
|
||||
calibrate_or_train = test_only_eval_fn
|
||||
data = self.img_data_2d
|
||||
is_qat = False
|
||||
else:
|
||||
assert quant_type == QuantType.QAT
|
||||
qconfig = default_qat_qconfig
|
||||
|
|
@ -5847,7 +5838,6 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
graph.train()
|
||||
calibrate_or_train = test_only_train_fn
|
||||
data = self.img_data_2d_train
|
||||
is_qat = True
|
||||
|
||||
if hasattr(eager, "fuse_model"):
|
||||
eager.fuse_model()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from .fake_quantize import * # noqa: F403
|
||||
from .fuse_modules import fuse_modules # noqa: F403
|
||||
from .fuse_modules import fuse_modules_qat # noqa: F403
|
||||
from .fuser_method_mappings import * # noqa: F403
|
||||
from .observer import * # noqa: F403
|
||||
from .qconfig import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module):
|
|||
|
||||
setattr(cur_mod, tokens[-1], module)
|
||||
|
||||
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
|
||||
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
||||
r"""Returns a list of modules that fuses the operations specified
|
||||
in the input module list.
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
|
|||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
|
||||
fused = fuser_method(is_qat, *mod_list)
|
||||
fused = fuser_method(*mod_list)
|
||||
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
|
||||
# Move pre forward hooks of the base module to resulting fused module
|
||||
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
|
||||
|
|
@ -65,7 +65,7 @@ def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
|
|||
|
||||
return new_mod
|
||||
|
||||
def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
if fuse_custom_config_dict is None:
|
||||
fuse_custom_config_dict = {}
|
||||
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
||||
|
|
@ -74,25 +74,12 @@ def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_m
|
|||
mod_list.append(_get_module(model, item))
|
||||
|
||||
# Fuse list of modules
|
||||
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
|
||||
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
|
||||
|
||||
# Replace original module list with fused module list
|
||||
for i, item in enumerate(modules_to_fuse):
|
||||
_set_module(model, item, new_mod_list[i])
|
||||
|
||||
def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
||||
# Handle case of modules_to_fuse being a list
|
||||
_fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict)
|
||||
else:
|
||||
# Handle case of modules_to_fuse being a list of lists
|
||||
for module_list in modules_to_fuse:
|
||||
_fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict)
|
||||
return model
|
||||
|
||||
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
r"""Fuses a list of modules into a single module
|
||||
|
||||
|
|
@ -134,34 +121,27 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
|
|||
|
||||
Examples::
|
||||
|
||||
>>> m = M().eval()
|
||||
>>> # m is a module containing the sub-modules below
|
||||
>>> m = myModel()
|
||||
>>> # m is a module containing the sub-modules below
|
||||
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
|
||||
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
|
||||
>>> output = fused_m(input)
|
||||
|
||||
>>> m = M().eval()
|
||||
>>> m = myModel()
|
||||
>>> # Alternately provide a single list of modules to fuse
|
||||
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
|
||||
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
|
||||
>>> output = fused_m(input)
|
||||
|
||||
"""
|
||||
return _fuse_modules(
|
||||
model,
|
||||
modules_to_fuse,
|
||||
is_qat=False,
|
||||
inplace=inplace,
|
||||
fuser_func=fuse_known_modules,
|
||||
fuse_custom_config_dict=None)
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
""" QAT version for `fuse_modules`
|
||||
"""
|
||||
return _fuse_modules(
|
||||
model,
|
||||
modules_to_fuse,
|
||||
is_qat=True,
|
||||
inplace=inplace,
|
||||
fuser_func=fuse_known_modules,
|
||||
fuse_custom_config_dict=None)
|
||||
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
||||
# Handle case of modules_to_fuse being a list
|
||||
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
|
||||
else:
|
||||
# Handle case of modules_to_fuse being a list of lists
|
||||
for module_list in modules_to_fuse:
|
||||
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -7,12 +7,10 @@ from torch.ao.quantization.utils import Pattern
|
|||
from torch.ao.quantization.utils import get_combined_dict
|
||||
|
||||
|
||||
def fuse_conv_bn(is_qat, conv, bn):
|
||||
def fuse_conv_bn(conv, bn):
|
||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
is_qat: a flag for whether we are using quantization aware training fusion
|
||||
or post training quantization fusion
|
||||
conv: Module instance of type conv2d/conv3d
|
||||
bn: Spatial BN instance that needs to be fused with the conv
|
||||
|
||||
|
|
@ -31,9 +29,7 @@ def fuse_conv_bn(is_qat, conv, bn):
|
|||
nn.Conv3d: nni.ConvBn3d,
|
||||
}
|
||||
|
||||
if is_qat:
|
||||
# TODO: remove the assert later
|
||||
assert conv.training, "qat is only supported when conv.training is True currently"
|
||||
if conv.training:
|
||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
||||
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
||||
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
||||
|
|
@ -45,12 +41,10 @@ def fuse_conv_bn(is_qat, conv, bn):
|
|||
else:
|
||||
return nn.utils.fuse_conv_bn_eval(conv, bn)
|
||||
|
||||
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
||||
def fuse_conv_bn_relu(conv, bn, relu):
|
||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
is_qat: a flag for whether we are using quantization aware training fusion
|
||||
or post training quantization fusion
|
||||
conv: Module instance of type conv2d/conv3d
|
||||
bn: Spatial BN instance that needs to be fused with the conv
|
||||
|
||||
|
|
@ -64,9 +58,7 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
|||
assert(conv.training == bn.training == relu.training),\
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
fused_module : Optional[Type[nn.Sequential]] = None
|
||||
if is_qat:
|
||||
# TODO: remove the assert later
|
||||
assert conv.training, "qat is only supported when conv.training is True currently"
|
||||
if conv.training:
|
||||
map_to_fused_module_train = {
|
||||
nn.Conv1d: nni.ConvBnReLU1d,
|
||||
nn.Conv2d: nni.ConvBnReLU2d,
|
||||
|
|
@ -93,12 +85,10 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
|||
else:
|
||||
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
|
||||
|
||||
def fuse_linear_bn(is_qat, linear, bn):
|
||||
def fuse_linear_bn(linear, bn):
|
||||
r"""Given the linear and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
is_qat: a flag for whether we are using quantization aware training fusion
|
||||
or post training quantization fusion
|
||||
linear: Module instance of type Linear
|
||||
bn: BatchNorm1d instance that needs to be fused with the linear layer
|
||||
|
||||
|
|
@ -111,14 +101,13 @@ def fuse_linear_bn(is_qat, linear, bn):
|
|||
assert(linear.training == bn.training),\
|
||||
"Linear and BN both must be in the same mode (train or eval)."
|
||||
|
||||
if is_qat:
|
||||
# TODO: remove the assert later
|
||||
assert linear.training, "qat is only supported when linear.training is True currently"
|
||||
if linear.training:
|
||||
raise Exception("Fusing Linear+BatchNorm not yet supported in training.")
|
||||
else:
|
||||
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
|
||||
|
||||
def fuse_convtranspose_bn(is_qat, convt, bn):
|
||||
|
||||
def fuse_convtranspose_bn(convt, bn):
|
||||
r"""Given ConvTranspose and bn modules, fuses them and returns the fused module
|
||||
|
||||
Args:
|
||||
|
|
@ -135,20 +124,11 @@ def fuse_convtranspose_bn(is_qat, convt, bn):
|
|||
assert(convt.training == bn.training),\
|
||||
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
||||
|
||||
if is_qat:
|
||||
assert convt.training, "qat is only supported when convt.training is True currently"
|
||||
if convt.training:
|
||||
raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in training.")
|
||||
else:
|
||||
return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
|
||||
|
||||
def sequential_wrapper2(sequential):
|
||||
""" Given a sequential class for two modules, return a function that takes
|
||||
is_qat, and then two modules as argument, that ignores the is_qat flag
|
||||
and always returns the sequential that combines the two input modules
|
||||
"""
|
||||
def fuser_method(is_qat, m1, m2):
|
||||
return sequential(m1, m2)
|
||||
return fuser_method
|
||||
|
||||
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
||||
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
|
||||
|
|
@ -157,13 +137,13 @@ DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
|||
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
|
||||
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
|
||||
(nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
|
||||
(nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
|
||||
(nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
|
||||
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
|
||||
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
|
||||
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
|
||||
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
|
||||
(nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
|
||||
(nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
|
||||
(nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
|
||||
(nn.Linear, nn.ReLU): nni.LinearReLU,
|
||||
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
|
||||
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
|
||||
(nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
|
||||
(nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
|
||||
(nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
|
||||
|
|
@ -181,25 +161,13 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
|
|||
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
|
||||
return fuser_method
|
||||
|
||||
def reverse_sequential_wrapper2(sequential):
|
||||
""" Given a sequential class for two modules, return a function that takes
|
||||
is_qat, and then two modules as argument, that ignores the is_qat flag
|
||||
and always returns the sequential that combines the two input modules, with
|
||||
the order of two inputs reversed
|
||||
"""
|
||||
def fuser_method(is_qat, m1, m2):
|
||||
return sequential(m2, m1)
|
||||
return fuser_method
|
||||
|
||||
def reverse2(f):
|
||||
def reversed(is_qat, x, y):
|
||||
return f(is_qat, y, x)
|
||||
return reversed
|
||||
return lambda x, y: f(y, x)
|
||||
|
||||
def reverse3(f):
|
||||
def reversed(is_qat, x, w):
|
||||
def reversed(x, w):
|
||||
y, z = w
|
||||
return f(is_qat, z, y, x)
|
||||
return f(z, y, x)
|
||||
return reversed
|
||||
|
||||
DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
|
||||
|
|
@ -209,13 +177,13 @@ DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] =
|
|||
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
|
||||
(nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
|
||||
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
|
||||
(nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d),
|
||||
(nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d),
|
||||
(nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d),
|
||||
(nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d),
|
||||
(nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d),
|
||||
(nn.ReLU, nn.Conv3d): reverse2(nni.ConvReLU3d),
|
||||
(nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
|
||||
(nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU),
|
||||
(nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d),
|
||||
(nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d),
|
||||
(nn.ReLU, nn.Linear): reverse2(nni.LinearReLU),
|
||||
(nn.ReLU, nn.BatchNorm2d): reverse2(nni.BNReLU2d),
|
||||
(nn.ReLU, nn.BatchNorm3d): reverse2(nni.BNReLU3d),
|
||||
(nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
|
||||
(nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
|
||||
(nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ class Fuser:
|
|||
def fuse(
|
||||
self,
|
||||
model: GraphModule,
|
||||
is_qat: bool,
|
||||
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> GraphModule:
|
||||
|
|
@ -73,7 +72,7 @@ class Fuser:
|
|||
root_node = get_root_node(matched_node_pattern) # type: ignore[index]
|
||||
env[node.name] = obj.fuse(
|
||||
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
|
||||
fuse_custom_config_dict, fuser_method_mapping, is_qat)
|
||||
fuse_custom_config_dict, fuser_method_mapping)
|
||||
elif maybe_last_node is None:
|
||||
env[node.name] = self.fused_graph.node_copy(node, load_arg)
|
||||
# node matched in patterns and is not root is removed here
|
||||
|
|
|
|||
|
|
@ -28,8 +28,7 @@ class FuseHandler(ABC):
|
|||
root_node: Node,
|
||||
matched_node_pattern: NodePattern,
|
||||
fuse_custom_config_dict: Dict[str, Any],
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
|
||||
is_qat: bool) -> Node:
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
|
||||
pass
|
||||
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
|
||||
|
|
@ -70,8 +69,7 @@ class DefaultFuseHandler(FuseHandler):
|
|||
root_node: Node,
|
||||
matched_node_pattern: NodePattern,
|
||||
fuse_custom_config_dict: Dict[str, Any],
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
|
||||
is_qat: bool) -> Node:
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
|
||||
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
||||
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
|
||||
root_module = quantizer.modules[root_node.target]
|
||||
|
|
@ -115,7 +113,7 @@ class DefaultFuseHandler(FuseHandler):
|
|||
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
|
||||
# TODO: change the signature for fuser_method to take matched module patterns
|
||||
# as input
|
||||
fused_module = fuser_method(is_qat, *matched_modules)
|
||||
fused_module = fuser_method(*matched_modules)
|
||||
# TODO: maybe add a pass to cleanup bn modules?
|
||||
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
|
||||
return quantizer.fused_graph.node_copy(root_node, load_arg)
|
||||
|
|
|
|||
|
|
@ -1198,7 +1198,6 @@ def insert_observers_for_model(
|
|||
|
||||
def run_prepare_fx_on_standalone_modules(
|
||||
model: torch.nn.Module,
|
||||
is_qat: bool,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
matches: Any,
|
||||
prepare_custom_config_dict: Dict[str, Any],
|
||||
|
|
@ -1229,7 +1228,6 @@ def run_prepare_fx_on_standalone_modules(
|
|||
prepare(
|
||||
standalone_module,
|
||||
sm_qconfig_dict,
|
||||
is_qat,
|
||||
sm_prepare_config_dict,
|
||||
backend_config_dict=sm_backend_config_dict)
|
||||
preserved_attributes = \
|
||||
|
|
@ -1266,12 +1264,12 @@ def save_state(
|
|||
def prepare(
|
||||
model: GraphModule,
|
||||
qconfig_dict: Any,
|
||||
is_qat: bool,
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
is_standalone_module: bool = False) -> ObservedGraphModule:
|
||||
is_standalone_module: bool = False,
|
||||
is_qat: bool = False) -> ObservedGraphModule:
|
||||
""" standalone_module means it a submodule that is not inlined in
|
||||
parent module, and will be quantized separately as one unit.
|
||||
|
||||
|
|
@ -1390,7 +1388,7 @@ def prepare(
|
|||
"output_quantized_idxs", [])
|
||||
|
||||
run_prepare_fx_on_standalone_modules(
|
||||
model, is_qat, modules, matches, prepare_custom_config_dict, backend_config_dict)
|
||||
model, modules, matches, prepare_custom_config_dict, backend_config_dict)
|
||||
|
||||
# record names for the set of observed node, so that in convert step
|
||||
# we know whether we need to convert a floating point module to reference
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ from torch.fx import (
|
|||
from torch.fx.graph import (
|
||||
Graph,
|
||||
)
|
||||
from torch.nn.intrinsic import _FusedModule
|
||||
|
||||
from ..utils import _parent_name
|
||||
from ..fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD
|
||||
from ..qconfig_dict_utils import (
|
||||
get_object_type_qconfig,
|
||||
maybe_adjust_qconfig_for_module_type_or_name,
|
||||
|
|
@ -56,29 +56,23 @@ def update_qconfig_for_fusion(
|
|||
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'call_module' and node.target in modules:
|
||||
maybe_fused_module = modules[str(node.target)]
|
||||
if not isinstance(maybe_fused_module, _FusedModule):
|
||||
module_type = type(modules[str(node.target)])
|
||||
if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()):
|
||||
continue
|
||||
|
||||
ops = list(maybe_fused_module._modules.values())
|
||||
fused_qconfig = object_type_dict.get(type(ops[0]), None)
|
||||
for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items():
|
||||
if module_type == fuser:
|
||||
fused_qconfig = object_type_dict.get(ops[0], None)
|
||||
|
||||
# Raise an error if the modules in the fused module have
|
||||
# different qconfigs specified in the qconfig_dict
|
||||
# TODO: currently it only works for modules,
|
||||
# need to make this work for torch.nn.functional.relu
|
||||
# TODO: currently it only works for object_type configurations,
|
||||
# ideally it should work for different types of configurations,
|
||||
# maybe we want to redesign this part
|
||||
for op in ops[1:]:
|
||||
if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig):
|
||||
raise LookupError(
|
||||
"During fusion, we need to specify the same " +
|
||||
f"qconfigs for all module types in {type(maybe_fused_module)} " +
|
||||
f"offending type: {type(op)}")
|
||||
# Raise an error if the modules in the fused module have
|
||||
# different qconfigs specified in the qconfig_dict
|
||||
for op in ops:
|
||||
if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig):
|
||||
raise LookupError("During fusion, we need to specify the same " +
|
||||
f"qconfigs for both modules in {module_type}.")
|
||||
|
||||
if fused_qconfig is not None:
|
||||
object_type_dict[type(maybe_fused_module)] = fused_qconfig
|
||||
if fused_qconfig is not None:
|
||||
object_type_dict[module_type] = fused_qconfig
|
||||
|
||||
return qconfig_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
|
|||
|
||||
def _fuse_fx(
|
||||
graph_module: GraphModule,
|
||||
is_qat: bool,
|
||||
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> GraphModule:
|
||||
|
|
@ -58,8 +57,7 @@ def _fuse_fx(
|
|||
"""
|
||||
_check_is_graph_module(graph_module)
|
||||
fuser = Fuser()
|
||||
return fuser.fuse(
|
||||
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)
|
||||
return fuser.fuse(graph_module, fuse_custom_config_dict, backend_config_dict)
|
||||
|
||||
|
||||
class Scope(object):
|
||||
|
|
@ -177,11 +175,11 @@ class QuantizationTracer(Tracer):
|
|||
def _prepare_fx(
|
||||
model: torch.nn.Module,
|
||||
qconfig_dict: Any,
|
||||
is_qat: bool,
|
||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
is_standalone_module: bool = False,
|
||||
is_qat: bool = False,
|
||||
) -> ObservedGraphModule:
|
||||
r""" Internal helper function for prepare_fx
|
||||
Args:
|
||||
|
|
@ -237,20 +235,16 @@ forward graph of the parent module,
|
|||
graph_module = GraphModule(model, tracer.trace(model))
|
||||
for attr_name in preserved_attributes:
|
||||
setattr(graph_module, attr_name, getattr(model, attr_name))
|
||||
graph_module = _fuse_fx(
|
||||
graph_module,
|
||||
is_qat,
|
||||
prepare_custom_config_dict,
|
||||
backend_config_dict)
|
||||
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict, backend_config_dict)
|
||||
prepared = prepare(
|
||||
graph_module,
|
||||
qconfig_dict,
|
||||
is_qat,
|
||||
tracer.node_name_to_scope,
|
||||
prepare_custom_config_dict=prepare_custom_config_dict,
|
||||
equalization_qconfig_dict=equalization_qconfig_dict,
|
||||
backend_config_dict=backend_config_dict,
|
||||
is_standalone_module=is_standalone_module,
|
||||
is_qat=is_qat,
|
||||
)
|
||||
|
||||
for attr_name in preserved_attributes:
|
||||
|
|
@ -261,9 +255,9 @@ forward graph of the parent module,
|
|||
def _prepare_standalone_module_fx(
|
||||
model: torch.nn.Module,
|
||||
qconfig_dict: Any,
|
||||
is_qat: bool,
|
||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||
is_qat: bool = False,
|
||||
) -> GraphModule:
|
||||
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
|
||||
parent module.
|
||||
|
|
@ -290,10 +284,10 @@ def _prepare_standalone_module_fx(
|
|||
return _prepare_fx(
|
||||
model,
|
||||
qconfig_dict,
|
||||
is_qat,
|
||||
prepare_custom_config_dict,
|
||||
backend_config_dict=backend_config_dict,
|
||||
is_standalone_module=True,
|
||||
is_qat=is_qat,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -338,7 +332,7 @@ def fuse_fx(
|
|||
)
|
||||
for attr_name in preserved_attributes:
|
||||
setattr(graph_module, attr_name, getattr(model, attr_name))
|
||||
return _fuse_fx(graph_module, False, fuse_custom_config_dict)
|
||||
return _fuse_fx(graph_module, fuse_custom_config_dict)
|
||||
|
||||
|
||||
def prepare_fx(
|
||||
|
|
@ -515,10 +509,10 @@ def prepare_fx(
|
|||
return _prepare_fx(
|
||||
model,
|
||||
qconfig_dict,
|
||||
False, # is_qat
|
||||
prepare_custom_config_dict,
|
||||
equalization_qconfig_dict,
|
||||
backend_config_dict,
|
||||
is_qat=False,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -564,9 +558,9 @@ def prepare_qat_fx(
|
|||
return _prepare_fx(
|
||||
model,
|
||||
qconfig_dict,
|
||||
True, # is_qat
|
||||
prepare_custom_config_dict,
|
||||
backend_config_dict=backend_config_dict,
|
||||
is_qat=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1195,11 +1195,7 @@ class AnnotatedConvBnReLUModel(torch.nn.Module):
|
|||
return x
|
||||
|
||||
def fuse_model(self):
|
||||
# TODO: remove this check and define two fuse_modules function on this module
|
||||
if self.training:
|
||||
torch.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True)
|
||||
else:
|
||||
torch.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
|
||||
torch.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
|
||||
|
||||
class TwoLayerConvModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -1468,11 +1464,7 @@ class InnerModule(torch.nn.Module):
|
|||
if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
|
||||
fusable_layers.append([current_name,
|
||||
named_children[idx + 1][0]])
|
||||
# TODO: remove this check and define two fuse_modules function on this module
|
||||
if self.training:
|
||||
torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
|
||||
else:
|
||||
torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
|
||||
torch.quantization.fuse_modules(self, fusable_layers, inplace=True)
|
||||
|
||||
class FunctionalLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -1963,11 +1955,7 @@ class ResNetBase(torch.nn.Module):
|
|||
return out
|
||||
|
||||
def fuse_model(self):
|
||||
# TODO: remove this check and define two fuse_model function on this module
|
||||
if self.training:
|
||||
torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
else:
|
||||
torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
|
||||
class ModelMultipleOps(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user