mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Minor fix and turn off fold_convbn (#27403)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27403 In fold_convbn pass, we need to recompute the parameter(weight, bias) for conv, update the attribute of conv and update the access of bias in conv because if the original conv have no bias, the `self.bias` access will be inline and replaced by Constant node `None = prim::Constant()`, we need to update this to use `GetAttr[name="bias"]` to make this work. But there is also some work going on the handle constants, so we'll fix this pass after that is done. Test Plan: . Imported from OSS Differential Revision: D18182918 fbshipit-source-id: bba510bc41ab58e0eb76f7b77335b6e3ffe2862d
This commit is contained in:
parent
d690521cf6
commit
5ac3df7712
|
|
@ -1301,6 +1301,11 @@ graph(%packed_params_module, %a, %a_scale, %a_zero_point, %a_dtype, %r_scale, %r
|
|||
FileCheck().run(input_str, graph)
|
||||
|
||||
@_tmp_donotuse_dont_inline_everything
|
||||
@unittest.skip("Temporarily turn off fold_convbn tests until \
|
||||
constants are handled properly, this test should not be passing \
|
||||
because bias is not handled properly, the reason is passes is because the \
|
||||
parameters of bn are initialized to default values and the recomputed bias \
|
||||
for conv is zero, which is equivalent to no bias")
|
||||
def test_foldbn_trivial(self):
|
||||
# Test trivial case
|
||||
class TestModule(torch.nn.Module):
|
||||
|
|
@ -1338,6 +1343,8 @@ graph(%packed_params_module, %a, %a_scale, %a_zero_point, %a_dtype, %r_scale, %r
|
|||
self.assertAlmostEqual(eager(x), scripted(x), delta=1e-5)
|
||||
|
||||
@_tmp_donotuse_dont_inline_everything
|
||||
@unittest.skip("Temporarily turn off fold_convbn tests until \
|
||||
constants are handled properly")
|
||||
def test_foldbn_trivial_nobias(self):
|
||||
# Test trivial case
|
||||
class TestModule(torch.nn.Module):
|
||||
|
|
@ -1375,6 +1382,8 @@ graph(%packed_params_module, %a, %a_scale, %a_zero_point, %a_dtype, %r_scale, %r
|
|||
self.assertAlmostEqual(eager(x), scripted(x), delta=1e-5)
|
||||
|
||||
@_tmp_donotuse_dont_inline_everything
|
||||
@unittest.skip("Temporarily turn off fold_convbn tests until \
|
||||
constants are handled properly")
|
||||
def test_foldbn_in_submodule(self):
|
||||
# Test that we find Conv-BN patterns in submodules
|
||||
class SubModule(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -883,8 +883,12 @@ graph(%self, %x):
|
|||
GRAPH_UPDATE("Deleting ", *matched_bn);
|
||||
|
||||
auto new_w_b = computeUpdatedConvWeightAndBias(params);
|
||||
params.conv_w.set_data(std::get<0>(new_w_b));
|
||||
params.conv_b.set_data(std::get<1>(new_w_b));
|
||||
conv_submodule.set_parameter("weight", std::get<0>(new_w_b));
|
||||
if (conv_submodule.find_parameter("bias")) {
|
||||
conv_submodule.set_parameter("bias", std::get<1>(new_w_b));
|
||||
} else {
|
||||
conv_submodule.register_parameter("bias", std::get<1>(new_w_b), false);
|
||||
}
|
||||
}
|
||||
|
||||
// Perform planned rewritings
|
||||
|
|
|
|||
|
|
@ -132,7 +132,11 @@ def quantize_script(model, qconfig_dict, run_fn, run_args, inplace=False):
|
|||
if not inplace:
|
||||
model = model.copy()
|
||||
scripted_qconfig_dict = {k: script_qconfig(v) for k, v in qconfig_dict.items()}
|
||||
torch._C._jit_pass_fold_convbn(model._c)
|
||||
# We are not going to run fold_convbn pass right now
|
||||
# since it is not able to work correctly, we will
|
||||
# revisit after constants is properly handled in
|
||||
# JIT
|
||||
# torch._C._jit_pass_fold_convbn(model._c)
|
||||
prepare_script(model, scripted_qconfig_dict, True)
|
||||
run_fn(model._c._get_method('forward'), *run_args)
|
||||
# When we mutating graph we didn't create a new ClassType
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user