[quant][fx][pyper] Get first linear use of quantize_per_tensor for FQN (#54859)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54859

This is applicable to the case when a call_function linear op is one of the users of quantize op
In order to be able to map the qparams of quantize_per_tensor to the qparams of the linear operator
that consumes it, we need to use the FQN of the module with linear op for the qparmas of quantize_per_tensor.

Test Plan:
python test/test_quantization.py test_qparams_fqn

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D27390505

fbshipit-source-id: a47af0e5ac016f2b2df74fbdf45afe99dc04be46
This commit is contained in:
Supriya Rao 2021-03-30 08:30:18 -07:00 committed by Facebook GitHub Bot
parent c690ed0ae8
commit a7dc0ab845
2 changed files with 59 additions and 5 deletions

View File

@ -2012,6 +2012,52 @@ class TestQuantizeFx(QuantizationTestCase):
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m)
def test_qparams_fqn(self):
""" Test that the FQN of input_scale/zero_point is set
to that of first linear use. """
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.ones(5, 5)
self.b = torch.zeros(5)
def forward(self, x):
return torch.nn.functional.linear(x, self.w, self.b)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods1 = torch.nn.Sequential(
Linear(),
Linear()
)
def forward(self, x):
x = torch.cat((x,), 1)
tmp = x.size()
x = self.mods1(x)
y = x * tmp[0]
return y
model = M().eval()
qconfig_dict = {
"": None,
"object_type": [
(torch.nn.functional.linear, default_qconfig),
(torch.nn.functional.relu, default_qconfig),
],
}
m = prepare_fx(model, qconfig_dict)
m(torch.rand(5, 5))
m = convert_fx(m)
keys = m.state_dict().keys()
m(torch.randn(5, 5))
for attr_name in [
"mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
"mods1_0_scale_0", "mods1_0_zero_point_0",
"mods1_1_scale_0", "mods1_1_zero_point_0"]:
self.assertTrue(hasattr(m, attr_name))
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops

View File

@ -131,17 +131,25 @@ def quantize_node(quantizer, in_node, obs_module, obs_node, is_input):
# Find the first use of the observer node, we use this to get the scope of the module.
if is_input:
# if the quantize function is at the input of op, then we find the first user of the observer_node
# to get the path
# to get the path. If a linear call_function is in the user list, we return the first instance
# of linear node to get the FQN.
users = list(obs_node.users)
first_use = users[0] if users else None
first_linear_use_or_first_use = users[0] if users else None
linear_node = None
for n in users:
if n.op == "call_function" and n.target == torch.nn.functional.linear:
linear_node = n
break
if linear_node:
first_linear_use_or_first_use = linear_node
prefix = "_input"
else:
# if the quantize function is at the output of the op, we use the observer input node to get the path
first_use = in_node
first_linear_use_or_first_use = in_node
prefix = "_output"
if first_use:
module_path, _ = quantizer.node_name_to_scope[first_use.name]
if first_linear_use_or_first_use:
module_path, _ = quantizer.node_name_to_scope[first_linear_use_or_first_use.name]
else:
# TODO: it's not used, so actually we can skip quantization
# but this requires changing return type of quantize_node