ns for fx: allow comparing int8 to int8 for functionals (#56742)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56742

Fixes a bug to allow shadowing of linear and conv functionals.
The bug is to only detach tensors, not all objects.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_int8_shadows_int8_fun
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D27960767

fbshipit-source-id: abc911ca4b9edafd1effb9dada7731981538c2df
This commit is contained in:
Vasiliy Kuznetsov 2021-04-26 16:58:35 -07:00 committed by Facebook GitHub Bot
parent 92c7aec5f5
commit 502c58ad84
2 changed files with 14 additions and 4 deletions

View File

@ -1119,12 +1119,10 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
m, (torch.randn(4, 4),),
results_len=2)
@skipIfNoFBGEMM
def test_int8_shadows_int8(self):
def _test_int8_shadows_int8_impl(self, m):
"""
Verify that shadowing works where both modules are int8
"""
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
qconfig_dict = {'': torch.quantization.default_qconfig}
mp = prepare_fx(m, qconfig_dict)
mp(torch.randn(4, 1, 4, 4))
@ -1137,6 +1135,16 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
self.assertTrue(len(act_compare_dict) == 1)
self.assert_ns_compare_dict_valid(act_compare_dict)
@skipIfNoFBGEMM
def test_int8_shadows_int8_mod(self):
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
self._test_int8_shadows_int8_impl(m)
@skipIfNoFBGEMM
def test_int8_shadows_int8_fun(self):
m = LinearFunctional().eval()
self._test_int8_shadows_int8_impl(m)
@skipIfNoFBGEMM
def test_user_module_scriptable(self):
# Logging of the output of this class is not supported, because it is

View File

@ -245,7 +245,9 @@ def _copy_node_from_a_to_c(
node_a_copy_name = \
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) # type: ignore
node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore
setattr(gm_b, node_a_copy_name, node_a_obj.detach())
if torch.is_tensor(node_a_obj):
node_a_obj = node_a_obj.detach()
setattr(gm_b, node_a_copy_name, node_a_obj)
node_a_copy = graph_c.create_node(
node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
return node_a_copy