diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index a5b02c03cc9..331fd38f9af 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -1606,13 +1606,14 @@ class TestFrozenOptimizations(JitTestCase): conv_bias = [True, False] module_pairs = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)] use_tracing = [True, False] + bn_running_stats = [True, False] - for use_bias, modules, tracing in product(conv_bias, module_pairs, use_tracing): + for use_bias, modules, tracing, track_stats in product(conv_bias, module_pairs, use_tracing, bn_running_stats): class ConvBN(torch.nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super(ConvBN, self).__init__() self.conv = modules[0](in_channels, out_channels, bias=use_bias, **kwargs) - self.bn = modules[1](out_channels, eps=0.001) + self.bn = modules[1](out_channels, eps=0.001, track_running_stats=track_stats) def forward(self, x): x = self.conv(x) @@ -1644,7 +1645,10 @@ class TestFrozenOptimizations(JitTestCase): scripted_mod = torch.jit.freeze(scripted_mod) self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) - FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) + if track_stats: + FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) + else: + FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph) self.assertEqual(mod_eager(inp), scripted_mod(inp)) self.assertEqual(mod_eager(inp), scripted_mod(inp)) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 75283ad6c9d..a35c94a3dc8 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -667,6 +667,30 @@ class TestFXExperimental(JitTestCase): self.assertEqual(fused(inp), rn18(inp)) + def test_conv_bn_fusion_not_running_state(self): + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) + self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + model = M().eval() + + traced = symbolic_trace(model) + fused = optimization.fuse(traced) + inp = torch.randn([1, 32, 50, 50]) + + # bn need not be folded in conv + self.assertTrue( + any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) + ) + self.assertEqual(fused(inp), model(inp)) + def test_call_to_assert_no_msg(self): class M(torch.nn.Module): def forward(self, a, b): diff --git a/torch/csrc/jit/passes/frozen_conv_folding.cpp b/torch/csrc/jit/passes/frozen_conv_folding.cpp index 0aa674e1a8c..4b5535af9ab 100644 --- a/torch/csrc/jit/passes/frozen_conv_folding.cpp +++ b/torch/csrc/jit/passes/frozen_conv_folding.cpp @@ -62,6 +62,15 @@ void FoldFrozenConvBatchnorm(Block* b) { continue; } + auto bn_rm_ivalue = bn->namedInput("running_mean"); + auto bn_rv_ivalue = bn->namedInput("running_var"); + // check running_mean and running_var has value, if they are + // None(track_running_stats=False), skiping the folding path. + if (bn_rm_ivalue->type() == NoneType::get() && + bn_rv_ivalue->type() == NoneType::get()) { + continue; + } + auto bn_rm = constant_as(bn->namedInput("running_mean")).value(); auto bn_rv = constant_as(bn->namedInput("running_var")).value(); auto bn_eps = constant_as(bn->namedInput("eps")).value(); diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 7016556e396..595dbfa4308 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -68,6 +68,8 @@ def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module: continue conv = modules[node.args[0].target] bn = modules[node.target] + if not bn.track_running_stats: + continue fused_conv = fuse_conv_bn_eval(conv, bn) replace_node_module(node.args[0], modules, fused_conv) node.replace_all_uses_with(node.args[0])