diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 5415e2b03dc..17edc644864 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -820,7 +820,7 @@ class TestDistributed(QuantizationTestCase): torch.ao.quantization.DeQuantStub(), ) - torch.ao.quantization.fuse_modules_qat(model, [['1', '2', '3'], ['4', '5']], inplace=True) + torch.ao.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm') torch.ao.quantization.prepare_qat(model, inplace=True) @@ -861,7 +861,7 @@ class TestDistributed(QuantizationTestCase): model = Model() # fuse it - fused_model = torch.ao.quantization.fuse_modules_qat( + fused_model = torch.ao.quantization.fuse_modules( model, [['conv', 'bn']], ) diff --git a/test/quantization/eager/test_fuse_eager.py b/test/quantization/eager/test_fuse_eager.py index 9cf09dedba6..003f0ff01d4 100644 --- a/test/quantization/eager/test_fuse_eager.py +++ b/test/quantization/eager/test_fuse_eager.py @@ -15,7 +15,6 @@ from torch.ao.quantization import ( prepare_qat, quantize_qat, fuse_modules, - fuse_modules_qat, QConfig, default_qconfig, default_qat_qconfig, @@ -44,8 +43,8 @@ class TestFuseEager(QuantizationTestCase): def test_fuse_module_train(self): model = ModelForFusion(default_qat_qconfig).train() # Test step by step fusion - model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1']) - model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn']) + model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) + model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, msg="Fused Conv + BN + Relu first layer") self.assertEqual(type(model.bn1), torch.nn.Identity, @@ -92,10 +91,8 @@ class TestFuseEager(QuantizationTestCase): checkQuantized(model) model = ModelForFusion(default_qat_qconfig).train() - model = fuse_modules_qat( - model, - [['conv1', 'bn1', 'relu1'], - ['sub1.conv', 'sub1.bn']]) + model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], + ['sub1.conv', 'sub1.bn']]) model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): checkQuantized(model) @@ -104,13 +101,11 @@ class TestFuseEager(QuantizationTestCase): def test_fuse_module_eval(self): model = ModelForFusion(default_qconfig) model.eval() - model = fuse_modules( - model, - [['conv3', 'bn3', 'relu4'], - ['conv1', 'bn1', 'relu1'], - ['conv2', 'relu2'], - ['bn2', 'relu3'], - ['sub1.conv', 'sub1.bn']]) + model = fuse_modules(model, [['conv3', 'bn3', 'relu4'], + ['conv1', 'bn1', 'relu1'], + ['conv2', 'relu2'], + ['bn2', 'relu3'], + ['sub1.conv', 'sub1.bn']]) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + BN + Relu first layer (BN is folded)") self.assertEqual(type(model.conv1[0]), nn.Conv2d, @@ -173,13 +168,11 @@ class TestFuseEager(QuantizationTestCase): checkQuantized(model) model = ModelForFusion(default_qconfig).eval() - model = fuse_modules( - model, - [['conv1', 'bn1', 'relu1'], - ['conv2', 'relu2'], - ['bn2', 'relu3'], - ['sub1.conv', 'sub1.bn'], - ['conv3', 'bn3', 'relu4']]) + model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], + ['conv2', 'relu2'], + ['bn2', 'relu3'], + ['sub1.conv', 'sub1.bn'], + ['conv3', 'bn3', 'relu4']]) model = quantize(model, test_only_eval_fn, [self.img_data_1d]) checkQuantized(model) @@ -188,13 +181,11 @@ class TestFuseEager(QuantizationTestCase): with override_quantized_engine(qengine): model = ModelWithSequentialFusion().train() model.to(torch.float) - fuse_modules_qat( - model, [['conv1', 'relu1'] , - ['features.0.0', 'features.0.1', 'features.0.2'], - ['features.1.0', 'features.1.1', 'features.1.2'], - ['features.2.0', 'features.2.1', 'features.2.2'], - ['classifier.0', 'classifier.1']], - inplace=True) + fuse_modules(model, [['conv1', 'relu1'] , + ['features.0.0', 'features.0.1', 'features.0.2'], + ['features.1.0', 'features.1.1', 'features.1.2'], + ['features.2.0', 'features.2.1', 'features.2.2'], + ['classifier.0', 'classifier.1']], inplace=True) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + Relu: nni.ConvReLU2d") self.assertEqual(type(model.conv1[0]), nn.Conv2d, @@ -242,14 +233,11 @@ class TestFuseEager(QuantizationTestCase): with override_quantized_engine(qengine): model = ModelWithSequentialFusion().eval() model.to(torch.float) - fuse_modules( - model, - [['conv1', 'relu1'], - ['features.0.0', 'features.0.1', 'features.0.2'], - ['features.1.0', 'features.1.1', 'features.1.2'], - ['features.2.0', 'features.2.1', 'features.2.2'], - ['classifier.0', 'classifier.1']], - inplace=True) + fuse_modules(model, [['conv1', 'relu1'] , + ['features.0.0', 'features.0.1', 'features.0.2'], + ['features.1.0', 'features.1.1', 'features.1.2'], + ['features.2.0', 'features.2.1', 'features.2.2'], + ['classifier.0', 'classifier.1']], inplace=True) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + Relu: nni.ConvReLU2d") self.assertEqual(type(model.conv1[0]), nn.Conv2d, @@ -298,10 +286,8 @@ class TestFuseEager(QuantizationTestCase): # fused model model_orig.qconfig = QConfig(activation=torch.nn.Identity, weight=torch.nn.Identity) - model = fuse_modules_qat( - model_orig, - [["conv1", "bn1", "relu1"], - ["conv2", "bn2"]]) + model = fuse_modules(model_orig, [["conv1", "bn1", "relu1"], + ["conv2", "bn2"]]) prep_model = prepare_qat(model, inplace=False) # output with fusion but no observers. out_fused = prep_model(self.img_data_2d[0][0]) @@ -399,8 +385,8 @@ class TestFuseEager(QuantizationTestCase): self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d)) self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d)) - model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1']) - model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn']) + model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) + model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) fused = True before_fusion_pre_count = counter['pre_forwards'] diff --git a/test/quantization/eager/test_model_numerics.py b/test/quantization/eager/test_model_numerics.py index b259e102b37..9d77a035e38 100644 --- a/test/quantization/eager/test_model_numerics.py +++ b/test/quantization/eager/test_model_numerics.py @@ -69,7 +69,7 @@ class TestModelNumericsEager(QuantizationTestCase): fq_model = torch.ao.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = torch.ao.quantization.default_qat_qconfig - torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.ao.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.ao.quantization.disable_fake_quant) @@ -105,7 +105,7 @@ class TestModelNumericsEager(QuantizationTestCase): fq_model = torch.ao.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = qconfig - torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.ao.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.ao.quantization.disable_fake_quant) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 193af7d5740..5a89c279eef 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -48,7 +48,6 @@ from torch.ao.quantization import ( get_default_qat_qconfig, get_default_qconfig_dict, fuse_modules, - fuse_modules_qat, prepare, prepare_qat, convert, @@ -364,8 +363,6 @@ class TestFuseFx(QuantizationTestCase): @skipIfNoFBGEMM def test_qconfig_fused_module(self): - """ TODO: add test for all fused modules - """ qconfig_dict = { "": None, "object_type": [(nn.Linear, default_qconfig), @@ -893,20 +890,15 @@ class TestQuantizeFx(QuantizationTestCase): m_eager.eval() qconfig = get_default_qconfig(qengine) prepare_fn = prepare - is_qat = False else: m_eager.train() qconfig = get_default_qat_qconfig(qengine) prepare_fn = prepare_qat - is_qat = True fuse_list = ["conv", "bn"] if has_relu: fuse_list.append("relu") - if is_qat: - fuse_modules_qat(m_eager, fuse_list, inplace=True) - else: - fuse_modules(m_eager, fuse_list, inplace=True) + fuse_modules(m_eager, fuse_list, inplace=True) m_eager.qconfig = qconfig m_eager = prepare_fn(m_eager) m_eager(*self.img_data_dict[dim][0]) @@ -5837,7 +5829,6 @@ class TestQuantizeFxModels(QuantizationTestCase): graph.eval() calibrate_or_train = test_only_eval_fn data = self.img_data_2d - is_qat = False else: assert quant_type == QuantType.QAT qconfig = default_qat_qconfig @@ -5847,7 +5838,6 @@ class TestQuantizeFxModels(QuantizationTestCase): graph.train() calibrate_or_train = test_only_train_fn data = self.img_data_2d_train - is_qat = True if hasattr(eager, "fuse_model"): eager.fuse_model() diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 3f3bdc63092..a2da899a792 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -2,7 +2,6 @@ from .fake_quantize import * # noqa: F403 from .fuse_modules import fuse_modules # noqa: F403 -from .fuse_modules import fuse_modules_qat # noqa: F403 from .fuser_method_mappings import * # noqa: F403 from .observer import * # noqa: F403 from .qconfig import * # noqa: F403 diff --git a/torch/ao/quantization/fuse_modules.py b/torch/ao/quantization/fuse_modules.py index 05705b161fc..c523d9bf225 100644 --- a/torch/ao/quantization/fuse_modules.py +++ b/torch/ao/quantization/fuse_modules.py @@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module): setattr(cur_mod, tokens[-1], module) -def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): +def fuse_known_modules(mod_list, additional_fuser_method_mapping=None): r"""Returns a list of modules that fuses the operations specified in the input module list. @@ -46,7 +46,7 @@ def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): if fuser_method is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) - fused = fuser_method(is_qat, *mod_list) + fused = fuser_method(*mod_list) # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion # Move pre forward hooks of the base module to resulting fused module for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items(): @@ -65,7 +65,7 @@ def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): return new_mod -def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): +def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): if fuse_custom_config_dict is None: fuse_custom_config_dict = {} additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) @@ -74,25 +74,12 @@ def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_m mod_list.append(_get_module(model, item)) # Fuse list of modules - new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) + new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping) # Replace original module list with fused module list for i, item in enumerate(modules_to_fuse): _set_module(model, item, new_mod_list[i]) -def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): - if not inplace: - model = copy.deepcopy(model) - - if all(isinstance(module_element, str) for module_element in modules_to_fuse): - # Handle case of modules_to_fuse being a list - _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict) - else: - # Handle case of modules_to_fuse being a list of lists - for module_list in modules_to_fuse: - _fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict) - return model - def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): r"""Fuses a list of modules into a single module @@ -134,34 +121,27 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo Examples:: - >>> m = M().eval() - >>> # m is a module containing the sub-modules below + >>> m = myModel() + >>> # m is a module containing the sub-modules below >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) - >>> m = M().eval() + >>> m = myModel() >>> # Alternately provide a single list of modules to fuse >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) """ - return _fuse_modules( - model, - modules_to_fuse, - is_qat=False, - inplace=inplace, - fuser_func=fuse_known_modules, - fuse_custom_config_dict=None) + if not inplace: + model = copy.deepcopy(model) -def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): - """ QAT version for `fuse_modules` - """ - return _fuse_modules( - model, - modules_to_fuse, - is_qat=True, - inplace=inplace, - fuser_func=fuse_known_modules, - fuse_custom_config_dict=None) + if all(isinstance(module_element, str) for module_element in modules_to_fuse): + # Handle case of modules_to_fuse being a list + _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict) + else: + # Handle case of modules_to_fuse being a list of lists + for module_list in modules_to_fuse: + _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict) + return model diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 23e5a1f4c35..a33d3011454 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -7,12 +7,10 @@ from torch.ao.quantization.utils import Pattern from torch.ao.quantization.utils import get_combined_dict -def fuse_conv_bn(is_qat, conv, bn): +def fuse_conv_bn(conv, bn): r"""Given the conv and bn modules, fuses them and returns the fused module Args: - is_qat: a flag for whether we are using quantization aware training fusion - or post training quantization fusion conv: Module instance of type conv2d/conv3d bn: Spatial BN instance that needs to be fused with the conv @@ -31,9 +29,7 @@ def fuse_conv_bn(is_qat, conv, bn): nn.Conv3d: nni.ConvBn3d, } - if is_qat: - # TODO: remove the assert later - assert conv.training, "qat is only supported when conv.training is True currently" + if conv.training: assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' @@ -45,12 +41,10 @@ def fuse_conv_bn(is_qat, conv, bn): else: return nn.utils.fuse_conv_bn_eval(conv, bn) -def fuse_conv_bn_relu(is_qat, conv, bn, relu): +def fuse_conv_bn_relu(conv, bn, relu): r"""Given the conv and bn modules, fuses them and returns the fused module Args: - is_qat: a flag for whether we are using quantization aware training fusion - or post training quantization fusion conv: Module instance of type conv2d/conv3d bn: Spatial BN instance that needs to be fused with the conv @@ -64,9 +58,7 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): assert(conv.training == bn.training == relu.training),\ "Conv and BN both must be in the same mode (train or eval)." fused_module : Optional[Type[nn.Sequential]] = None - if is_qat: - # TODO: remove the assert later - assert conv.training, "qat is only supported when conv.training is True currently" + if conv.training: map_to_fused_module_train = { nn.Conv1d: nni.ConvBnReLU1d, nn.Conv2d: nni.ConvBnReLU2d, @@ -93,12 +85,10 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): else: raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) -def fuse_linear_bn(is_qat, linear, bn): +def fuse_linear_bn(linear, bn): r"""Given the linear and bn modules, fuses them and returns the fused module Args: - is_qat: a flag for whether we are using quantization aware training fusion - or post training quantization fusion linear: Module instance of type Linear bn: BatchNorm1d instance that needs to be fused with the linear layer @@ -111,14 +101,13 @@ def fuse_linear_bn(is_qat, linear, bn): assert(linear.training == bn.training),\ "Linear and BN both must be in the same mode (train or eval)." - if is_qat: - # TODO: remove the assert later - assert linear.training, "qat is only supported when linear.training is True currently" + if linear.training: raise Exception("Fusing Linear+BatchNorm not yet supported in training.") else: return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) -def fuse_convtranspose_bn(is_qat, convt, bn): + +def fuse_convtranspose_bn(convt, bn): r"""Given ConvTranspose and bn modules, fuses them and returns the fused module Args: @@ -135,20 +124,11 @@ def fuse_convtranspose_bn(is_qat, convt, bn): assert(convt.training == bn.training),\ "ConvTranspose and BN both must be in the same mode (train or eval)." - if is_qat: - assert convt.training, "qat is only supported when convt.training is True currently" + if convt.training: raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in training.") else: return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) -def sequential_wrapper2(sequential): - """ Given a sequential class for two modules, return a function that takes - is_qat, and then two modules as argument, that ignores the is_qat flag - and always returns the sequential that combines the two input modules - """ - def fuser_method(is_qat, m1, m2): - return sequential(m1, m2) - return fuser_method DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = { (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, @@ -157,13 +137,13 @@ DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = { (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, - (nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d), - (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d), - (nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d), + (nn.Conv1d, nn.ReLU): nni.ConvReLU1d, + (nn.Conv2d, nn.ReLU): nni.ConvReLU2d, + (nn.Conv3d, nn.ReLU): nni.ConvReLU3d, (nn.Linear, nn.BatchNorm1d): fuse_linear_bn, - (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU), - (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d), - (nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d), + (nn.Linear, nn.ReLU): nni.LinearReLU, + (nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d, + (nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d, (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn, (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn, (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn, @@ -181,25 +161,13 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None): assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list) return fuser_method -def reverse_sequential_wrapper2(sequential): - """ Given a sequential class for two modules, return a function that takes - is_qat, and then two modules as argument, that ignores the is_qat flag - and always returns the sequential that combines the two input modules, with - the order of two inputs reversed - """ - def fuser_method(is_qat, m1, m2): - return sequential(m2, m1) - return fuser_method - def reverse2(f): - def reversed(is_qat, x, y): - return f(is_qat, y, x) - return reversed + return lambda x, y: f(y, x) def reverse3(f): - def reversed(is_qat, x, w): + def reversed(x, w): y, z = w - return f(is_qat, z, y, x) + return f(z, y, x) return reversed DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = { @@ -209,13 +177,13 @@ DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu), (nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn), (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu), - (nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d), - (nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d), - (nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d), + (nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d), + (nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d), + (nn.ReLU, nn.Conv3d): reverse2(nni.ConvReLU3d), (nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn), - (nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU), - (nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d), - (nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d), + (nn.ReLU, nn.Linear): reverse2(nni.LinearReLU), + (nn.ReLU, nn.BatchNorm2d): reverse2(nni.BNReLU2d), + (nn.ReLU, nn.BatchNorm3d): reverse2(nni.BNReLU3d), (nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn), (nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn), (nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn), diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py index 60e7ccd28a5..02a43f99680 100644 --- a/torch/ao/quantization/fx/fuse.py +++ b/torch/ao/quantization/fx/fuse.py @@ -28,7 +28,6 @@ class Fuser: def fuse( self, model: GraphModule, - is_qat: bool, fuse_custom_config_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, ) -> GraphModule: @@ -73,7 +72,7 @@ class Fuser: root_node = get_root_node(matched_node_pattern) # type: ignore[index] env[node.name] = obj.fuse( self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type] - fuse_custom_config_dict, fuser_method_mapping, is_qat) + fuse_custom_config_dict, fuser_method_mapping) elif maybe_last_node is None: env[node.name] = self.fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here diff --git a/torch/ao/quantization/fx/fusion_patterns.py b/torch/ao/quantization/fx/fusion_patterns.py index 2a0b9ff6f1e..d744856bcb4 100644 --- a/torch/ao/quantization/fx/fusion_patterns.py +++ b/torch/ao/quantization/fx/fusion_patterns.py @@ -28,8 +28,7 @@ class FuseHandler(ABC): root_node: Node, matched_node_pattern: NodePattern, fuse_custom_config_dict: Dict[str, Any], - fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]], - is_qat: bool) -> Node: + fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node: pass @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d)) @@ -70,8 +69,7 @@ class DefaultFuseHandler(FuseHandler): root_node: Node, matched_node_pattern: NodePattern, fuse_custom_config_dict: Dict[str, Any], - fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]], - is_qat: bool) -> Node: + fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node: additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) assert root_node.op == "call_module", "Expecting module node to be a call_module Node" root_module = quantizer.modules[root_node.target] @@ -115,7 +113,7 @@ class DefaultFuseHandler(FuseHandler): fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping) # TODO: change the signature for fuser_method to take matched module patterns # as input - fused_module = fuser_method(is_qat, *matched_modules) + fused_module = fuser_method(*matched_modules) # TODO: maybe add a pass to cleanup bn modules? setattr(quantizer.modules[module_parent_name], module_name, fused_module) return quantizer.fused_graph.node_copy(root_node, load_arg) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index d0d951ce7aa..f10469ac04b 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -1198,7 +1198,6 @@ def insert_observers_for_model( def run_prepare_fx_on_standalone_modules( model: torch.nn.Module, - is_qat: bool, modules: Dict[str, torch.nn.Module], matches: Any, prepare_custom_config_dict: Dict[str, Any], @@ -1229,7 +1228,6 @@ def run_prepare_fx_on_standalone_modules( prepare( standalone_module, sm_qconfig_dict, - is_qat, sm_prepare_config_dict, backend_config_dict=sm_backend_config_dict) preserved_attributes = \ @@ -1266,12 +1264,12 @@ def save_state( def prepare( model: GraphModule, qconfig_dict: Any, - is_qat: bool, node_name_to_scope: Dict[str, Tuple[str, type]], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, - is_standalone_module: bool = False) -> ObservedGraphModule: + is_standalone_module: bool = False, + is_qat: bool = False) -> ObservedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -1390,7 +1388,7 @@ def prepare( "output_quantized_idxs", []) run_prepare_fx_on_standalone_modules( - model, is_qat, modules, matches, prepare_custom_config_dict, backend_config_dict) + model, modules, matches, prepare_custom_config_dict, backend_config_dict) # record names for the set of observed node, so that in convert step # we know whether we need to convert a floating point module to reference diff --git a/torch/ao/quantization/fx/qconfig_utils.py b/torch/ao/quantization/fx/qconfig_utils.py index 80afa562a10..4f4d2cbaf39 100644 --- a/torch/ao/quantization/fx/qconfig_utils.py +++ b/torch/ao/quantization/fx/qconfig_utils.py @@ -11,9 +11,9 @@ from torch.fx import ( from torch.fx.graph import ( Graph, ) -from torch.nn.intrinsic import _FusedModule from ..utils import _parent_name +from ..fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD from ..qconfig_dict_utils import ( get_object_type_qconfig, maybe_adjust_qconfig_for_module_type_or_name, @@ -56,29 +56,23 @@ def update_qconfig_for_fusion( for node in model.graph.nodes: if node.op == 'call_module' and node.target in modules: - maybe_fused_module = modules[str(node.target)] - if not isinstance(maybe_fused_module, _FusedModule): + module_type = type(modules[str(node.target)]) + if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()): continue - ops = list(maybe_fused_module._modules.values()) - fused_qconfig = object_type_dict.get(type(ops[0]), None) + for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items(): + if module_type == fuser: + fused_qconfig = object_type_dict.get(ops[0], None) - # Raise an error if the modules in the fused module have - # different qconfigs specified in the qconfig_dict - # TODO: currently it only works for modules, - # need to make this work for torch.nn.functional.relu - # TODO: currently it only works for object_type configurations, - # ideally it should work for different types of configurations, - # maybe we want to redesign this part - for op in ops[1:]: - if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig): - raise LookupError( - "During fusion, we need to specify the same " + - f"qconfigs for all module types in {type(maybe_fused_module)} " + - f"offending type: {type(op)}") + # Raise an error if the modules in the fused module have + # different qconfigs specified in the qconfig_dict + for op in ops: + if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig): + raise LookupError("During fusion, we need to specify the same " + + f"qconfigs for both modules in {module_type}.") - if fused_qconfig is not None: - object_type_dict[type(maybe_fused_module)] = fused_qconfig + if fused_qconfig is not None: + object_type_dict[module_type] = fused_qconfig return qconfig_dict diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 07a1eb6755b..b3f6d6acfc2 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -47,7 +47,6 @@ def _swap_ff_with_fxff(model: torch.nn.Module) -> None: def _fuse_fx( graph_module: GraphModule, - is_qat: bool, fuse_custom_config_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, ) -> GraphModule: @@ -58,8 +57,7 @@ def _fuse_fx( """ _check_is_graph_module(graph_module) fuser = Fuser() - return fuser.fuse( - graph_module, is_qat, fuse_custom_config_dict, backend_config_dict) + return fuser.fuse(graph_module, fuse_custom_config_dict, backend_config_dict) class Scope(object): @@ -177,11 +175,11 @@ class QuantizationTracer(Tracer): def _prepare_fx( model: torch.nn.Module, qconfig_dict: Any, - is_qat: bool, prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, is_standalone_module: bool = False, + is_qat: bool = False, ) -> ObservedGraphModule: r""" Internal helper function for prepare_fx Args: @@ -237,20 +235,16 @@ forward graph of the parent module, graph_module = GraphModule(model, tracer.trace(model)) for attr_name in preserved_attributes: setattr(graph_module, attr_name, getattr(model, attr_name)) - graph_module = _fuse_fx( - graph_module, - is_qat, - prepare_custom_config_dict, - backend_config_dict) + graph_module = _fuse_fx(graph_module, prepare_custom_config_dict, backend_config_dict) prepared = prepare( graph_module, qconfig_dict, - is_qat, tracer.node_name_to_scope, prepare_custom_config_dict=prepare_custom_config_dict, equalization_qconfig_dict=equalization_qconfig_dict, backend_config_dict=backend_config_dict, is_standalone_module=is_standalone_module, + is_qat=is_qat, ) for attr_name in preserved_attributes: @@ -261,9 +255,9 @@ forward graph of the parent module, def _prepare_standalone_module_fx( model: torch.nn.Module, qconfig_dict: Any, - is_qat: bool, prepare_custom_config_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, + is_qat: bool = False, ) -> GraphModule: r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. @@ -290,10 +284,10 @@ def _prepare_standalone_module_fx( return _prepare_fx( model, qconfig_dict, - is_qat, prepare_custom_config_dict, backend_config_dict=backend_config_dict, is_standalone_module=True, + is_qat=is_qat, ) @@ -338,7 +332,7 @@ def fuse_fx( ) for attr_name in preserved_attributes: setattr(graph_module, attr_name, getattr(model, attr_name)) - return _fuse_fx(graph_module, False, fuse_custom_config_dict) + return _fuse_fx(graph_module, fuse_custom_config_dict) def prepare_fx( @@ -515,10 +509,10 @@ def prepare_fx( return _prepare_fx( model, qconfig_dict, - False, # is_qat prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict, + is_qat=False, ) @@ -564,9 +558,9 @@ def prepare_qat_fx( return _prepare_fx( model, qconfig_dict, - True, # is_qat prepare_custom_config_dict, backend_config_dict=backend_config_dict, + is_qat=True, ) diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index fcf18d0e1e0..bcdf57e67ee 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -1195,11 +1195,7 @@ class AnnotatedConvBnReLUModel(torch.nn.Module): return x def fuse_model(self): - # TODO: remove this check and define two fuse_modules function on this module - if self.training: - torch.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True) - else: - torch.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True) + torch.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True) class TwoLayerConvModel(torch.nn.Module): def __init__(self): @@ -1468,11 +1464,7 @@ class InnerModule(torch.nn.Module): if isinstance(named_children[idx + 1][1], torch.nn.ReLU): fusable_layers.append([current_name, named_children[idx + 1][0]]) - # TODO: remove this check and define two fuse_modules function on this module - if self.training: - torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True) - else: - torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True) + torch.quantization.fuse_modules(self, fusable_layers, inplace=True) class FunctionalLinear(torch.nn.Module): def __init__(self): @@ -1963,11 +1955,7 @@ class ResNetBase(torch.nn.Module): return out def fuse_model(self): - # TODO: remove this check and define two fuse_model function on this module - if self.training: - torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True) - else: - torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True) class ModelMultipleOps(torch.nn.Module): def __init__(self):