mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][be] Add TestPT2ERepresentation test case (#108923)
Summary: att Test Plan: python test/test_quantization.py TestPT2ERepresentation Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/108923 Approved by: https://github.com/andrewor14
This commit is contained in:
parent
064ae9ff33
commit
c914ca7577
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
import copy
|
||||
import operator
|
||||
from typing import Any, List, Optional, Tuple, Dict
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
|
|
@ -72,177 +72,14 @@ from torch.testing._internal.common_quantization import (
|
|||
QuantizationTestCase,
|
||||
skip_if_no_torchvision,
|
||||
skipIfNoQNNPACK,
|
||||
TestHelperModules,
|
||||
)
|
||||
from torch.ao.quantization import (
|
||||
default_dynamic_qconfig,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401
|
||||
from torch._export import dynamic_dim
|
||||
|
||||
import unittest
|
||||
|
||||
# TODO: Move to common utils or use existing quant utils to fetch model instances
|
||||
class TestHelperModules:
|
||||
class Conv2dPropAnnotaton(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.view(-1, 3)
|
||||
x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class Conv2dWithObsSharingOps(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.hardtanh = torch.nn.Hardtanh()
|
||||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
x = self.hardtanh(x)
|
||||
x = torch.mean(x)
|
||||
return x
|
||||
|
||||
class Conv2dWithTwoLinearPermute(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 3)
|
||||
self.linear1 = torch.nn.Linear(16, 8, bias=False)
|
||||
self.linear2 = torch.nn.Linear(8, 8)
|
||||
|
||||
def forward(self, x):
|
||||
conv_out = self.conv(x)
|
||||
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
|
||||
return self.linear2(self.linear1(permute_out))
|
||||
|
||||
class Conv2dWithTwoLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 3)
|
||||
self.linear1 = torch.nn.Linear(64, 8, bias=False)
|
||||
self.linear2 = torch.nn.Linear(8, 8)
|
||||
|
||||
def forward(self, x):
|
||||
conv_out = self.conv(x)
|
||||
reshape_out = torch.reshape(conv_out, (2, 64))
|
||||
return self.linear2(self.linear1(reshape_out))
|
||||
|
||||
class ConvLinearWPermute(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 8, 3)
|
||||
self.linear1 = torch.nn.Linear(8, 8)
|
||||
|
||||
def forward(self, x):
|
||||
conv_out = self.conv(x)
|
||||
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
|
||||
return self.linear1(permute_out)
|
||||
|
||||
class TwoLinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(8, 16, bias=False)
|
||||
self.linear2 = torch.nn.Linear(16, 8)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.linear1(x))
|
||||
|
||||
class ConvMaxPool2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TestHelperModules.ConvMaxPool2d, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(2, 2, 1)
|
||||
self.pool = torch.nn.MaxPool2d(1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
return x
|
||||
|
||||
class ConvWithBNRelu(torch.nn.Module):
|
||||
def __init__(self, relu, bn=True, bias=True):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias)
|
||||
if bn:
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
else:
|
||||
self.bn = torch.nn.Identity()
|
||||
if relu:
|
||||
self.relu = torch.nn.ReLU()
|
||||
else:
|
||||
self.relu = torch.nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return self.relu(x)
|
||||
|
||||
class Conv2dWithCat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
||||
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.conv1(x)
|
||||
y = self.conv2(y)
|
||||
z = torch.cat([x, y], dim=1)
|
||||
return z
|
||||
|
||||
class EmbeddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
|
||||
|
||||
def forward(self, indices):
|
||||
return self.emb(indices)
|
||||
|
||||
class EmbeddingConvLinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
|
||||
self.conv = torch.nn.Conv2d(8, 16, (1, 3))
|
||||
self.linear = torch.nn.Linear(16, 8)
|
||||
|
||||
def forward(self, indices):
|
||||
embeddings = self.emb(indices)
|
||||
embeddings = torch.unsqueeze(embeddings, dim=0)
|
||||
embeddings = torch.permute(embeddings, (0, 3, 1, 2))
|
||||
conv_out = self.conv(embeddings)
|
||||
conv_out = torch.permute(conv_out, (0, 2, 3, 1))
|
||||
conv_out = torch.squeeze(conv_out, dim=0)
|
||||
return self.linear(conv_out)
|
||||
|
||||
class AddInplaceAdd(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x = x + y
|
||||
x += y
|
||||
return x
|
||||
|
||||
class MulInplaceMul(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x = x * y
|
||||
x *= y
|
||||
return x
|
||||
|
||||
|
||||
class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
"""
|
||||
|
|
@ -258,7 +95,6 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
}
|
||||
|
||||
|
||||
|
||||
def _test_quantizer(
|
||||
self,
|
||||
model,
|
||||
|
|
@ -270,6 +106,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
fx_qconfig_mapping=None,
|
||||
export_with_dynamic_shape=False,
|
||||
):
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
m_eager = model.eval()
|
||||
|
||||
# program capture
|
||||
|
|
@ -342,6 +180,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
Helper method to verify that the QAT numerics for PT2E quantization match those of
|
||||
FX graph mode quantization for symmetric qnnpack.
|
||||
"""
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
MANUAL_SEED = 100
|
||||
|
||||
# PT2 export
|
||||
|
|
@ -427,6 +267,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
with fake quantizes inserted into the correct places.
|
||||
# TODO: also verify that metadata is copied over to the new nodes.
|
||||
"""
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
m = copy.deepcopy(m)
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(
|
||||
|
|
@ -549,60 +391,6 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
self.assertTrue("tensor_constant" in bn_running_var_node.target)
|
||||
self.assertEqual(eps, 1e-5)
|
||||
|
||||
def _test_representation(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
quantizer: Quantizer,
|
||||
ref_node_occurrence: Dict[ns, int],
|
||||
non_ref_node_occurrence: Dict[ns, int],
|
||||
fixed_output_tol: float = None,
|
||||
output_scale_idx: int = 3,
|
||||
) -> torch.nn.Module:
|
||||
""" TODO: need to implement output checking based on output_scale once
|
||||
torchdynamo issue is resolved
|
||||
"""
|
||||
# program capture
|
||||
model = capture_pre_autograd_graph(
|
||||
model,
|
||||
example_inputs,
|
||||
)
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
model = prepare_pt2e(model, quantizer)
|
||||
# Calibrate
|
||||
model(*example_inputs)
|
||||
model = convert_pt2e(model, use_reference_representation=True)
|
||||
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
|
||||
# make sure it runs
|
||||
pt2e_quant_output = model(*example_inputs)
|
||||
|
||||
# TODO: torchdynamo times out when we do this, we can enable numerical checking
|
||||
# after that is fixed
|
||||
model_copy = prepare_pt2e(model_copy, quantizer)
|
||||
# Calibrate
|
||||
model_copy(*example_inputs)
|
||||
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
|
||||
self.checkGraphModuleNodes(model_copy, expected_node_occurrence=non_ref_node_occurrence)
|
||||
pt2e_quant_output_copy = model_copy(*example_inputs)
|
||||
|
||||
|
||||
output_tol = None
|
||||
if fixed_output_tol is not None:
|
||||
output_tol = fixed_output_tol
|
||||
else:
|
||||
idx = 0
|
||||
for n in model_copy.graph.nodes:
|
||||
if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
|
||||
idx += 1
|
||||
if idx == output_scale_idx:
|
||||
output_tol = n.args[1]
|
||||
assert output_tol is not None
|
||||
|
||||
# make sure the result is off by one at most in the quantized integer representation
|
||||
self.assertTrue(
|
||||
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_tol + 1e-5)
|
||||
)
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
|
@ -2195,242 +1983,6 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
self._verify_symmetric_xnnpack_qat_numerics(M(), example_inputs)
|
||||
|
||||
def test_representation_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 5),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_dynamic_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False, is_dynamic=True)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 5),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={},
|
||||
fixed_output_tol=1e-4,
|
||||
)
|
||||
|
||||
def test_representation_conv2d(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv2d = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2d(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_add(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(quantization_config)
|
||||
m_eager = M().eval()
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_add_relu(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
out = x + y
|
||||
out = torch.nn.functional.relu(out)
|
||||
return out
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),)
|
||||
ref_node_occurrence = {
|
||||
ns.call_function(out_dtype): 2,
|
||||
}
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence=ref_node_occurrence,
|
||||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_maxpool2d(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
m_eager = TestHelperModules.ConvMaxPool2d().eval()
|
||||
|
||||
example_inputs = (torch.randn(1, 2, 2, 2),)
|
||||
|
||||
self._test_representation(
|
||||
m_eager,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
@unittest.skip("will fix later")
|
||||
def test_representation_adaptive_avg_pool2d(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
m_eager = TestHelperModules.ConvWithAdaptiveAvgPool2d().eval()
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
|
||||
self._test_representation(
|
||||
m_eager,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_quantize_dequantize_per_channel(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
# use per channel quantization for weight
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
m_eager = M().eval()
|
||||
|
||||
inputs = [
|
||||
(torch.randn(1, 5),),
|
||||
(torch.randn(1, 3, 5),),
|
||||
(torch.randn(1, 3, 3, 5),),
|
||||
(torch.randn(1, 3, 3, 3, 5),),
|
||||
]
|
||||
for example_inputs in inputs:
|
||||
ref_node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default
|
||||
): 0,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
): 0,
|
||||
}
|
||||
non_ref_node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default
|
||||
): 1,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
): 1,
|
||||
}
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence,
|
||||
non_ref_node_occurrence,
|
||||
output_scale_idx=2,
|
||||
)
|
||||
|
||||
def test_representation_quantize_dequantize(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(quantization_config)
|
||||
m_eager = M().eval()
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),)
|
||||
ref_node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor
|
||||
): 0,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
): 0,
|
||||
}
|
||||
non_ref_node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
): 3,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 3,
|
||||
}
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence,
|
||||
non_ref_node_occurrence
|
||||
)
|
||||
|
||||
def test_move_exported_model_to_eval(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -2611,6 +2163,7 @@ class TestQuantizePT2EOps(QuantizationTestCase):
|
|||
model_graph = convert_pt2e(model_graph)
|
||||
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
|
||||
|
||||
|
||||
# TODO: express this using self._test_quantizer, add test for inception_v4
|
||||
class TestQuantizePT2EModels(PT2EQuantizationTestCase):
|
||||
@skip_if_no_torchvision
|
||||
|
|
|
|||
328
test/quantization/pt2e/test_representation.py
Normal file
328
test/quantization/pt2e/test_representation.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
import copy
|
||||
import unittest
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer import Quantizer
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
XNNPACKQuantizer,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
QuantizationTestCase,
|
||||
skipIfNoQNNPACK,
|
||||
TestHelperModules,
|
||||
)
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestPT2ERepresentation(QuantizationTestCase):
|
||||
def _test_representation(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
quantizer: Quantizer,
|
||||
ref_node_occurrence: Dict[ns, int],
|
||||
non_ref_node_occurrence: Dict[ns, int],
|
||||
fixed_output_tol: float = None,
|
||||
output_scale_idx: int = 3,
|
||||
) -> torch.nn.Module:
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
model = capture_pre_autograd_graph(
|
||||
model,
|
||||
example_inputs,
|
||||
)
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
model = prepare_pt2e(model, quantizer)
|
||||
# Calibrate
|
||||
model(*example_inputs)
|
||||
model = convert_pt2e(model, use_reference_representation=True)
|
||||
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
|
||||
# make sure it runs
|
||||
pt2e_quant_output = model(*example_inputs)
|
||||
|
||||
# TODO: torchdynamo times out when we do this, we can enable numerical checking
|
||||
# after that is fixed
|
||||
model_copy = prepare_pt2e(model_copy, quantizer)
|
||||
# Calibrate
|
||||
model_copy(*example_inputs)
|
||||
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
|
||||
self.checkGraphModuleNodes(
|
||||
model_copy, expected_node_occurrence=non_ref_node_occurrence
|
||||
)
|
||||
pt2e_quant_output_copy = model_copy(*example_inputs)
|
||||
|
||||
output_tol = None
|
||||
if fixed_output_tol is not None:
|
||||
output_tol = fixed_output_tol
|
||||
else:
|
||||
idx = 0
|
||||
for n in model_copy.graph.nodes:
|
||||
if (
|
||||
n.target
|
||||
== torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
):
|
||||
idx += 1
|
||||
if idx == output_scale_idx:
|
||||
output_tol = n.args[1]
|
||||
assert output_tol is not None
|
||||
|
||||
# make sure the result is off by one at most in the quantized integer representation
|
||||
self.assertTrue(
|
||||
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output))
|
||||
<= (2 * output_tol + 1e-5)
|
||||
)
|
||||
|
||||
def test_static_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 5),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={},
|
||||
)
|
||||
|
||||
def test_dynamic_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(
|
||||
is_per_channel=False, is_dynamic=True
|
||||
)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 5),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={},
|
||||
fixed_output_tol=1e-4,
|
||||
)
|
||||
|
||||
def test_conv2d(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv2d = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2d(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={},
|
||||
)
|
||||
|
||||
def test_add(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(quantization_config)
|
||||
m_eager = M().eval()
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(1, 3, 3, 3),
|
||||
torch.randn(1, 3, 3, 3),
|
||||
)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={},
|
||||
)
|
||||
|
||||
def test_add_relu(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
out = x + y
|
||||
out = torch.nn.functional.relu(out)
|
||||
return out
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(1, 3, 3, 3),
|
||||
torch.randn(1, 3, 3, 3),
|
||||
)
|
||||
ref_node_occurrence = {
|
||||
ns.call_function(out_dtype): 2,
|
||||
}
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence=ref_node_occurrence,
|
||||
non_ref_node_occurrence={},
|
||||
)
|
||||
|
||||
def test_maxpool2d(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
m_eager = TestHelperModules.ConvMaxPool2d().eval()
|
||||
|
||||
example_inputs = (torch.randn(1, 2, 2, 2),)
|
||||
|
||||
self._test_representation(
|
||||
m_eager,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={},
|
||||
)
|
||||
|
||||
@unittest.skip("will fix later")
|
||||
def test_representation_adaptive_avg_pool2d(self):
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
m_eager = TestHelperModules.ConvWithAdaptiveAvgPool2d().eval()
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
|
||||
self._test_representation(
|
||||
m_eager,
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={},
|
||||
)
|
||||
|
||||
def test_qdq_per_channel(self):
|
||||
"""Test representation for quantize_per_channel and dequantize_per_channel op"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
# use per channel quantization for weight
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(operator_config)
|
||||
m_eager = M().eval()
|
||||
|
||||
inputs = [
|
||||
(torch.randn(1, 5),),
|
||||
(torch.randn(1, 3, 5),),
|
||||
(torch.randn(1, 3, 3, 5),),
|
||||
(torch.randn(1, 3, 3, 3, 5),),
|
||||
]
|
||||
for example_inputs in inputs:
|
||||
ref_node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default
|
||||
): 0,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
): 0,
|
||||
}
|
||||
non_ref_node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel.default
|
||||
): 1,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
): 1,
|
||||
}
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence,
|
||||
non_ref_node_occurrence,
|
||||
output_scale_idx=2,
|
||||
)
|
||||
|
||||
def test_qdq(self):
|
||||
"""Test representation for quantize and dequantize op"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
quantizer.set_global(quantization_config)
|
||||
m_eager = M().eval()
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(1, 3, 3, 3),
|
||||
torch.randn(1, 3, 3, 3),
|
||||
)
|
||||
ref_node_occurrence = {
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0,
|
||||
}
|
||||
non_ref_node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
): 3,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 3,
|
||||
}
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence,
|
||||
non_ref_node_occurrence,
|
||||
)
|
||||
|
|
@ -86,6 +86,7 @@ try:
|
|||
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EOps # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401
|
||||
from quantization.pt2e.test_representation import TestPT2ERepresentation # noqa: F401
|
||||
from quantization.pt2e.test_x86inductor_quantizer import TestQuantizePT2EX86Inductor # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e_fx import TestQuantizePT2EFX # noqa: F401
|
||||
from quantization.pt2e.test_quantize_pt2e_fx import TestQuantizePT2EFXModels # noqa: F401
|
||||
|
|
|
|||
|
|
@ -2431,3 +2431,163 @@ class SparseNNModel(nn.Module):
|
|||
out = self.dense_top(sparse_feature, dense)
|
||||
|
||||
return out
|
||||
|
||||
class TestHelperModules:
|
||||
class Conv2dPropAnnotaton(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.view(-1, 3)
|
||||
x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class Conv2dWithObsSharingOps(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.hardtanh = torch.nn.Hardtanh()
|
||||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
x = self.hardtanh(x)
|
||||
x = torch.mean(x)
|
||||
return x
|
||||
|
||||
class Conv2dWithTwoLinearPermute(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 3)
|
||||
self.linear1 = torch.nn.Linear(16, 8, bias=False)
|
||||
self.linear2 = torch.nn.Linear(8, 8)
|
||||
|
||||
def forward(self, x):
|
||||
conv_out = self.conv(x)
|
||||
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
|
||||
return self.linear2(self.linear1(permute_out))
|
||||
|
||||
class Conv2dWithTwoLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 3)
|
||||
self.linear1 = torch.nn.Linear(64, 8, bias=False)
|
||||
self.linear2 = torch.nn.Linear(8, 8)
|
||||
|
||||
def forward(self, x):
|
||||
conv_out = self.conv(x)
|
||||
reshape_out = torch.reshape(conv_out, (2, 64))
|
||||
return self.linear2(self.linear1(reshape_out))
|
||||
|
||||
class ConvLinearWPermute(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 8, 3)
|
||||
self.linear1 = torch.nn.Linear(8, 8)
|
||||
|
||||
def forward(self, x):
|
||||
conv_out = self.conv(x)
|
||||
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
|
||||
return self.linear1(permute_out)
|
||||
|
||||
class TwoLinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(8, 16, bias=False)
|
||||
self.linear2 = torch.nn.Linear(16, 8)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.linear1(x))
|
||||
|
||||
class ConvMaxPool2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(2, 2, 1)
|
||||
self.pool = torch.nn.MaxPool2d(1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
return x
|
||||
|
||||
class ConvWithBNRelu(torch.nn.Module):
|
||||
def __init__(self, relu, bn=True, bias=True):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias)
|
||||
if bn:
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
else:
|
||||
self.bn = torch.nn.Identity()
|
||||
if relu:
|
||||
self.relu = torch.nn.ReLU()
|
||||
else:
|
||||
self.relu = torch.nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return self.relu(x)
|
||||
|
||||
class Conv2dWithCat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
||||
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.conv1(x)
|
||||
y = self.conv2(y)
|
||||
z = torch.cat([x, y], dim=1)
|
||||
return z
|
||||
|
||||
class EmbeddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
|
||||
|
||||
def forward(self, indices):
|
||||
return self.emb(indices)
|
||||
|
||||
class EmbeddingConvLinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
|
||||
self.conv = torch.nn.Conv2d(8, 16, (1, 3))
|
||||
self.linear = torch.nn.Linear(16, 8)
|
||||
|
||||
def forward(self, indices):
|
||||
embeddings = self.emb(indices)
|
||||
embeddings = torch.unsqueeze(embeddings, dim=0)
|
||||
embeddings = torch.permute(embeddings, (0, 3, 1, 2))
|
||||
conv_out = self.conv(embeddings)
|
||||
conv_out = torch.permute(conv_out, (0, 2, 3, 1))
|
||||
conv_out = torch.squeeze(conv_out, dim=0)
|
||||
return self.linear(conv_out)
|
||||
|
||||
class AddInplaceAdd(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x = x + y
|
||||
x += y
|
||||
return x
|
||||
|
||||
class MulInplaceMul(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x = x * y
|
||||
x *= y
|
||||
return x
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user