mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Align the conv+bn folding behavior with jit path for mixed type case: always keep conv's weight and bias dtype after folding. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99696 Approved by: https://github.com/jgong5, https://github.com/jansel
56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
|
|
|
|
import copy
|
|
import torch
|
|
|
|
def fuse_conv_bn_eval(conv, bn, transpose=False):
|
|
assert(not (conv.training or bn.training)), "Fusion only for eval!"
|
|
fused_conv = copy.deepcopy(conv)
|
|
|
|
fused_conv.weight, fused_conv.bias = \
|
|
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
|
|
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)
|
|
|
|
return fused_conv
|
|
|
|
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False):
|
|
conv_weight_dtype = conv_w.dtype
|
|
conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
|
|
if conv_b is None:
|
|
conv_b = torch.zeros_like(bn_rm)
|
|
if bn_w is None:
|
|
bn_w = torch.ones_like(bn_rm)
|
|
if bn_b is None:
|
|
bn_b = torch.zeros_like(bn_rm)
|
|
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
|
|
|
|
if transpose:
|
|
shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
|
|
else:
|
|
shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
|
|
|
|
fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(dtype=conv_weight_dtype)
|
|
fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(dtype=conv_bias_dtype)
|
|
|
|
return torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad)
|
|
|
|
def fuse_linear_bn_eval(linear, bn):
|
|
assert(not (linear.training or bn.training)), "Fusion only for eval!"
|
|
fused_linear = copy.deepcopy(linear)
|
|
|
|
fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
|
|
fused_linear.weight, fused_linear.bias,
|
|
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
|
|
|
|
return fused_linear
|
|
|
|
def fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
|
|
if linear_b is None:
|
|
linear_b = torch.zeros_like(bn_rm)
|
|
bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
|
|
|
|
fused_w = linear_w * bn_scale.unsqueeze(-1)
|
|
fused_b = (linear_b - bn_rm) * bn_scale + bn_b
|
|
|
|
return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(fused_b, linear_b.requires_grad)
|