pytorch/torch/nn/utils/fusion.py
Zafar Takhirov 058645acb1 Fusion and _intrinsic modules (#23003)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23003

torch.quantization.fuse_module and torch.nn._intrinsic convRelu and LinearRelu

Fusion function to combine specific modules: (conv,bn) and  (conv,bn,relu).
In all cases, replace modules in place. The first module is replaced with the _intrinsic fused module and the remaining modules are replaced by nn.Identity.
Support both training and eval. For training, the modules are "fused" with a sequential container. This is to allow for further module swaps for quantization aware training.
Also add: torch.nn._intrinsic for convRelu and LinearRelu.

TODO: Add tests for _intrinsic modules.

Conv BN fusion code is based on DsKhudia's implementation

Differential Revision: D16199720

fbshipit-source-id: 95fb9ffe72b361d280313b2ec57de2acd4f9dda2
2019-07-23 14:54:19 -07:00

30 lines
786 B
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import torch
def fuse_conv_bn_eval(conv, bn):
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
w_conv = fused_conv.weight
b_conv = fused_conv.bias
bn_mean = bn.running_mean
bn_var_sqrt = torch.sqrt(bn.running_var + bn.eps)
bn_weight = bn.weight
bn_bias = bn.bias
if b_conv is None:
b_conv = bn_mean.new_zeros(bn_mean.shape)
w_conv = w_conv * (bn_weight / bn_var_sqrt).reshape([-1, 1, 1, 1])
b_conv = (b_conv - bn_mean) / bn_var_sqrt * bn_weight + bn_bias
fused_conv.weight = torch.nn.Parameter(w_conv)
fused_conv.bias = torch.nn.Parameter(b_conv)
return fused_conv