[quant][be] Add TestPT2ERepresentation test case (#108923)

Summary:
att

Test Plan:
python test/test_quantization.py TestPT2ERepresentation
Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108923
Approved by: https://github.com/andrewor14
This commit is contained in:
Jerry Zhang 2023-09-13 15:09:58 -07:00 committed by PyTorch MergeBot
parent 064ae9ff33
commit c914ca7577
4 changed files with 498 additions and 456 deletions

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: quantization"]
import copy
import operator
from typing import Any, List, Optional, Tuple, Dict
from typing import Any, List, Optional, Tuple
import torch
import torch._dynamo as torchdynamo
@ -72,177 +72,14 @@ from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skip_if_no_torchvision,
skipIfNoQNNPACK,
TestHelperModules,
)
from torch.ao.quantization import (
default_dynamic_qconfig,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401
from torch._export import dynamic_dim
import unittest
# TODO: Move to common utils or use existing quant utils to fetch model instances
class TestHelperModules:
class Conv2dPropAnnotaton(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 3)
x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
x = self.linear(x)
return x
class Conv2dWithObsSharingOps(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = torch.mean(x)
return x
class Conv2dWithTwoLinearPermute(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
self.linear1 = torch.nn.Linear(16, 8, bias=False)
self.linear2 = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
return self.linear2(self.linear1(permute_out))
class Conv2dWithTwoLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
self.linear1 = torch.nn.Linear(64, 8, bias=False)
self.linear2 = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
reshape_out = torch.reshape(conv_out, (2, 64))
return self.linear2(self.linear1(reshape_out))
class ConvLinearWPermute(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 8, 3)
self.linear1 = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
return self.linear1(permute_out)
class TwoLinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 16, bias=False)
self.linear2 = torch.nn.Linear(16, 8)
def forward(self, x):
return self.linear2(self.linear1(x))
class ConvMaxPool2d(torch.nn.Module):
def __init__(self):
super(TestHelperModules.ConvMaxPool2d, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 1)
self.pool = torch.nn.MaxPool2d(1, 1)
def forward(self, x):
x = self.conv(x)
x = self.pool(x)
return x
class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
return x
class ConvWithBNRelu(torch.nn.Module):
def __init__(self, relu, bn=True, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias)
if bn:
self.bn = torch.nn.BatchNorm2d(3)
else:
self.bn = torch.nn.Identity()
if relu:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class Conv2dWithCat(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
z = torch.cat([x, y], dim=1)
return z
class EmbeddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
def forward(self, indices):
return self.emb(indices)
class EmbeddingConvLinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
self.conv = torch.nn.Conv2d(8, 16, (1, 3))
self.linear = torch.nn.Linear(16, 8)
def forward(self, indices):
embeddings = self.emb(indices)
embeddings = torch.unsqueeze(embeddings, dim=0)
embeddings = torch.permute(embeddings, (0, 3, 1, 2))
conv_out = self.conv(embeddings)
conv_out = torch.permute(conv_out, (0, 2, 3, 1))
conv_out = torch.squeeze(conv_out, dim=0)
return self.linear(conv_out)
class AddInplaceAdd(torch.nn.Module):
def forward(self, x, y):
x = x + y
x += y
return x
class MulInplaceMul(torch.nn.Module):
def forward(self, x, y):
x = x * y
x *= y
return x
class PT2EQuantizationTestCase(QuantizationTestCase):
"""
@ -258,7 +95,6 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
}
def _test_quantizer(
self,
model,
@ -270,6 +106,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
fx_qconfig_mapping=None,
export_with_dynamic_shape=False,
):
# resetting dynamo cache
torch._dynamo.reset()
m_eager = model.eval()
# program capture
@ -342,6 +180,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
Helper method to verify that the QAT numerics for PT2E quantization match those of
FX graph mode quantization for symmetric qnnpack.
"""
# resetting dynamo cache
torch._dynamo.reset()
MANUAL_SEED = 100
# PT2 export
@ -427,6 +267,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
with fake quantizes inserted into the correct places.
# TODO: also verify that metadata is copied over to the new nodes.
"""
# resetting dynamo cache
torch._dynamo.reset()
m = copy.deepcopy(m)
quantizer = XNNPACKQuantizer()
quantizer.set_global(
@ -549,60 +391,6 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
self.assertTrue("tensor_constant" in bn_running_var_node.target)
self.assertEqual(eps, 1e-5)
def _test_representation(
self,
model: torch.nn.Module,
example_inputs: Tuple[Any, ...],
quantizer: Quantizer,
ref_node_occurrence: Dict[ns, int],
non_ref_node_occurrence: Dict[ns, int],
fixed_output_tol: float = None,
output_scale_idx: int = 3,
) -> torch.nn.Module:
""" TODO: need to implement output checking based on output_scale once
torchdynamo issue is resolved
"""
# program capture
model = capture_pre_autograd_graph(
model,
example_inputs,
)
model_copy = copy.deepcopy(model)
model = prepare_pt2e(model, quantizer)
# Calibrate
model(*example_inputs)
model = convert_pt2e(model, use_reference_representation=True)
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
# make sure it runs
pt2e_quant_output = model(*example_inputs)
# TODO: torchdynamo times out when we do this, we can enable numerical checking
# after that is fixed
model_copy = prepare_pt2e(model_copy, quantizer)
# Calibrate
model_copy(*example_inputs)
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
self.checkGraphModuleNodes(model_copy, expected_node_occurrence=non_ref_node_occurrence)
pt2e_quant_output_copy = model_copy(*example_inputs)
output_tol = None
if fixed_output_tol is not None:
output_tol = fixed_output_tol
else:
idx = 0
for n in model_copy.graph.nodes:
if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
idx += 1
if idx == output_scale_idx:
output_tol = n.args[1]
assert output_tol is not None
# make sure the result is off by one at most in the quantized integer representation
self.assertTrue(
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_tol + 1e-5)
)
@skipIfNoQNNPACK
class TestQuantizePT2E(PT2EQuantizationTestCase):
@ -2195,242 +1983,6 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_xnnpack_qat_numerics(M(), example_inputs)
def test_representation_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 5),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={}
)
def test_representation_dynamic_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False, is_dynamic=True)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 5),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
fixed_output_tol=1e-4,
)
def test_representation_conv2d(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv2d(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(1, 3, 3, 3),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={}
)
def test_representation_add(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x + y
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
m_eager = M().eval()
example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={}
)
def test_representation_add_relu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
out = x + y
out = torch.nn.functional.relu(out)
return out
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),)
ref_node_occurrence = {
ns.call_function(out_dtype): 2,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence=ref_node_occurrence,
non_ref_node_occurrence={}
)
def test_representation_maxpool2d(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.ConvMaxPool2d().eval()
example_inputs = (torch.randn(1, 2, 2, 2),)
self._test_representation(
m_eager,
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={}
)
@unittest.skip("will fix later")
def test_representation_adaptive_avg_pool2d(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.ConvWithAdaptiveAvgPool2d().eval()
example_inputs = (torch.randn(1, 3, 3, 3),)
self._test_representation(
m_eager,
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={}
)
def test_representation_quantize_dequantize_per_channel(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
# use per channel quantization for weight
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()
inputs = [
(torch.randn(1, 5),),
(torch.randn(1, 3, 5),),
(torch.randn(1, 3, 3, 5),),
(torch.randn(1, 3, 3, 3, 5),),
]
for example_inputs in inputs:
ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 0,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 0,
}
non_ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 1,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 1,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence,
non_ref_node_occurrence,
output_scale_idx=2,
)
def test_representation_quantize_dequantize(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x + y
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
m_eager = M().eval()
example_inputs = (torch.randn(1, 3, 3, 3), torch.randn(1, 3, 3, 3),)
ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor
): 0,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
): 0,
}
non_ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 3,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence,
non_ref_node_occurrence
)
def test_move_exported_model_to_eval(self):
class M(torch.nn.Module):
def __init__(self):
@ -2611,6 +2163,7 @@ class TestQuantizePT2EOps(QuantizationTestCase):
model_graph = convert_pt2e(model_graph)
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
# TODO: express this using self._test_quantizer, add test for inception_v4
class TestQuantizePT2EModels(PT2EQuantizationTestCase):
@skip_if_no_torchvision

View File

@ -0,0 +1,328 @@
# Owner(s): ["oncall: quantization"]
import copy
import unittest
from typing import Any, Dict, Tuple
import torch
from torch._export import capture_pre_autograd_graph
from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
skipIfNoQNNPACK,
TestHelperModules,
)
@skipIfNoQNNPACK
class TestPT2ERepresentation(QuantizationTestCase):
def _test_representation(
self,
model: torch.nn.Module,
example_inputs: Tuple[Any, ...],
quantizer: Quantizer,
ref_node_occurrence: Dict[ns, int],
non_ref_node_occurrence: Dict[ns, int],
fixed_output_tol: float = None,
output_scale_idx: int = 3,
) -> torch.nn.Module:
# resetting dynamo cache
torch._dynamo.reset()
model = capture_pre_autograd_graph(
model,
example_inputs,
)
model_copy = copy.deepcopy(model)
model = prepare_pt2e(model, quantizer)
# Calibrate
model(*example_inputs)
model = convert_pt2e(model, use_reference_representation=True)
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
# make sure it runs
pt2e_quant_output = model(*example_inputs)
# TODO: torchdynamo times out when we do this, we can enable numerical checking
# after that is fixed
model_copy = prepare_pt2e(model_copy, quantizer)
# Calibrate
model_copy(*example_inputs)
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
self.checkGraphModuleNodes(
model_copy, expected_node_occurrence=non_ref_node_occurrence
)
pt2e_quant_output_copy = model_copy(*example_inputs)
output_tol = None
if fixed_output_tol is not None:
output_tol = fixed_output_tol
else:
idx = 0
for n in model_copy.graph.nodes:
if (
n.target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
):
idx += 1
if idx == output_scale_idx:
output_tol = n.args[1]
assert output_tol is not None
# make sure the result is off by one at most in the quantized integer representation
self.assertTrue(
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output))
<= (2 * output_tol + 1e-5)
)
def test_static_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 5),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
def test_dynamic_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=True
)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 5),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
fixed_output_tol=1e-4,
)
def test_conv2d(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv2d(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(1, 3, 3, 3),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
def test_add(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x + y
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
m_eager = M().eval()
example_inputs = (
torch.randn(1, 3, 3, 3),
torch.randn(1, 3, 3, 3),
)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
def test_add_relu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
out = x + y
out = torch.nn.functional.relu(out)
return out
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
example_inputs = (
torch.randn(1, 3, 3, 3),
torch.randn(1, 3, 3, 3),
)
ref_node_occurrence = {
ns.call_function(out_dtype): 2,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence=ref_node_occurrence,
non_ref_node_occurrence={},
)
def test_maxpool2d(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.ConvMaxPool2d().eval()
example_inputs = (torch.randn(1, 2, 2, 2),)
self._test_representation(
m_eager,
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
@unittest.skip("will fix later")
def test_representation_adaptive_avg_pool2d(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.ConvWithAdaptiveAvgPool2d().eval()
example_inputs = (torch.randn(1, 3, 3, 3),)
self._test_representation(
m_eager,
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
def test_qdq_per_channel(self):
"""Test representation for quantize_per_channel and dequantize_per_channel op"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
# use per channel quantization for weight
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()
inputs = [
(torch.randn(1, 5),),
(torch.randn(1, 3, 5),),
(torch.randn(1, 3, 3, 5),),
(torch.randn(1, 3, 3, 3, 5),),
]
for example_inputs in inputs:
ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 0,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 0,
}
non_ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 1,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 1,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence,
non_ref_node_occurrence,
output_scale_idx=2,
)
def test_qdq(self):
"""Test representation for quantize and dequantize op"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x + y
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
m_eager = M().eval()
example_inputs = (
torch.randn(1, 3, 3, 3),
torch.randn(1, 3, 3, 3),
)
ref_node_occurrence = {
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0,
}
non_ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 3,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence,
non_ref_node_occurrence,
)

View File

@ -86,6 +86,7 @@ try:
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EOps # noqa: F401
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401
from quantization.pt2e.test_representation import TestPT2ERepresentation # noqa: F401
from quantization.pt2e.test_x86inductor_quantizer import TestQuantizePT2EX86Inductor # noqa: F401
from quantization.pt2e.test_quantize_pt2e_fx import TestQuantizePT2EFX # noqa: F401
from quantization.pt2e.test_quantize_pt2e_fx import TestQuantizePT2EFXModels # noqa: F401

View File

@ -2431,3 +2431,163 @@ class SparseNNModel(nn.Module):
out = self.dense_top(sparse_feature, dense)
return out
class TestHelperModules:
class Conv2dPropAnnotaton(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 3)
x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
x = self.linear(x)
return x
class Conv2dWithObsSharingOps(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = torch.mean(x)
return x
class Conv2dWithTwoLinearPermute(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
self.linear1 = torch.nn.Linear(16, 8, bias=False)
self.linear2 = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
return self.linear2(self.linear1(permute_out))
class Conv2dWithTwoLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
self.linear1 = torch.nn.Linear(64, 8, bias=False)
self.linear2 = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
reshape_out = torch.reshape(conv_out, (2, 64))
return self.linear2(self.linear1(reshape_out))
class ConvLinearWPermute(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 8, 3)
self.linear1 = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
return self.linear1(permute_out)
class TwoLinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 16, bias=False)
self.linear2 = torch.nn.Linear(16, 8)
def forward(self, x):
return self.linear2(self.linear1(x))
class ConvMaxPool2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(2, 2, 1)
self.pool = torch.nn.MaxPool2d(1, 1)
def forward(self, x):
x = self.conv(x)
x = self.pool(x)
return x
class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
return x
class ConvWithBNRelu(torch.nn.Module):
def __init__(self, relu, bn=True, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3, bias=bias)
if bn:
self.bn = torch.nn.BatchNorm2d(3)
else:
self.bn = torch.nn.Identity()
if relu:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class Conv2dWithCat(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
z = torch.cat([x, y], dim=1)
return z
class EmbeddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
def forward(self, indices):
return self.emb(indices)
class EmbeddingConvLinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
self.conv = torch.nn.Conv2d(8, 16, (1, 3))
self.linear = torch.nn.Linear(16, 8)
def forward(self, indices):
embeddings = self.emb(indices)
embeddings = torch.unsqueeze(embeddings, dim=0)
embeddings = torch.permute(embeddings, (0, 3, 1, 2))
conv_out = self.conv(embeddings)
conv_out = torch.permute(conv_out, (0, 2, 3, 1))
conv_out = torch.squeeze(conv_out, dim=0)
return self.linear(conv_out)
class AddInplaceAdd(torch.nn.Module):
def forward(self, x, y):
x = x + y
x += y
return x
class MulInplaceMul(torch.nn.Module):
def forward(self, x, y):
x = x * y
x *= y
return x