mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23003 torch.quantization.fuse_module and torch.nn._intrinsic convRelu and LinearRelu Fusion function to combine specific modules: (conv,bn) and (conv,bn,relu). In all cases, replace modules in place. The first module is replaced with the _intrinsic fused module and the remaining modules are replaced by nn.Identity. Support both training and eval. For training, the modules are "fused" with a sequential container. This is to allow for further module swaps for quantization aware training. Also add: torch.nn._intrinsic for convRelu and LinearRelu. TODO: Add tests for _intrinsic modules. Conv BN fusion code is based on DsKhudia's implementation Differential Revision: D16199720 fbshipit-source-id: 95fb9ffe72b361d280313b2ec57de2acd4f9dda2
112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import torch
|
|
|
|
import torch.nn._intrinsic.modules.fused as torch_fused
|
|
|
|
def fuse_conv_bn(conv, bn):
|
|
r"""Given the conv and bn modules, fuses them and returns the fused module
|
|
|
|
Args:
|
|
conv: Module instance of type conv2d
|
|
bn: Spatial BN instance that needs to be fused with the conv
|
|
|
|
Examples::
|
|
|
|
>>> m1 = nn.Conv2d(10, 20, 3)
|
|
>>> b1 = nn.BatchNorm2d(20)
|
|
>>> m2 = fuse_conv_bn(m1, b1)
|
|
"""
|
|
assert(conv.training == bn.training),\
|
|
"Conv and BN both must be in the same mode (train or eval)."
|
|
|
|
if conv.training:
|
|
return torch.nn._intrinsic.ConvBn2d(conv, bn)
|
|
else:
|
|
return torch.nn.utils.fuse_conv_bn_eval(conv, bn)
|
|
|
|
def fuse_conv_bn_relu(conv, bn, relu):
|
|
r"""Given the conv and bn modules, fuses them and returns the fused module
|
|
|
|
Args:
|
|
conv: Module instance of type conv2d
|
|
bn: Spatial BN instance that needs to be fused with the conv
|
|
|
|
Examples::
|
|
|
|
>>> m1 = nn.Conv2d(10, 20, 3)
|
|
>>> b1 = nn.BatchNorm2d(20)
|
|
>>> m2 = fuse_conv_bn(m1, b1)
|
|
"""
|
|
assert(conv.training == bn.training == relu.training),\
|
|
"Conv and BN both must be in the same mode (train or eval)."
|
|
|
|
if conv.training:
|
|
return torch_fused.ConvBnReLU2d(conv, bn, relu)
|
|
else:
|
|
return torch_fused.ConvReLU2d(
|
|
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
|
|
|
|
|
|
def _fuse_modules(model, named_module_dict, modules_to_fuse, fuser_func=None):
|
|
assert(len(modules_to_fuse) == 2 or len(modules_to_fuse) == 3),\
|
|
"Can fuse only 2 or 3 modules."
|
|
|
|
OP_LIST_TO_FUSER_FUNC = {
|
|
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
|
|
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu
|
|
}
|
|
|
|
mod = []
|
|
parent_mod = []
|
|
for i in range(len(modules_to_fuse)):
|
|
parent_module_name = '.'.join(modules_to_fuse[i].split('.')[:-1])
|
|
mod.append(named_module_dict[modules_to_fuse[i]])
|
|
parent_mod.append(named_module_dict.get(parent_module_name, model))
|
|
|
|
new_mod = mod[0]
|
|
if fuser_func is None:
|
|
types = tuple(type(m) for m in mod)
|
|
fuser_func = OP_LIST_TO_FUSER_FUNC.get(types, None)
|
|
if fuser_func is None:
|
|
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
|
new_mod = fuser_func(*mod)
|
|
|
|
# Assign new_mod to module and set remaining modules to identity
|
|
if new_mod is not mod[0]:
|
|
setattr(parent_mod[0], modules_to_fuse[0].split('.')[-1], new_mod)
|
|
for i in range(1, len(modules_to_fuse)):
|
|
setattr(parent_mod[i], modules_to_fuse[i].split('.')[-1], torch.nn.Identity())
|
|
|
|
|
|
def fuse_modules(model, modules_to_fuse):
|
|
r"""Fuses a list of modules into a single module
|
|
|
|
Fuses only the following sequence of modules:
|
|
conv, bn
|
|
conv, bn, relu
|
|
All other sequences are left unchanged.
|
|
For these sequences, replaces the first item in the list
|
|
with the fused module, replacing the rest of the modules
|
|
with identity.
|
|
|
|
Arguments:
|
|
model: Model containing the modules to be fused
|
|
modules_to_fuse: list of list of module names to fuse.
|
|
|
|
Returns:
|
|
Modifies the model in place.
|
|
|
|
Examples::
|
|
|
|
>>> m = myModel()
|
|
>>> # m is a module containing the sub-modules below
|
|
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
|
|
>>> nn.quantization.fuse_module.fuse_module(m, modules_to_fuse)
|
|
>>> output = m(input)
|
|
|
|
"""
|
|
named_module_dict = {name: mod for name, mod in model.named_modules()}
|
|
for module_list in modules_to_fuse:
|
|
_fuse_modules(model, named_module_dict, module_list)
|