pytorch/torch/nn/utils/fusion.py
Zafar Takhirov bb4f380f35 Optimizing out the division in the fusion
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23275

Test Plan: Imported from OSS

Differential Revision: D16450294

Pulled By: zafartahirov

fbshipit-source-id: 2f1ebf3673ed0467a9f6a912e08e5d95f9b27d0b
2019-08-12 11:35:37 -07:00

30 lines
787 B
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import torch
def fuse_conv_bn_eval(conv, bn):
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
w_conv = fused_conv.weight
b_conv = fused_conv.bias
bn_mean = bn.running_mean
bn_var_sqrt = torch.rsqrt(bn.running_var + bn.eps)
bn_weight = bn.weight
bn_bias = bn.bias
if b_conv is None:
b_conv = bn_mean.new_zeros(bn_mean.shape)
w_conv = w_conv * (bn_weight * bn_var_sqrt).reshape([-1, 1, 1, 1])
b_conv = (b_conv - bn_mean) * bn_var_sqrt * bn_weight + bn_bias
fused_conv.weight = torch.nn.Parameter(w_conv)
fused_conv.bias = torch.nn.Parameter(b_conv)
return fused_conv