mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[reland][bc-breaking][quant][be] Refactor fuser_method to include is_qat argument" (#71956)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71956 Pull Request resolved: https://github.com/facebookresearch/mobile-vision/pull/59 Original commit changeset: f3912e210e8c Original Phabricator Diff: D33178977 (ef501e8fed) Test Plan: Please see original diff for test plans **Static Docs Preview: classyvision** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D33833203/V3/classyvision/)| |**Modified Pages**| Reviewed By: andrewor14 Differential Revision: D33833203 fbshipit-source-id: 74a8f22730b00aafa6a173b208e635c1d696959e (cherry picked from commitfb88772b18)
This commit is contained in:
parent
847dbb8684
commit
082ff25f37
|
|
@ -820,7 +820,7 @@ class TestDistributed(QuantizationTestCase):
|
||||||
torch.ao.quantization.DeQuantStub(),
|
torch.ao.quantization.DeQuantStub(),
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.ao.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
|
torch.ao.quantization.fuse_modules_qat(model, [['1', '2', '3'], ['4', '5']], inplace=True)
|
||||||
|
|
||||||
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
|
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
|
||||||
torch.ao.quantization.prepare_qat(model, inplace=True)
|
torch.ao.quantization.prepare_qat(model, inplace=True)
|
||||||
|
|
@ -861,7 +861,7 @@ class TestDistributed(QuantizationTestCase):
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
# fuse it
|
# fuse it
|
||||||
fused_model = torch.ao.quantization.fuse_modules(
|
fused_model = torch.ao.quantization.fuse_modules_qat(
|
||||||
model,
|
model,
|
||||||
[['conv', 'bn']],
|
[['conv', 'bn']],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from torch.ao.quantization import (
|
||||||
prepare_qat,
|
prepare_qat,
|
||||||
quantize_qat,
|
quantize_qat,
|
||||||
fuse_modules,
|
fuse_modules,
|
||||||
|
fuse_modules_qat,
|
||||||
QConfig,
|
QConfig,
|
||||||
default_qconfig,
|
default_qconfig,
|
||||||
default_qat_qconfig,
|
default_qat_qconfig,
|
||||||
|
|
@ -43,8 +44,8 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
def test_fuse_module_train(self):
|
def test_fuse_module_train(self):
|
||||||
model = ModelForFusion(default_qat_qconfig).train()
|
model = ModelForFusion(default_qat_qconfig).train()
|
||||||
# Test step by step fusion
|
# Test step by step fusion
|
||||||
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
|
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
||||||
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
|
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
||||||
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
|
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
|
||||||
msg="Fused Conv + BN + Relu first layer")
|
msg="Fused Conv + BN + Relu first layer")
|
||||||
self.assertEqual(type(model.bn1), torch.nn.Identity,
|
self.assertEqual(type(model.bn1), torch.nn.Identity,
|
||||||
|
|
@ -91,7 +92,9 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
model = ModelForFusion(default_qat_qconfig).train()
|
model = ModelForFusion(default_qat_qconfig).train()
|
||||||
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
|
model = fuse_modules_qat(
|
||||||
|
model,
|
||||||
|
[['conv1', 'bn1', 'relu1'],
|
||||||
['sub1.conv', 'sub1.bn']])
|
['sub1.conv', 'sub1.bn']])
|
||||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
||||||
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
||||||
|
|
@ -101,7 +104,9 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
def test_fuse_module_eval(self):
|
def test_fuse_module_eval(self):
|
||||||
model = ModelForFusion(default_qconfig)
|
model = ModelForFusion(default_qconfig)
|
||||||
model.eval()
|
model.eval()
|
||||||
model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
|
model = fuse_modules(
|
||||||
|
model,
|
||||||
|
[['conv3', 'bn3', 'relu4'],
|
||||||
['conv1', 'bn1', 'relu1'],
|
['conv1', 'bn1', 'relu1'],
|
||||||
['conv2', 'relu2'],
|
['conv2', 'relu2'],
|
||||||
['bn2', 'relu3'],
|
['bn2', 'relu3'],
|
||||||
|
|
@ -168,7 +173,9 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
model = ModelForFusion(default_qconfig).eval()
|
model = ModelForFusion(default_qconfig).eval()
|
||||||
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
|
model = fuse_modules(
|
||||||
|
model,
|
||||||
|
[['conv1', 'bn1', 'relu1'],
|
||||||
['conv2', 'relu2'],
|
['conv2', 'relu2'],
|
||||||
['bn2', 'relu3'],
|
['bn2', 'relu3'],
|
||||||
['sub1.conv', 'sub1.bn'],
|
['sub1.conv', 'sub1.bn'],
|
||||||
|
|
@ -181,11 +188,13 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
with override_quantized_engine(qengine):
|
with override_quantized_engine(qengine):
|
||||||
model = ModelWithSequentialFusion().train()
|
model = ModelWithSequentialFusion().train()
|
||||||
model.to(torch.float)
|
model.to(torch.float)
|
||||||
fuse_modules(model, [['conv1', 'relu1'] ,
|
fuse_modules_qat(
|
||||||
|
model, [['conv1', 'relu1'] ,
|
||||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||||
['classifier.0', 'classifier.1']], inplace=True)
|
['classifier.0', 'classifier.1']],
|
||||||
|
inplace=True)
|
||||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
||||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||||
|
|
@ -233,11 +242,14 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
with override_quantized_engine(qengine):
|
with override_quantized_engine(qengine):
|
||||||
model = ModelWithSequentialFusion().eval()
|
model = ModelWithSequentialFusion().eval()
|
||||||
model.to(torch.float)
|
model.to(torch.float)
|
||||||
fuse_modules(model, [['conv1', 'relu1'] ,
|
fuse_modules(
|
||||||
|
model,
|
||||||
|
[['conv1', 'relu1'],
|
||||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||||
['classifier.0', 'classifier.1']], inplace=True)
|
['classifier.0', 'classifier.1']],
|
||||||
|
inplace=True)
|
||||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
||||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||||
|
|
@ -286,7 +298,9 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
# fused model
|
# fused model
|
||||||
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
|
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
|
||||||
weight=torch.nn.Identity)
|
weight=torch.nn.Identity)
|
||||||
model = fuse_modules(model_orig, [["conv1", "bn1", "relu1"],
|
model = fuse_modules_qat(
|
||||||
|
model_orig,
|
||||||
|
[["conv1", "bn1", "relu1"],
|
||||||
["conv2", "bn2"]])
|
["conv2", "bn2"]])
|
||||||
prep_model = prepare_qat(model, inplace=False)
|
prep_model = prepare_qat(model, inplace=False)
|
||||||
# output with fusion but no observers.
|
# output with fusion but no observers.
|
||||||
|
|
@ -385,8 +399,8 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
|
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
|
||||||
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
|
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
|
||||||
|
|
||||||
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
|
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
||||||
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
|
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
||||||
|
|
||||||
fused = True
|
fused = True
|
||||||
before_fusion_pre_count = counter['pre_forwards']
|
before_fusion_pre_count = counter['pre_forwards']
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||||
fq_model.train()
|
fq_model.train()
|
||||||
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
|
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
|
||||||
torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||||
torch.ao.quantization.prepare_qat(fq_model)
|
torch.ao.quantization.prepare_qat(fq_model)
|
||||||
fq_model.eval()
|
fq_model.eval()
|
||||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||||
|
|
@ -105,7 +105,7 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||||
fq_model.train()
|
fq_model.train()
|
||||||
fq_model.qconfig = qconfig
|
fq_model.qconfig = qconfig
|
||||||
torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||||
torch.ao.quantization.prepare_qat(fq_model)
|
torch.ao.quantization.prepare_qat(fq_model)
|
||||||
fq_model.eval()
|
fq_model.eval()
|
||||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,7 @@ from torch.ao.quantization import (
|
||||||
get_default_qat_qconfig,
|
get_default_qat_qconfig,
|
||||||
get_default_qconfig_dict,
|
get_default_qconfig_dict,
|
||||||
fuse_modules,
|
fuse_modules,
|
||||||
|
fuse_modules_qat,
|
||||||
prepare,
|
prepare,
|
||||||
prepare_qat,
|
prepare_qat,
|
||||||
convert,
|
convert,
|
||||||
|
|
@ -363,6 +364,8 @@ class TestFuseFx(QuantizationTestCase):
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_qconfig_fused_module(self):
|
def test_qconfig_fused_module(self):
|
||||||
|
""" TODO: add test for all fused modules
|
||||||
|
"""
|
||||||
qconfig_dict = {
|
qconfig_dict = {
|
||||||
"": None,
|
"": None,
|
||||||
"object_type": [(nn.Linear, default_qconfig),
|
"object_type": [(nn.Linear, default_qconfig),
|
||||||
|
|
@ -890,14 +893,19 @@ class TestQuantizeFx(QuantizationTestCase):
|
||||||
m_eager.eval()
|
m_eager.eval()
|
||||||
qconfig = get_default_qconfig(qengine)
|
qconfig = get_default_qconfig(qengine)
|
||||||
prepare_fn = prepare
|
prepare_fn = prepare
|
||||||
|
is_qat = False
|
||||||
else:
|
else:
|
||||||
m_eager.train()
|
m_eager.train()
|
||||||
qconfig = get_default_qat_qconfig(qengine)
|
qconfig = get_default_qat_qconfig(qengine)
|
||||||
prepare_fn = prepare_qat
|
prepare_fn = prepare_qat
|
||||||
|
is_qat = True
|
||||||
|
|
||||||
fuse_list = ["conv", "bn"]
|
fuse_list = ["conv", "bn"]
|
||||||
if has_relu:
|
if has_relu:
|
||||||
fuse_list.append("relu")
|
fuse_list.append("relu")
|
||||||
|
if is_qat:
|
||||||
|
fuse_modules_qat(m_eager, fuse_list, inplace=True)
|
||||||
|
else:
|
||||||
fuse_modules(m_eager, fuse_list, inplace=True)
|
fuse_modules(m_eager, fuse_list, inplace=True)
|
||||||
m_eager.qconfig = qconfig
|
m_eager.qconfig = qconfig
|
||||||
m_eager = prepare_fn(m_eager)
|
m_eager = prepare_fn(m_eager)
|
||||||
|
|
@ -5847,6 +5855,7 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
||||||
graph.eval()
|
graph.eval()
|
||||||
calibrate_or_train = test_only_eval_fn
|
calibrate_or_train = test_only_eval_fn
|
||||||
data = self.img_data_2d
|
data = self.img_data_2d
|
||||||
|
is_qat = False
|
||||||
else:
|
else:
|
||||||
assert quant_type == QuantType.QAT
|
assert quant_type == QuantType.QAT
|
||||||
qconfig = default_qat_qconfig
|
qconfig = default_qat_qconfig
|
||||||
|
|
@ -5856,6 +5865,7 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
||||||
graph.train()
|
graph.train()
|
||||||
calibrate_or_train = test_only_train_fn
|
calibrate_or_train = test_only_train_fn
|
||||||
data = self.img_data_2d_train
|
data = self.img_data_2d_train
|
||||||
|
is_qat = True
|
||||||
|
|
||||||
if hasattr(eager, "fuse_model"):
|
if hasattr(eager, "fuse_model"):
|
||||||
eager.fuse_model()
|
eager.fuse_model()
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from .fake_quantize import * # noqa: F403
|
from .fake_quantize import * # noqa: F403
|
||||||
from .fuse_modules import fuse_modules # noqa: F403
|
from .fuse_modules import fuse_modules # noqa: F403
|
||||||
|
from .fuse_modules import fuse_modules_qat # noqa: F403
|
||||||
from .fuser_method_mappings import * # noqa: F403
|
from .fuser_method_mappings import * # noqa: F403
|
||||||
from .observer import * # noqa: F403
|
from .observer import * # noqa: F403
|
||||||
from .qconfig import * # noqa: F403
|
from .qconfig import * # noqa: F403
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module):
|
||||||
|
|
||||||
setattr(cur_mod, tokens[-1], module)
|
setattr(cur_mod, tokens[-1], module)
|
||||||
|
|
||||||
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
|
||||||
r"""Returns a list of modules that fuses the operations specified
|
r"""Returns a list of modules that fuses the operations specified
|
||||||
in the input module list.
|
in the input module list.
|
||||||
|
|
||||||
|
|
@ -46,7 +46,7 @@ def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
||||||
if fuser_method is None:
|
if fuser_method is None:
|
||||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||||
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
|
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
|
||||||
fused = fuser_method(*mod_list)
|
fused = fuser_method(is_qat, *mod_list)
|
||||||
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
|
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
|
||||||
# Move pre forward hooks of the base module to resulting fused module
|
# Move pre forward hooks of the base module to resulting fused module
|
||||||
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
|
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
|
||||||
|
|
@ -65,7 +65,7 @@ def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
||||||
|
|
||||||
return new_mod
|
return new_mod
|
||||||
|
|
||||||
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||||
if fuse_custom_config_dict is None:
|
if fuse_custom_config_dict is None:
|
||||||
fuse_custom_config_dict = {}
|
fuse_custom_config_dict = {}
|
||||||
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
||||||
|
|
@ -74,12 +74,25 @@ def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_cu
|
||||||
mod_list.append(_get_module(model, item))
|
mod_list.append(_get_module(model, item))
|
||||||
|
|
||||||
# Fuse list of modules
|
# Fuse list of modules
|
||||||
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
|
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
|
||||||
|
|
||||||
# Replace original module list with fused module list
|
# Replace original module list with fused module list
|
||||||
for i, item in enumerate(modules_to_fuse):
|
for i, item in enumerate(modules_to_fuse):
|
||||||
_set_module(model, item, new_mod_list[i])
|
_set_module(model, item, new_mod_list[i])
|
||||||
|
|
||||||
|
def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||||
|
if not inplace:
|
||||||
|
model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
||||||
|
# Handle case of modules_to_fuse being a list
|
||||||
|
_fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict)
|
||||||
|
else:
|
||||||
|
# Handle case of modules_to_fuse being a list of lists
|
||||||
|
for module_list in modules_to_fuse:
|
||||||
|
_fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||||
r"""Fuses a list of modules into a single module
|
r"""Fuses a list of modules into a single module
|
||||||
|
|
||||||
|
|
@ -121,27 +134,34 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
>>> m = myModel()
|
>>> m = M().eval()
|
||||||
>>> # m is a module containing the sub-modules below
|
>>> # m is a module containing the sub-modules below
|
||||||
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
|
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
|
||||||
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
|
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
|
||||||
>>> output = fused_m(input)
|
>>> output = fused_m(input)
|
||||||
|
|
||||||
>>> m = myModel()
|
>>> m = M().eval()
|
||||||
>>> # Alternately provide a single list of modules to fuse
|
>>> # Alternately provide a single list of modules to fuse
|
||||||
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
|
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
|
||||||
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
|
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
|
||||||
>>> output = fused_m(input)
|
>>> output = fused_m(input)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not inplace:
|
return _fuse_modules(
|
||||||
model = copy.deepcopy(model)
|
model,
|
||||||
|
modules_to_fuse,
|
||||||
|
is_qat=False,
|
||||||
|
inplace=inplace,
|
||||||
|
fuser_func=fuse_known_modules,
|
||||||
|
fuse_custom_config_dict=None)
|
||||||
|
|
||||||
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||||
# Handle case of modules_to_fuse being a list
|
""" QAT version for `fuse_modules`
|
||||||
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
|
"""
|
||||||
else:
|
return _fuse_modules(
|
||||||
# Handle case of modules_to_fuse being a list of lists
|
model,
|
||||||
for module_list in modules_to_fuse:
|
modules_to_fuse,
|
||||||
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
|
is_qat=True,
|
||||||
return model
|
inplace=inplace,
|
||||||
|
fuser_func=fuse_known_modules,
|
||||||
|
fuse_custom_config_dict=None)
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,12 @@ from torch.ao.quantization.utils import Pattern
|
||||||
from torch.ao.quantization.utils import get_combined_dict
|
from torch.ao.quantization.utils import get_combined_dict
|
||||||
|
|
||||||
|
|
||||||
def fuse_conv_bn(conv, bn):
|
def fuse_conv_bn(is_qat, conv, bn):
|
||||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
is_qat: a flag for whether we are using quantization aware training fusion
|
||||||
|
or post training quantization fusion
|
||||||
conv: Module instance of type conv2d/conv3d
|
conv: Module instance of type conv2d/conv3d
|
||||||
bn: Spatial BN instance that needs to be fused with the conv
|
bn: Spatial BN instance that needs to be fused with the conv
|
||||||
|
|
||||||
|
|
@ -29,7 +31,9 @@ def fuse_conv_bn(conv, bn):
|
||||||
nn.Conv3d: nni.ConvBn3d,
|
nn.Conv3d: nni.ConvBn3d,
|
||||||
}
|
}
|
||||||
|
|
||||||
if conv.training:
|
if is_qat:
|
||||||
|
# TODO: remove the assert later
|
||||||
|
assert conv.training, "qat is only supported when conv.training is True currently"
|
||||||
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
||||||
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
|
||||||
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
|
||||||
|
|
@ -41,10 +45,12 @@ def fuse_conv_bn(conv, bn):
|
||||||
else:
|
else:
|
||||||
return nn.utils.fuse_conv_bn_eval(conv, bn)
|
return nn.utils.fuse_conv_bn_eval(conv, bn)
|
||||||
|
|
||||||
def fuse_conv_bn_relu(conv, bn, relu):
|
def fuse_conv_bn_relu(is_qat, conv, bn, relu):
|
||||||
r"""Given the conv and bn modules, fuses them and returns the fused module
|
r"""Given the conv and bn modules, fuses them and returns the fused module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
is_qat: a flag for whether we are using quantization aware training fusion
|
||||||
|
or post training quantization fusion
|
||||||
conv: Module instance of type conv2d/conv3d
|
conv: Module instance of type conv2d/conv3d
|
||||||
bn: Spatial BN instance that needs to be fused with the conv
|
bn: Spatial BN instance that needs to be fused with the conv
|
||||||
|
|
||||||
|
|
@ -58,7 +64,9 @@ def fuse_conv_bn_relu(conv, bn, relu):
|
||||||
assert(conv.training == bn.training == relu.training),\
|
assert(conv.training == bn.training == relu.training),\
|
||||||
"Conv and BN both must be in the same mode (train or eval)."
|
"Conv and BN both must be in the same mode (train or eval)."
|
||||||
fused_module : Optional[Type[nn.Sequential]] = None
|
fused_module : Optional[Type[nn.Sequential]] = None
|
||||||
if conv.training:
|
if is_qat:
|
||||||
|
# TODO: remove the assert later
|
||||||
|
assert conv.training, "qat is only supported when conv.training is True currently"
|
||||||
map_to_fused_module_train = {
|
map_to_fused_module_train = {
|
||||||
nn.Conv1d: nni.ConvBnReLU1d,
|
nn.Conv1d: nni.ConvBnReLU1d,
|
||||||
nn.Conv2d: nni.ConvBnReLU2d,
|
nn.Conv2d: nni.ConvBnReLU2d,
|
||||||
|
|
@ -85,10 +93,12 @@ def fuse_conv_bn_relu(conv, bn, relu):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
|
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
|
||||||
|
|
||||||
def fuse_linear_bn(linear, bn):
|
def fuse_linear_bn(is_qat, linear, bn):
|
||||||
r"""Given the linear and bn modules, fuses them and returns the fused module
|
r"""Given the linear and bn modules, fuses them and returns the fused module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
is_qat: a flag for whether we are using quantization aware training fusion
|
||||||
|
or post training quantization fusion
|
||||||
linear: Module instance of type Linear
|
linear: Module instance of type Linear
|
||||||
bn: BatchNorm1d instance that needs to be fused with the linear layer
|
bn: BatchNorm1d instance that needs to be fused with the linear layer
|
||||||
|
|
||||||
|
|
@ -101,13 +111,14 @@ def fuse_linear_bn(linear, bn):
|
||||||
assert(linear.training == bn.training),\
|
assert(linear.training == bn.training),\
|
||||||
"Linear and BN both must be in the same mode (train or eval)."
|
"Linear and BN both must be in the same mode (train or eval)."
|
||||||
|
|
||||||
if linear.training:
|
if is_qat:
|
||||||
|
# TODO: remove the assert later
|
||||||
|
assert linear.training, "qat is only supported when linear.training is True currently"
|
||||||
raise Exception("Fusing Linear+BatchNorm not yet supported in training.")
|
raise Exception("Fusing Linear+BatchNorm not yet supported in training.")
|
||||||
else:
|
else:
|
||||||
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
|
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
|
||||||
|
|
||||||
|
def fuse_convtranspose_bn(is_qat, convt, bn):
|
||||||
def fuse_convtranspose_bn(convt, bn):
|
|
||||||
r"""Given ConvTranspose and bn modules, fuses them and returns the fused module
|
r"""Given ConvTranspose and bn modules, fuses them and returns the fused module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -124,11 +135,20 @@ def fuse_convtranspose_bn(convt, bn):
|
||||||
assert(convt.training == bn.training),\
|
assert(convt.training == bn.training),\
|
||||||
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
"ConvTranspose and BN both must be in the same mode (train or eval)."
|
||||||
|
|
||||||
if convt.training:
|
if is_qat:
|
||||||
|
assert convt.training, "qat is only supported when convt.training is True currently"
|
||||||
raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in training.")
|
raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in training.")
|
||||||
else:
|
else:
|
||||||
return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
|
return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
|
||||||
|
|
||||||
|
def sequential_wrapper2(sequential):
|
||||||
|
""" Given a sequential class for two modules, return a function that takes
|
||||||
|
is_qat, and then two modules as argument, that ignores the is_qat flag
|
||||||
|
and always returns the sequential that combines the two input modules
|
||||||
|
"""
|
||||||
|
def fuser_method(is_qat, m1, m2):
|
||||||
|
return sequential(m1, m2)
|
||||||
|
return fuser_method
|
||||||
|
|
||||||
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
||||||
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
|
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
|
||||||
|
|
@ -137,13 +157,13 @@ DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
|
||||||
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
|
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
|
||||||
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
|
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
|
||||||
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
|
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
|
||||||
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
|
(nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
|
||||||
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
|
(nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
|
||||||
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
|
(nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
|
||||||
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
|
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
|
||||||
(nn.Linear, nn.ReLU): nni.LinearReLU,
|
(nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
|
||||||
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
|
(nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
|
||||||
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,
|
(nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
|
||||||
(nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
|
(nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
|
||||||
(nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
|
(nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
|
||||||
(nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
|
(nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
|
||||||
|
|
@ -161,13 +181,25 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
|
||||||
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
|
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
|
||||||
return fuser_method
|
return fuser_method
|
||||||
|
|
||||||
|
def reverse_sequential_wrapper2(sequential):
|
||||||
|
""" Given a sequential class for two modules, return a function that takes
|
||||||
|
is_qat, and then two modules as argument, that ignores the is_qat flag
|
||||||
|
and always returns the sequential that combines the two input modules, with
|
||||||
|
the order of two inputs reversed
|
||||||
|
"""
|
||||||
|
def fuser_method(is_qat, m1, m2):
|
||||||
|
return sequential(m2, m1)
|
||||||
|
return fuser_method
|
||||||
|
|
||||||
def reverse2(f):
|
def reverse2(f):
|
||||||
return lambda x, y: f(y, x)
|
def reversed(is_qat, x, y):
|
||||||
|
return f(is_qat, y, x)
|
||||||
|
return reversed
|
||||||
|
|
||||||
def reverse3(f):
|
def reverse3(f):
|
||||||
def reversed(x, w):
|
def reversed(is_qat, x, w):
|
||||||
y, z = w
|
y, z = w
|
||||||
return f(z, y, x)
|
return f(is_qat, z, y, x)
|
||||||
return reversed
|
return reversed
|
||||||
|
|
||||||
DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
|
DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
|
||||||
|
|
@ -177,13 +209,13 @@ DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] =
|
||||||
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
|
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
|
||||||
(nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
|
(nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
|
||||||
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
|
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
|
||||||
(nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d),
|
(nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d),
|
||||||
(nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d),
|
(nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d),
|
||||||
(nn.ReLU, nn.Conv3d): reverse2(nni.ConvReLU3d),
|
(nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d),
|
||||||
(nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
|
(nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
|
||||||
(nn.ReLU, nn.Linear): reverse2(nni.LinearReLU),
|
(nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU),
|
||||||
(nn.ReLU, nn.BatchNorm2d): reverse2(nni.BNReLU2d),
|
(nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d),
|
||||||
(nn.ReLU, nn.BatchNorm3d): reverse2(nni.BNReLU3d),
|
(nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d),
|
||||||
(nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
|
(nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn),
|
||||||
(nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
|
(nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn),
|
||||||
(nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),
|
(nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn),
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ class Fuser:
|
||||||
def fuse(
|
def fuse(
|
||||||
self,
|
self,
|
||||||
model: GraphModule,
|
model: GraphModule,
|
||||||
|
is_qat: bool,
|
||||||
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
|
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
|
|
@ -72,7 +73,7 @@ class Fuser:
|
||||||
root_node = get_root_node(matched_node_pattern) # type: ignore[index]
|
root_node = get_root_node(matched_node_pattern) # type: ignore[index]
|
||||||
env[node.name] = obj.fuse(
|
env[node.name] = obj.fuse(
|
||||||
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
|
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
|
||||||
fuse_custom_config_dict, fuser_method_mapping)
|
fuse_custom_config_dict, fuser_method_mapping, is_qat)
|
||||||
elif maybe_last_node is None:
|
elif maybe_last_node is None:
|
||||||
env[node.name] = self.fused_graph.node_copy(node, load_arg)
|
env[node.name] = self.fused_graph.node_copy(node, load_arg)
|
||||||
# node matched in patterns and is not root is removed here
|
# node matched in patterns and is not root is removed here
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ class FuseHandler(ABC):
|
||||||
root_node: Node,
|
root_node: Node,
|
||||||
matched_node_pattern: NodePattern,
|
matched_node_pattern: NodePattern,
|
||||||
fuse_custom_config_dict: Dict[str, Any],
|
fuse_custom_config_dict: Dict[str, Any],
|
||||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
|
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
|
||||||
|
is_qat: bool) -> Node:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
|
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
|
||||||
|
|
@ -69,7 +70,8 @@ class DefaultFuseHandler(FuseHandler):
|
||||||
root_node: Node,
|
root_node: Node,
|
||||||
matched_node_pattern: NodePattern,
|
matched_node_pattern: NodePattern,
|
||||||
fuse_custom_config_dict: Dict[str, Any],
|
fuse_custom_config_dict: Dict[str, Any],
|
||||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
|
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]],
|
||||||
|
is_qat: bool) -> Node:
|
||||||
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
||||||
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
|
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
|
||||||
root_module = quantizer.modules[root_node.target]
|
root_module = quantizer.modules[root_node.target]
|
||||||
|
|
@ -113,7 +115,7 @@ class DefaultFuseHandler(FuseHandler):
|
||||||
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
|
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
|
||||||
# TODO: change the signature for fuser_method to take matched module patterns
|
# TODO: change the signature for fuser_method to take matched module patterns
|
||||||
# as input
|
# as input
|
||||||
fused_module = fuser_method(*matched_modules)
|
fused_module = fuser_method(is_qat, *matched_modules)
|
||||||
# TODO: maybe add a pass to cleanup bn modules?
|
# TODO: maybe add a pass to cleanup bn modules?
|
||||||
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
|
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
|
||||||
return quantizer.fused_graph.node_copy(root_node, load_arg)
|
return quantizer.fused_graph.node_copy(root_node, load_arg)
|
||||||
|
|
|
||||||
|
|
@ -1198,6 +1198,7 @@ def insert_observers_for_model(
|
||||||
|
|
||||||
def run_prepare_fx_on_standalone_modules(
|
def run_prepare_fx_on_standalone_modules(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
is_qat: bool,
|
||||||
modules: Dict[str, torch.nn.Module],
|
modules: Dict[str, torch.nn.Module],
|
||||||
matches: Any,
|
matches: Any,
|
||||||
prepare_custom_config_dict: Dict[str, Any],
|
prepare_custom_config_dict: Dict[str, Any],
|
||||||
|
|
@ -1228,6 +1229,7 @@ def run_prepare_fx_on_standalone_modules(
|
||||||
prepare(
|
prepare(
|
||||||
standalone_module,
|
standalone_module,
|
||||||
sm_qconfig_dict,
|
sm_qconfig_dict,
|
||||||
|
is_qat,
|
||||||
sm_prepare_config_dict,
|
sm_prepare_config_dict,
|
||||||
backend_config_dict=sm_backend_config_dict)
|
backend_config_dict=sm_backend_config_dict)
|
||||||
preserved_attributes = \
|
preserved_attributes = \
|
||||||
|
|
@ -1264,12 +1266,12 @@ def save_state(
|
||||||
def prepare(
|
def prepare(
|
||||||
model: GraphModule,
|
model: GraphModule,
|
||||||
qconfig_dict: Any,
|
qconfig_dict: Any,
|
||||||
|
is_qat: bool,
|
||||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
|
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
|
||||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
is_standalone_module: bool = False,
|
is_standalone_module: bool = False) -> ObservedGraphModule:
|
||||||
is_qat: bool = False) -> ObservedGraphModule:
|
|
||||||
""" standalone_module means it a submodule that is not inlined in
|
""" standalone_module means it a submodule that is not inlined in
|
||||||
parent module, and will be quantized separately as one unit.
|
parent module, and will be quantized separately as one unit.
|
||||||
|
|
||||||
|
|
@ -1388,7 +1390,7 @@ def prepare(
|
||||||
"output_quantized_idxs", [])
|
"output_quantized_idxs", [])
|
||||||
|
|
||||||
run_prepare_fx_on_standalone_modules(
|
run_prepare_fx_on_standalone_modules(
|
||||||
model, modules, matches, prepare_custom_config_dict, backend_config_dict)
|
model, is_qat, modules, matches, prepare_custom_config_dict, backend_config_dict)
|
||||||
|
|
||||||
# record names for the set of observed node, so that in convert step
|
# record names for the set of observed node, so that in convert step
|
||||||
# we know whether we need to convert a floating point module to reference
|
# we know whether we need to convert a floating point module to reference
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,9 @@ from torch.fx import (
|
||||||
from torch.fx.graph import (
|
from torch.fx.graph import (
|
||||||
Graph,
|
Graph,
|
||||||
)
|
)
|
||||||
|
from torch.nn.intrinsic import _FusedModule
|
||||||
|
|
||||||
from ..utils import _parent_name
|
from ..utils import _parent_name
|
||||||
from ..fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD
|
|
||||||
from ..qconfig_dict_utils import (
|
from ..qconfig_dict_utils import (
|
||||||
get_object_type_qconfig,
|
get_object_type_qconfig,
|
||||||
maybe_adjust_qconfig_for_module_type_or_name,
|
maybe_adjust_qconfig_for_module_type_or_name,
|
||||||
|
|
@ -56,23 +56,29 @@ def update_qconfig_for_fusion(
|
||||||
|
|
||||||
for node in model.graph.nodes:
|
for node in model.graph.nodes:
|
||||||
if node.op == 'call_module' and node.target in modules:
|
if node.op == 'call_module' and node.target in modules:
|
||||||
module_type = type(modules[str(node.target)])
|
maybe_fused_module = modules[str(node.target)]
|
||||||
if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()):
|
if not isinstance(maybe_fused_module, _FusedModule):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items():
|
ops = list(maybe_fused_module._modules.values())
|
||||||
if module_type == fuser:
|
fused_qconfig = object_type_dict.get(type(ops[0]), None)
|
||||||
fused_qconfig = object_type_dict.get(ops[0], None)
|
|
||||||
|
|
||||||
# Raise an error if the modules in the fused module have
|
# Raise an error if the modules in the fused module have
|
||||||
# different qconfigs specified in the qconfig_dict
|
# different qconfigs specified in the qconfig_dict
|
||||||
for op in ops:
|
# TODO: currently it only works for modules,
|
||||||
if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig):
|
# need to make this work for torch.nn.functional.relu
|
||||||
raise LookupError("During fusion, we need to specify the same " +
|
# TODO: currently it only works for object_type configurations,
|
||||||
f"qconfigs for both modules in {module_type}.")
|
# ideally it should work for different types of configurations,
|
||||||
|
# maybe we want to redesign this part
|
||||||
|
for op in ops[1:]:
|
||||||
|
if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig):
|
||||||
|
raise LookupError(
|
||||||
|
"During fusion, we need to specify the same " +
|
||||||
|
f"qconfigs for all module types in {type(maybe_fused_module)} " +
|
||||||
|
f"offending type: {type(op)}")
|
||||||
|
|
||||||
if fused_qconfig is not None:
|
if fused_qconfig is not None:
|
||||||
object_type_dict[module_type] = fused_qconfig
|
object_type_dict[type(maybe_fused_module)] = fused_qconfig
|
||||||
|
|
||||||
return qconfig_dict
|
return qconfig_dict
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
|
||||||
|
|
||||||
def _fuse_fx(
|
def _fuse_fx(
|
||||||
graph_module: GraphModule,
|
graph_module: GraphModule,
|
||||||
|
is_qat: bool,
|
||||||
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
|
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
|
|
@ -57,7 +58,8 @@ def _fuse_fx(
|
||||||
"""
|
"""
|
||||||
_check_is_graph_module(graph_module)
|
_check_is_graph_module(graph_module)
|
||||||
fuser = Fuser()
|
fuser = Fuser()
|
||||||
return fuser.fuse(graph_module, fuse_custom_config_dict, backend_config_dict)
|
return fuser.fuse(
|
||||||
|
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)
|
||||||
|
|
||||||
|
|
||||||
class Scope(object):
|
class Scope(object):
|
||||||
|
|
@ -175,11 +177,11 @@ class QuantizationTracer(Tracer):
|
||||||
def _prepare_fx(
|
def _prepare_fx(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
qconfig_dict: Any,
|
qconfig_dict: Any,
|
||||||
|
is_qat: bool,
|
||||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
|
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
|
||||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
is_standalone_module: bool = False,
|
is_standalone_module: bool = False,
|
||||||
is_qat: bool = False,
|
|
||||||
) -> ObservedGraphModule:
|
) -> ObservedGraphModule:
|
||||||
r""" Internal helper function for prepare_fx
|
r""" Internal helper function for prepare_fx
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -235,16 +237,20 @@ forward graph of the parent module,
|
||||||
graph_module = GraphModule(model, tracer.trace(model))
|
graph_module = GraphModule(model, tracer.trace(model))
|
||||||
for attr_name in preserved_attributes:
|
for attr_name in preserved_attributes:
|
||||||
setattr(graph_module, attr_name, getattr(model, attr_name))
|
setattr(graph_module, attr_name, getattr(model, attr_name))
|
||||||
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict, backend_config_dict)
|
graph_module = _fuse_fx(
|
||||||
|
graph_module,
|
||||||
|
is_qat,
|
||||||
|
prepare_custom_config_dict,
|
||||||
|
backend_config_dict)
|
||||||
prepared = prepare(
|
prepared = prepare(
|
||||||
graph_module,
|
graph_module,
|
||||||
qconfig_dict,
|
qconfig_dict,
|
||||||
|
is_qat,
|
||||||
tracer.node_name_to_scope,
|
tracer.node_name_to_scope,
|
||||||
prepare_custom_config_dict=prepare_custom_config_dict,
|
prepare_custom_config_dict=prepare_custom_config_dict,
|
||||||
equalization_qconfig_dict=equalization_qconfig_dict,
|
equalization_qconfig_dict=equalization_qconfig_dict,
|
||||||
backend_config_dict=backend_config_dict,
|
backend_config_dict=backend_config_dict,
|
||||||
is_standalone_module=is_standalone_module,
|
is_standalone_module=is_standalone_module,
|
||||||
is_qat=is_qat,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for attr_name in preserved_attributes:
|
for attr_name in preserved_attributes:
|
||||||
|
|
@ -255,9 +261,9 @@ forward graph of the parent module,
|
||||||
def _prepare_standalone_module_fx(
|
def _prepare_standalone_module_fx(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
qconfig_dict: Any,
|
qconfig_dict: Any,
|
||||||
|
is_qat: bool,
|
||||||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
backend_config_dict: Optional[Dict[str, Any]] = None,
|
backend_config_dict: Optional[Dict[str, Any]] = None,
|
||||||
is_qat: bool = False,
|
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
|
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
|
||||||
parent module.
|
parent module.
|
||||||
|
|
@ -284,10 +290,10 @@ def _prepare_standalone_module_fx(
|
||||||
return _prepare_fx(
|
return _prepare_fx(
|
||||||
model,
|
model,
|
||||||
qconfig_dict,
|
qconfig_dict,
|
||||||
|
is_qat,
|
||||||
prepare_custom_config_dict,
|
prepare_custom_config_dict,
|
||||||
backend_config_dict=backend_config_dict,
|
backend_config_dict=backend_config_dict,
|
||||||
is_standalone_module=True,
|
is_standalone_module=True,
|
||||||
is_qat=is_qat,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -332,7 +338,7 @@ def fuse_fx(
|
||||||
)
|
)
|
||||||
for attr_name in preserved_attributes:
|
for attr_name in preserved_attributes:
|
||||||
setattr(graph_module, attr_name, getattr(model, attr_name))
|
setattr(graph_module, attr_name, getattr(model, attr_name))
|
||||||
return _fuse_fx(graph_module, fuse_custom_config_dict)
|
return _fuse_fx(graph_module, False, fuse_custom_config_dict)
|
||||||
|
|
||||||
|
|
||||||
def prepare_fx(
|
def prepare_fx(
|
||||||
|
|
@ -509,10 +515,10 @@ def prepare_fx(
|
||||||
return _prepare_fx(
|
return _prepare_fx(
|
||||||
model,
|
model,
|
||||||
qconfig_dict,
|
qconfig_dict,
|
||||||
|
False, # is_qat
|
||||||
prepare_custom_config_dict,
|
prepare_custom_config_dict,
|
||||||
equalization_qconfig_dict,
|
equalization_qconfig_dict,
|
||||||
backend_config_dict,
|
backend_config_dict,
|
||||||
is_qat=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -558,9 +564,9 @@ def prepare_qat_fx(
|
||||||
return _prepare_fx(
|
return _prepare_fx(
|
||||||
model,
|
model,
|
||||||
qconfig_dict,
|
qconfig_dict,
|
||||||
|
True, # is_qat
|
||||||
prepare_custom_config_dict,
|
prepare_custom_config_dict,
|
||||||
backend_config_dict=backend_config_dict,
|
backend_config_dict=backend_config_dict,
|
||||||
is_qat=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1195,6 +1195,10 @@ class AnnotatedConvBnReLUModel(torch.nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def fuse_model(self):
|
def fuse_model(self):
|
||||||
|
# TODO: remove this check and define two fuse_modules function on this module
|
||||||
|
if self.training:
|
||||||
|
torch.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True)
|
||||||
|
else:
|
||||||
torch.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
|
torch.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
|
||||||
|
|
||||||
class TwoLayerConvModel(torch.nn.Module):
|
class TwoLayerConvModel(torch.nn.Module):
|
||||||
|
|
@ -1464,7 +1468,11 @@ class InnerModule(torch.nn.Module):
|
||||||
if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
|
if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
|
||||||
fusable_layers.append([current_name,
|
fusable_layers.append([current_name,
|
||||||
named_children[idx + 1][0]])
|
named_children[idx + 1][0]])
|
||||||
torch.quantization.fuse_modules(self, fusable_layers, inplace=True)
|
# TODO: remove this check and define two fuse_modules function on this module
|
||||||
|
if self.training:
|
||||||
|
torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
|
||||||
|
else:
|
||||||
|
torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
|
||||||
|
|
||||||
class FunctionalLinear(torch.nn.Module):
|
class FunctionalLinear(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -1955,7 +1963,11 @@ class ResNetBase(torch.nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def fuse_model(self):
|
def fuse_model(self):
|
||||||
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
# TODO: remove this check and define two fuse_model function on this module
|
||||||
|
if self.training:
|
||||||
|
torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||||
|
else:
|
||||||
|
torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||||
|
|
||||||
class ModelMultipleOps(torch.nn.Module):
|
class ModelMultipleOps(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user