mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[quant][fx][pyper] Get first linear use of quantize_per_tensor for FQN (#54859)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54859 This is applicable to the case when a call_function linear op is one of the users of quantize op In order to be able to map the qparams of quantize_per_tensor to the qparams of the linear operator that consumes it, we need to use the FQN of the module with linear op for the qparmas of quantize_per_tensor. Test Plan: python test/test_quantization.py test_qparams_fqn Imported from OSS Reviewed By: jerryzh168 Differential Revision: D27390505 fbshipit-source-id: a47af0e5ac016f2b2df74fbdf45afe99dc04be46
This commit is contained in:
parent
c690ed0ae8
commit
a7dc0ab845
|
|
@ -2012,6 +2012,52 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
m = prepare_fx(m, qconfig_dict)
|
||||
m = convert_fx(m)
|
||||
|
||||
def test_qparams_fqn(self):
|
||||
""" Test that the FQN of input_scale/zero_point is set
|
||||
to that of first linear use. """
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w = torch.ones(5, 5)
|
||||
self.b = torch.zeros(5)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.w, self.b)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mods1 = torch.nn.Sequential(
|
||||
Linear(),
|
||||
Linear()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.cat((x,), 1)
|
||||
tmp = x.size()
|
||||
x = self.mods1(x)
|
||||
y = x * tmp[0]
|
||||
return y
|
||||
|
||||
model = M().eval()
|
||||
qconfig_dict = {
|
||||
"": None,
|
||||
"object_type": [
|
||||
(torch.nn.functional.linear, default_qconfig),
|
||||
(torch.nn.functional.relu, default_qconfig),
|
||||
],
|
||||
}
|
||||
m = prepare_fx(model, qconfig_dict)
|
||||
m(torch.rand(5, 5))
|
||||
m = convert_fx(m)
|
||||
keys = m.state_dict().keys()
|
||||
m(torch.randn(5, 5))
|
||||
for attr_name in [
|
||||
"mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
|
||||
"mods1_0_scale_0", "mods1_0_zero_point_0",
|
||||
"mods1_1_scale_0", "mods1_1_zero_point_0"]:
|
||||
self.assertTrue(hasattr(m, attr_name))
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
"""Unit tests for individual ops
|
||||
|
|
|
|||
|
|
@ -131,17 +131,25 @@ def quantize_node(quantizer, in_node, obs_module, obs_node, is_input):
|
|||
# Find the first use of the observer node, we use this to get the scope of the module.
|
||||
if is_input:
|
||||
# if the quantize function is at the input of op, then we find the first user of the observer_node
|
||||
# to get the path
|
||||
# to get the path. If a linear call_function is in the user list, we return the first instance
|
||||
# of linear node to get the FQN.
|
||||
users = list(obs_node.users)
|
||||
first_use = users[0] if users else None
|
||||
first_linear_use_or_first_use = users[0] if users else None
|
||||
linear_node = None
|
||||
for n in users:
|
||||
if n.op == "call_function" and n.target == torch.nn.functional.linear:
|
||||
linear_node = n
|
||||
break
|
||||
if linear_node:
|
||||
first_linear_use_or_first_use = linear_node
|
||||
prefix = "_input"
|
||||
else:
|
||||
# if the quantize function is at the output of the op, we use the observer input node to get the path
|
||||
first_use = in_node
|
||||
first_linear_use_or_first_use = in_node
|
||||
prefix = "_output"
|
||||
|
||||
if first_use:
|
||||
module_path, _ = quantizer.node_name_to_scope[first_use.name]
|
||||
if first_linear_use_or_first_use:
|
||||
module_path, _ = quantizer.node_name_to_scope[first_linear_use_or_first_use.name]
|
||||
else:
|
||||
# TODO: it's not used, so actually we can skip quantization
|
||||
# but this requires changing return type of quantize_node
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user