mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23003 torch.quantization.fuse_module and torch.nn._intrinsic convRelu and LinearRelu Fusion function to combine specific modules: (conv,bn) and (conv,bn,relu). In all cases, replace modules in place. The first module is replaced with the _intrinsic fused module and the remaining modules are replaced by nn.Identity. Support both training and eval. For training, the modules are "fused" with a sequential container. This is to allow for further module swaps for quantization aware training. Also add: torch.nn._intrinsic for convRelu and LinearRelu. TODO: Add tests for _intrinsic modules. Conv BN fusion code is based on DsKhudia's implementation Differential Revision: D16199720 fbshipit-source-id: 95fb9ffe72b361d280313b2ec57de2acd4f9dda2
30 lines
786 B
Python
30 lines
786 B
Python
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import copy
|
|
import torch
|
|
|
|
def fuse_conv_bn_eval(conv, bn):
|
|
assert(not (conv.training or bn.training)), "Fusion only for eval!"
|
|
fused_conv = copy.deepcopy(conv)
|
|
|
|
w_conv = fused_conv.weight
|
|
b_conv = fused_conv.bias
|
|
|
|
bn_mean = bn.running_mean
|
|
bn_var_sqrt = torch.sqrt(bn.running_var + bn.eps)
|
|
|
|
bn_weight = bn.weight
|
|
bn_bias = bn.bias
|
|
|
|
if b_conv is None:
|
|
b_conv = bn_mean.new_zeros(bn_mean.shape)
|
|
|
|
w_conv = w_conv * (bn_weight / bn_var_sqrt).reshape([-1, 1, 1, 1])
|
|
b_conv = (b_conv - bn_mean) / bn_var_sqrt * bn_weight + bn_bias
|
|
|
|
fused_conv.weight = torch.nn.Parameter(w_conv)
|
|
fused_conv.bias = torch.nn.Parameter(b_conv)
|
|
|
|
return fused_conv
|