from __future__ import absolute_import, division, print_function, unicode_literals import torch import torch.nn._intrinsic.modules.fused as torch_fused def fuse_conv_bn(conv, bn): r"""Given the conv and bn modules, fuses them and returns the fused module Args: conv: Module instance of type conv2d bn: Spatial BN instance that needs to be fused with the conv Examples:: >>> m1 = nn.Conv2d(10, 20, 3) >>> b1 = nn.BatchNorm2d(20) >>> m2 = fuse_conv_bn(m1, b1) """ assert(conv.training == bn.training),\ "Conv and BN both must be in the same mode (train or eval)." if conv.training: assert conv.bias is None, 'Only support fusing Conv2d that does not have bias' assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' return torch.nn._intrinsic.ConvBn2d(conv, bn) else: return torch.nn.utils.fuse_conv_bn_eval(conv, bn) def fuse_conv_bn_relu(conv, bn, relu): r"""Given the conv and bn modules, fuses them and returns the fused module Args: conv: Module instance of type conv2d bn: Spatial BN instance that needs to be fused with the conv Examples:: >>> m1 = nn.Conv2d(10, 20, 3) >>> b1 = nn.BatchNorm2d(20) >>> m2 = fuse_conv_bn(m1, b1) """ assert(conv.training == bn.training == relu.training),\ "Conv and BN both must be in the same mode (train or eval)." if conv.training: assert not relu.inplace, 'We only support fusion of non-inplace ReLU.' return torch_fused.ConvBnReLU2d(conv, bn, relu) else: return torch_fused.ConvReLU2d( torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu) def _fuse_modules(model, named_module_dict, modules_to_fuse, fuser_func=None): assert(len(modules_to_fuse) == 2 or len(modules_to_fuse) == 3),\ "Can fuse only 2 or 3 modules." OP_LIST_TO_FUSER_FUNC = { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn, (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu } mod = [] parent_mod = [] for i in range(len(modules_to_fuse)): parent_module_name = '.'.join(modules_to_fuse[i].split('.')[:-1]) mod.append(named_module_dict[modules_to_fuse[i]]) parent_mod.append(named_module_dict.get(parent_module_name, model)) new_mod = mod[0] if fuser_func is None: types = tuple(type(m) for m in mod) fuser_func = OP_LIST_TO_FUSER_FUNC.get(types, None) if fuser_func is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) new_mod = fuser_func(*mod) # Assign new_mod to module and set remaining modules to identity if new_mod is not mod[0]: setattr(parent_mod[0], modules_to_fuse[0].split('.')[-1], new_mod) for i in range(1, len(modules_to_fuse)): setattr(parent_mod[i], modules_to_fuse[i].split('.')[-1], torch.nn.Identity()) def fuse_modules(model, modules_to_fuse): r"""Fuses a list of modules into a single module Fuses only the following sequence of modules: conv, bn conv, bn, relu All other sequences are left unchanged. For these sequences, replaces the first item in the list with the fused module, replacing the rest of the modules with identity. Arguments: model: Model containing the modules to be fused modules_to_fuse: list of list of module names to fuse. Returns: Modifies the model in place. Examples:: >>> m = myModel() >>> # m is a module containing the sub-modules below >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> nn.quantization.fuse_module.fuse_module(m, modules_to_fuse) >>> output = m(input) """ named_module_dict = {name: mod for name, mod in model.named_modules()} for module_list in modules_to_fuse: _fuse_modules(model, named_module_dict, module_list)