[quant][pt2] Add prepare QAT test for mobilenetv2 (#104068)

Summary:
Prepare QAT for mobilenetv2 has matching numerics with
FX. There were two changes needed to achieve this, however.
First, this commit adds observer sharing for ReLU6, which is
used extensively throughout this model. Second, in the tests we
have to use the same manual seed every time we call the models
in order to get the same results between FX and PT2. This is
because there is a dropout at the end of the model.

Test Plan: python test/test_quantization.py TestQuantizePT2EModels.test_qat_mobilenet_v2

Reviewed By: kimishpatel

Differential Revision: D46707786

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104068
Approved by: https://github.com/jerryzh168
This commit is contained in:
Andrew Or 2023-06-23 16:34:25 +00:00 committed by PyTorch MergeBot
parent fd40abb706
commit 7320ef5651
2 changed files with 21 additions and 0 deletions

View File

@ -283,6 +283,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
Helper method to verify that the QAT numerics for PT2E quantization match those of Helper method to verify that the QAT numerics for PT2E quantization match those of
FX graph mode quantization for symmetric qnnpack. FX graph mode quantization for symmetric qnnpack.
""" """
MANUAL_SEED = 100
# PT2 export # PT2 export
model_pt2e = copy.deepcopy(model) model_pt2e = copy.deepcopy(model)
@ -298,6 +300,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
aten_graph=True, aten_graph=True,
) )
model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer) model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer)
torch.manual_seed(MANUAL_SEED)
after_prepare_result_pt2e = model_pt2e(*example_inputs) after_prepare_result_pt2e = model_pt2e(*example_inputs)
# FX # FX
@ -317,6 +320,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
model_fx = prepare_qat_fx( model_fx = prepare_qat_fx(
model_fx, qconfig_mapping, example_inputs, backend_config=backend_config model_fx, qconfig_mapping, example_inputs, backend_config=backend_config
) )
torch.manual_seed(MANUAL_SEED)
after_prepare_result_fx = model_fx(*example_inputs) after_prepare_result_fx = model_fx(*example_inputs)
# Verify that numerics match # Verify that numerics match
@ -1883,3 +1887,17 @@ class TestQuantizePT2EModels(PT2EQuantizationTestCase):
self._verify_symmetric_qnnpack_qat_numerics( self._verify_symmetric_qnnpack_qat_numerics(
m, example_inputs, is_per_channel=True, verify_convert=True, m, example_inputs, is_per_channel=True, verify_convert=True,
) )
@skip_if_no_torchvision
@skipIfNoQNNPACK
def test_qat_mobilenet_v2(self):
import torchvision
with override_quantized_engine("qnnpack"):
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.mobilenet_v2()
self._verify_symmetric_qnnpack_qat_numerics(
m, example_inputs, is_per_channel=False, verify_convert=False,
)
self._verify_symmetric_qnnpack_qat_numerics(
m, example_inputs, is_per_channel=True, verify_convert=False,
)

View File

@ -685,6 +685,9 @@ class QNNPackQuantizer(Quantizer):
self._annotate_input_out_obs_sharing_op( self._annotate_input_out_obs_sharing_op(
torch.nn.modules.Hardtanh, gm, quantization_config torch.nn.modules.Hardtanh, gm, quantization_config
) )
self._annotate_input_out_obs_sharing_op(
torch.nn.modules.ReLU6, gm, quantization_config
)
def _annotate_mean( def _annotate_mean(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig