mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23275 Test Plan: Imported from OSS Differential Revision: D16450294 Pulled By: zafartahirov fbshipit-source-id: 2f1ebf3673ed0467a9f6a912e08e5d95f9b27d0b
30 lines
787 B
Python
30 lines
787 B
Python
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import copy
|
|
import torch
|
|
|
|
def fuse_conv_bn_eval(conv, bn):
|
|
assert(not (conv.training or bn.training)), "Fusion only for eval!"
|
|
fused_conv = copy.deepcopy(conv)
|
|
|
|
w_conv = fused_conv.weight
|
|
b_conv = fused_conv.bias
|
|
|
|
bn_mean = bn.running_mean
|
|
bn_var_sqrt = torch.rsqrt(bn.running_var + bn.eps)
|
|
|
|
bn_weight = bn.weight
|
|
bn_bias = bn.bias
|
|
|
|
if b_conv is None:
|
|
b_conv = bn_mean.new_zeros(bn_mean.shape)
|
|
|
|
w_conv = w_conv * (bn_weight * bn_var_sqrt).reshape([-1, 1, 1, 1])
|
|
b_conv = (b_conv - bn_mean) * bn_var_sqrt * bn_weight + bn_bias
|
|
|
|
fused_conv.weight = torch.nn.Parameter(w_conv)
|
|
fused_conv.bias = torch.nn.Parameter(b_conv)
|
|
|
|
return fused_conv
|