mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45474 When batchnorm affine is set to false, weight and bias is set to None, which is not supported in this case. Added a fix to set weights to 1 and bias to 0 if they are not set. Test Plan: Add unit test for testing fusing conv, batchnorm where batchnorm is in affine=False mode. Reviewed By: z-a-f Differential Revision: D23977080 fbshipit-source-id: 2782be626dc67553f3d27d8f8b1ddc7dea022c2a
29 lines
921 B
Python
29 lines
921 B
Python
|
|
|
|
import copy
|
|
import torch
|
|
|
|
def fuse_conv_bn_eval(conv, bn):
|
|
assert(not (conv.training or bn.training)), "Fusion only for eval!"
|
|
fused_conv = copy.deepcopy(conv)
|
|
|
|
fused_conv.weight, fused_conv.bias = \
|
|
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
|
|
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
|
|
|
|
return fused_conv
|
|
|
|
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
|
|
if conv_b is None:
|
|
conv_b = bn_rm.new_zeros(bn_rm.shape)
|
|
if bn_w is None:
|
|
bn_w = torch.ones_like(bn_rm)
|
|
if bn_b is None:
|
|
bn_b = torch.zeros_like(bn_rm)
|
|
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
|
|
|
|
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
|
|
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
|
|
|
|
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
|