mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2] Add prepare QAT test for mobilenetv2 (#104068)
Summary: Prepare QAT for mobilenetv2 has matching numerics with FX. There were two changes needed to achieve this, however. First, this commit adds observer sharing for ReLU6, which is used extensively throughout this model. Second, in the tests we have to use the same manual seed every time we call the models in order to get the same results between FX and PT2. This is because there is a dropout at the end of the model. Test Plan: python test/test_quantization.py TestQuantizePT2EModels.test_qat_mobilenet_v2 Reviewed By: kimishpatel Differential Revision: D46707786 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104068 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
fd40abb706
commit
7320ef5651
|
|
@ -283,6 +283,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||||
Helper method to verify that the QAT numerics for PT2E quantization match those of
|
Helper method to verify that the QAT numerics for PT2E quantization match those of
|
||||||
FX graph mode quantization for symmetric qnnpack.
|
FX graph mode quantization for symmetric qnnpack.
|
||||||
"""
|
"""
|
||||||
|
MANUAL_SEED = 100
|
||||||
|
|
||||||
# PT2 export
|
# PT2 export
|
||||||
|
|
||||||
model_pt2e = copy.deepcopy(model)
|
model_pt2e = copy.deepcopy(model)
|
||||||
|
|
@ -298,6 +300,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||||
aten_graph=True,
|
aten_graph=True,
|
||||||
)
|
)
|
||||||
model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer)
|
model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer)
|
||||||
|
torch.manual_seed(MANUAL_SEED)
|
||||||
after_prepare_result_pt2e = model_pt2e(*example_inputs)
|
after_prepare_result_pt2e = model_pt2e(*example_inputs)
|
||||||
|
|
||||||
# FX
|
# FX
|
||||||
|
|
@ -317,6 +320,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||||
model_fx = prepare_qat_fx(
|
model_fx = prepare_qat_fx(
|
||||||
model_fx, qconfig_mapping, example_inputs, backend_config=backend_config
|
model_fx, qconfig_mapping, example_inputs, backend_config=backend_config
|
||||||
)
|
)
|
||||||
|
torch.manual_seed(MANUAL_SEED)
|
||||||
after_prepare_result_fx = model_fx(*example_inputs)
|
after_prepare_result_fx = model_fx(*example_inputs)
|
||||||
|
|
||||||
# Verify that numerics match
|
# Verify that numerics match
|
||||||
|
|
@ -1883,3 +1887,17 @@ class TestQuantizePT2EModels(PT2EQuantizationTestCase):
|
||||||
self._verify_symmetric_qnnpack_qat_numerics(
|
self._verify_symmetric_qnnpack_qat_numerics(
|
||||||
m, example_inputs, is_per_channel=True, verify_convert=True,
|
m, example_inputs, is_per_channel=True, verify_convert=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skip_if_no_torchvision
|
||||||
|
@skipIfNoQNNPACK
|
||||||
|
def test_qat_mobilenet_v2(self):
|
||||||
|
import torchvision
|
||||||
|
with override_quantized_engine("qnnpack"):
|
||||||
|
example_inputs = (torch.randn(1, 3, 224, 224),)
|
||||||
|
m = torchvision.models.mobilenet_v2()
|
||||||
|
self._verify_symmetric_qnnpack_qat_numerics(
|
||||||
|
m, example_inputs, is_per_channel=False, verify_convert=False,
|
||||||
|
)
|
||||||
|
self._verify_symmetric_qnnpack_qat_numerics(
|
||||||
|
m, example_inputs, is_per_channel=True, verify_convert=False,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -685,6 +685,9 @@ class QNNPackQuantizer(Quantizer):
|
||||||
self._annotate_input_out_obs_sharing_op(
|
self._annotate_input_out_obs_sharing_op(
|
||||||
torch.nn.modules.Hardtanh, gm, quantization_config
|
torch.nn.modules.Hardtanh, gm, quantization_config
|
||||||
)
|
)
|
||||||
|
self._annotate_input_out_obs_sharing_op(
|
||||||
|
torch.nn.modules.ReLU6, gm, quantization_config
|
||||||
|
)
|
||||||
|
|
||||||
def _annotate_mean(
|
def _annotate_mean(
|
||||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user