fix conv+bn folding issue when bn hasn't running states (#71259)

Summary:
Doing conv+bn folding which bn hasn't a running stats, there have error for JIT and FX path:

```
import torch

import torch.nn as nn

import torch.fx.experimental.optimization as optimization

class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.conv = nn.Conv2d(32, 64, 3, stride=2)
        self.bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

x = torch.randn([1, 32, 50, 50])

model = M().eval()

'''
# jit path
with torch.no_grad():
    traced = torch.jit.trace(model, x).eval()
    traced = torch.jit.freeze(traced)
'''

# FX path
fused_model = optimization.fuse(model)
```

expected result:
1. JIT path
```
Traceback (most recent call last):
  File "bn_test.py", line 27, in <module>
    traced = torch.jit.freeze(traced)
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/jit/_freeze.py", line 119, in freeze
    run_frozen_optimizations(out, optimize_numerics, preserved_methods)
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/jit/_freeze.py", line 167, in run_frozen_optimizations
    torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics)
RuntimeError: Expected Tensor but got None
```
2. FX path
```
Traceback (most recent call last):
  File "bn_test.py", line 31, in <module>
    model = optimization.fuse(model, inplace=True)
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/fx/experimental/optimization.py", line 71, in fuse
    fused_conv = fuse_conv_bn_eval(conv, bn)
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/nn/utils/fusion.py", line 11, in fuse_conv_bn_eval
    fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/nn/utils/fusion.py", line 23, in fuse_conv_bn_weights
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
TypeError: unsupported operand type(s) for +: 'NoneType' and 'float'
```

This PR will fix this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71259

Reviewed By: anjali411

Differential Revision: D33595049

Pulled By: davidberard98

fbshipit-source-id: 0fe56bb2bb25d6d54ebc53789d2ad22458da9012
(cherry picked from commit 5672c08378)
This commit is contained in:
XiaobingSuper 2022-01-18 13:46:01 -08:00 committed by PyTorch MergeBot
parent a986154950
commit b8679ee1fc
4 changed files with 42 additions and 3 deletions

View File

@ -1606,13 +1606,14 @@ class TestFrozenOptimizations(JitTestCase):
conv_bias = [True, False]
module_pairs = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)]
use_tracing = [True, False]
bn_running_stats = [True, False]
for use_bias, modules, tracing in product(conv_bias, module_pairs, use_tracing):
for use_bias, modules, tracing, track_stats in product(conv_bias, module_pairs, use_tracing, bn_running_stats):
class ConvBN(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(ConvBN, self).__init__()
self.conv = modules[0](in_channels, out_channels, bias=use_bias, **kwargs)
self.bn = modules[1](out_channels, eps=0.001)
self.bn = modules[1](out_channels, eps=0.001, track_running_stats=track_stats)
def forward(self, x):
x = self.conv(x)
@ -1644,7 +1645,10 @@ class TestFrozenOptimizations(JitTestCase):
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
if track_stats:
FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph)
else:
FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph)
self.assertEqual(mod_eager(inp), scripted_mod(inp))
self.assertEqual(mod_eager(inp), scripted_mod(inp))

View File

@ -667,6 +667,30 @@ class TestFXExperimental(JitTestCase):
self.assertEqual(fused(inp), rn18(inp))
def test_conv_bn_fusion_not_running_state(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(32, 64, 3, stride=2)
self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
model = M().eval()
traced = symbolic_trace(model)
fused = optimization.fuse(traced)
inp = torch.randn([1, 32, 50, 50])
# bn need not be folded in conv
self.assertTrue(
any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
)
self.assertEqual(fused(inp), model(inp))
def test_call_to_assert_no_msg(self):
class M(torch.nn.Module):
def forward(self, a, b):

View File

@ -62,6 +62,15 @@ void FoldFrozenConvBatchnorm(Block* b) {
continue;
}
auto bn_rm_ivalue = bn->namedInput("running_mean");
auto bn_rv_ivalue = bn->namedInput("running_var");
// check running_mean and running_var has value, if they are
// None(track_running_stats=False), skiping the folding path.
if (bn_rm_ivalue->type() == NoneType::get() &&
bn_rv_ivalue->type() == NoneType::get()) {
continue;
}
auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();

View File

@ -68,6 +68,8 @@ def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
if not bn.track_running_stats:
continue
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])