mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
ns for fx: allow comparing int8 to int8 for functionals (#56742)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56742 Fixes a bug to allow shadowing of linear and conv functionals. The bug is to only detach tensors, not all objects. Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_int8_shadows_int8_fun ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D27960767 fbshipit-source-id: abc911ca4b9edafd1effb9dada7731981538c2df
This commit is contained in:
parent
92c7aec5f5
commit
502c58ad84
|
|
@ -1119,12 +1119,10 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
m, (torch.randn(4, 4),),
|
||||
results_len=2)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_int8_shadows_int8(self):
|
||||
def _test_int8_shadows_int8_impl(self, m):
|
||||
"""
|
||||
Verify that shadowing works where both modules are int8
|
||||
"""
|
||||
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
|
||||
qconfig_dict = {'': torch.quantization.default_qconfig}
|
||||
mp = prepare_fx(m, qconfig_dict)
|
||||
mp(torch.randn(4, 1, 4, 4))
|
||||
|
|
@ -1137,6 +1135,16 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
self.assertTrue(len(act_compare_dict) == 1)
|
||||
self.assert_ns_compare_dict_valid(act_compare_dict)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_int8_shadows_int8_mod(self):
|
||||
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
|
||||
self._test_int8_shadows_int8_impl(m)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_int8_shadows_int8_fun(self):
|
||||
m = LinearFunctional().eval()
|
||||
self._test_int8_shadows_int8_impl(m)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_user_module_scriptable(self):
|
||||
# Logging of the output of this class is not supported, because it is
|
||||
|
|
|
|||
|
|
@ -245,7 +245,9 @@ def _copy_node_from_a_to_c(
|
|||
node_a_copy_name = \
|
||||
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) # type: ignore
|
||||
node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore
|
||||
setattr(gm_b, node_a_copy_name, node_a_obj.detach())
|
||||
if torch.is_tensor(node_a_obj):
|
||||
node_a_obj = node_a_obj.detach()
|
||||
setattr(gm_b, node_a_copy_name, node_a_obj)
|
||||
node_a_copy = graph_c.create_node(
|
||||
node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
|
||||
return node_a_copy
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user