Minor fix and turn off fold_convbn (#27403)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27403

In fold_convbn pass, we need to recompute the parameter(weight, bias) for
conv, update the attribute of conv and update the access of bias in conv
because if the original conv have no bias, the `self.bias` access will be
inline and replaced by Constant node `None = prim::Constant()`, we need to
update this to use `GetAttr[name="bias"]` to make this work. But there is
also some work going on the handle constants, so we'll fix this pass after
that is done.

Test Plan:
.

Imported from OSS

Differential Revision: D18182918

fbshipit-source-id: bba510bc41ab58e0eb76f7b77335b6e3ffe2862d
This commit is contained in:
Jerry Zhang 2019-11-01 12:13:18 -07:00 committed by Facebook Github Bot
parent d690521cf6
commit 5ac3df7712
3 changed files with 20 additions and 3 deletions

View File

@ -1301,6 +1301,11 @@ graph(%packed_params_module, %a, %a_scale, %a_zero_point, %a_dtype, %r_scale, %r
FileCheck().run(input_str, graph)
@_tmp_donotuse_dont_inline_everything
@unittest.skip("Temporarily turn off fold_convbn tests until \
constants are handled properly, this test should not be passing \
because bias is not handled properly, the reason is passes is because the \
parameters of bn are initialized to default values and the recomputed bias \
for conv is zero, which is equivalent to no bias")
def test_foldbn_trivial(self):
# Test trivial case
class TestModule(torch.nn.Module):
@ -1338,6 +1343,8 @@ graph(%packed_params_module, %a, %a_scale, %a_zero_point, %a_dtype, %r_scale, %r
self.assertAlmostEqual(eager(x), scripted(x), delta=1e-5)
@_tmp_donotuse_dont_inline_everything
@unittest.skip("Temporarily turn off fold_convbn tests until \
constants are handled properly")
def test_foldbn_trivial_nobias(self):
# Test trivial case
class TestModule(torch.nn.Module):
@ -1375,6 +1382,8 @@ graph(%packed_params_module, %a, %a_scale, %a_zero_point, %a_dtype, %r_scale, %r
self.assertAlmostEqual(eager(x), scripted(x), delta=1e-5)
@_tmp_donotuse_dont_inline_everything
@unittest.skip("Temporarily turn off fold_convbn tests until \
constants are handled properly")
def test_foldbn_in_submodule(self):
# Test that we find Conv-BN patterns in submodules
class SubModule(torch.nn.Module):

View File

@ -883,8 +883,12 @@ graph(%self, %x):
GRAPH_UPDATE("Deleting ", *matched_bn);
auto new_w_b = computeUpdatedConvWeightAndBias(params);
params.conv_w.set_data(std::get<0>(new_w_b));
params.conv_b.set_data(std::get<1>(new_w_b));
conv_submodule.set_parameter("weight", std::get<0>(new_w_b));
if (conv_submodule.find_parameter("bias")) {
conv_submodule.set_parameter("bias", std::get<1>(new_w_b));
} else {
conv_submodule.register_parameter("bias", std::get<1>(new_w_b), false);
}
}
// Perform planned rewritings

View File

@ -132,7 +132,11 @@ def quantize_script(model, qconfig_dict, run_fn, run_args, inplace=False):
if not inplace:
model = model.copy()
scripted_qconfig_dict = {k: script_qconfig(v) for k, v in qconfig_dict.items()}
torch._C._jit_pass_fold_convbn(model._c)
# We are not going to run fold_convbn pass right now
# since it is not able to work correctly, we will
# revisit after constants is properly handled in
# JIT
# torch._C._jit_pass_fold_convbn(model._c)
prepare_script(model, scripted_qconfig_dict, True)
run_fn(model._c._get_method('forward'), *run_args)
# When we mutating graph we didn't create a new ClassType