From b8679ee1fc53f3931e3fd415276f2854005fb46b Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Tue, 18 Jan 2022 13:46:01 -0800 Subject: [PATCH] fix conv+bn folding issue when bn hasn't running states (#71259) Summary: Doing conv+bn folding which bn hasn't a running stats, there have error for JIT and FX path: ``` import torch import torch.nn as nn import torch.fx.experimental.optimization as optimization class M(nn.Module): def __init__(self): super(M, self).__init__() self.conv = nn.Conv2d(32, 64, 3, stride=2) self.bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) def forward(self, x): x = self.conv(x) x = self.bn(x) return x x = torch.randn([1, 32, 50, 50]) model = M().eval() ''' # jit path with torch.no_grad(): traced = torch.jit.trace(model, x).eval() traced = torch.jit.freeze(traced) ''' # FX path fused_model = optimization.fuse(model) ``` expected result: 1. JIT path ``` Traceback (most recent call last): File "bn_test.py", line 27, in traced = torch.jit.freeze(traced) File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/jit/_freeze.py", line 119, in freeze run_frozen_optimizations(out, optimize_numerics, preserved_methods) File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/jit/_freeze.py", line 167, in run_frozen_optimizations torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics) RuntimeError: Expected Tensor but got None ``` 2. FX path ``` Traceback (most recent call last): File "bn_test.py", line 31, in model = optimization.fuse(model, inplace=True) File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/fx/experimental/optimization.py", line 71, in fuse fused_conv = fuse_conv_bn_eval(conv, bn) File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/nn/utils/fusion.py", line 11, in fuse_conv_bn_eval fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/nn/utils/fusion.py", line 23, in fuse_conv_bn_weights bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) TypeError: unsupported operand type(s) for +: 'NoneType' and 'float' ``` This PR will fix this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/71259 Reviewed By: anjali411 Differential Revision: D33595049 Pulled By: davidberard98 fbshipit-source-id: 0fe56bb2bb25d6d54ebc53789d2ad22458da9012 (cherry picked from commit 5672c083784585e6e1ec5657f02bd3051afb2b50) --- test/jit/test_freezing.py | 10 +++++--- test/test_fx_experimental.py | 24 +++++++++++++++++++ torch/csrc/jit/passes/frozen_conv_folding.cpp | 9 +++++++ torch/fx/experimental/optimization.py | 2 ++ 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index a5b02c03cc9..331fd38f9af 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -1606,13 +1606,14 @@ class TestFrozenOptimizations(JitTestCase): conv_bias = [True, False] module_pairs = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)] use_tracing = [True, False] + bn_running_stats = [True, False] - for use_bias, modules, tracing in product(conv_bias, module_pairs, use_tracing): + for use_bias, modules, tracing, track_stats in product(conv_bias, module_pairs, use_tracing, bn_running_stats): class ConvBN(torch.nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super(ConvBN, self).__init__() self.conv = modules[0](in_channels, out_channels, bias=use_bias, **kwargs) - self.bn = modules[1](out_channels, eps=0.001) + self.bn = modules[1](out_channels, eps=0.001, track_running_stats=track_stats) def forward(self, x): x = self.conv(x) @@ -1644,7 +1645,10 @@ class TestFrozenOptimizations(JitTestCase): scripted_mod = torch.jit.freeze(scripted_mod) self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) - FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) + if track_stats: + FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) + else: + FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph) self.assertEqual(mod_eager(inp), scripted_mod(inp)) self.assertEqual(mod_eager(inp), scripted_mod(inp)) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 75283ad6c9d..a35c94a3dc8 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -667,6 +667,30 @@ class TestFXExperimental(JitTestCase): self.assertEqual(fused(inp), rn18(inp)) + def test_conv_bn_fusion_not_running_state(self): + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) + self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + model = M().eval() + + traced = symbolic_trace(model) + fused = optimization.fuse(traced) + inp = torch.randn([1, 32, 50, 50]) + + # bn need not be folded in conv + self.assertTrue( + any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) + ) + self.assertEqual(fused(inp), model(inp)) + def test_call_to_assert_no_msg(self): class M(torch.nn.Module): def forward(self, a, b): diff --git a/torch/csrc/jit/passes/frozen_conv_folding.cpp b/torch/csrc/jit/passes/frozen_conv_folding.cpp index 0aa674e1a8c..4b5535af9ab 100644 --- a/torch/csrc/jit/passes/frozen_conv_folding.cpp +++ b/torch/csrc/jit/passes/frozen_conv_folding.cpp @@ -62,6 +62,15 @@ void FoldFrozenConvBatchnorm(Block* b) { continue; } + auto bn_rm_ivalue = bn->namedInput("running_mean"); + auto bn_rv_ivalue = bn->namedInput("running_var"); + // check running_mean and running_var has value, if they are + // None(track_running_stats=False), skiping the folding path. + if (bn_rm_ivalue->type() == NoneType::get() && + bn_rv_ivalue->type() == NoneType::get()) { + continue; + } + auto bn_rm = constant_as(bn->namedInput("running_mean")).value(); auto bn_rv = constant_as(bn->namedInput("running_var")).value(); auto bn_eps = constant_as(bn->namedInput("eps")).value(); diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 7016556e396..595dbfa4308 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -68,6 +68,8 @@ def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module: continue conv = modules[node.args[0].target] bn = modules[node.target] + if not bn.track_running_stats: + continue fused_conv = fuse_conv_bn_eval(conv, bn) replace_node_module(node.args[0], modules, fused_conv) node.replace_all_uses_with(node.args[0])