pytorch/test/test_quantization.py
Supriya Rao b8386f5d72 [quant] Create FusedMovingAvgObsFakeQuantize for QAT (#61691)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61691

Create a new module for QAT that does a Fused MovingAvgMinMaxObserver and FakeQuantize operation
The module currently only supports per-tensor quantization (affine/symmetric). Follow-up PR will add support for per-channel

Results on running QAT with MobileNetV2 (Obs enabled/fake_quant enabled)
Original FQ module
PyTorchObserver {"type": "_", "metric": "qnnpack_fp_latency_ms", "unit": "ms", "value": "242.80261993408203"}
PyTorchObserver {"type": "_", "metric": "qnnpack_qat0_latency_ms", "unit": "ms", "value": "505.7964324951172"}
PyTorchObserver {"type": "_", "metric": "fbgemm_fp_latency_ms", "unit": "ms", "value": "235.80145835876465"}
PyTorchObserver {"type": "_", "metric": "fbgemm_qat0_latency_ms", "unit": "ms", "value": "543.8144207000732"}

Fused FakeQuant module (~50% improvement in latency)
PyTorchObserver {"type": "_", "metric": "qnnpack_fp_latency_ms", "unit": "ms", "value": "232.1624755859375"}
PyTorchObserver {"type": "_", "metric": "qnnpack_qat0_latency_ms", "unit": "ms", "value": "263.8866901397705"}
PyTorchObserver {"type": "_", "metric": "fbgemm_fp_latency_ms", "unit": "ms", "value": "236.9832992553711"}
PyTorchObserver {"type": "_", "metric": "fbgemm_qat0_latency_ms", "unit": "ms", "value": "292.1590805053711"}

Individual module benchmark result (>5x improvement in latency)
===> Baseline FakeQuantize module
```
---------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
---------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
              aten::fake_quantize_per_tensor_affine         0.77%       1.210ms         4.92%       7.730ms     154.596us     718.528us         0.45%       9.543ms     190.862us            50
    aten::fake_quantize_per_tensor_affine_cachemask         2.41%       3.792ms         4.15%       6.520ms     130.402us       8.825ms         5.58%       8.825ms     176.492us            50
                                     aten::_aminmax         3.25%       5.105ms         4.43%       6.955ms     139.102us       8.193ms         5.18%       8.193ms     163.868us            50
                                   aten::zeros_like         1.87%       2.939ms         6.95%      10.922ms     109.218us       5.992ms         3.79%      10.844ms     108.442us           100
                                        aten::zeros         0.97%       1.527ms         3.11%       4.885ms      97.702us       2.383ms         1.51%       4.800ms      96.010us            50
                                         aten::rsub         1.34%       2.106ms         2.94%       4.614ms      92.277us       2.063ms         1.30%       4.559ms      91.173us            50
                                        aten::clamp         2.79%       4.381ms         5.42%       8.519ms      85.190us       5.385ms         3.41%       8.438ms      84.381us           100
                                           aten::eq        11.70%      18.384ms        21.31%      33.479ms      83.280us      22.465ms        14.21%      33.310ms      82.861us           402
                                         aten::ones         1.05%       1.656ms         2.57%       4.038ms      80.751us       2.494ms         1.58%       3.951ms      79.028us            50
                                           aten::le         2.52%       3.955ms         4.84%       7.607ms      76.071us       4.998ms         3.16%       7.702ms      77.016us           100
                                          aten::min         0.69%       1.087ms         2.32%       3.641ms      72.827us       1.017ms         0.64%       3.603ms      72.055us            50
                                          aten::max         1.40%       2.195ms         4.62%       7.260ms      72.597us       2.008ms         1.27%       7.140ms      71.404us           100
                                   aten::is_nonzero         2.68%       4.207ms        11.35%      17.829ms      71.033us       4.062ms         2.57%      17.225ms      68.625us           251
                                       aten::detach         1.17%       1.831ms         3.65%       5.736ms      57.360us       1.680ms         1.06%       5.634ms      56.340us           100
                                          aten::mul         3.36%       5.278ms         3.36%       5.278ms      53.862us       5.215ms         3.30%       5.215ms      53.216us            98
                                          aten::div         3.42%       5.376ms         3.42%       5.376ms      53.759us       5.320ms         3.36%       5.320ms      53.196us           100
                                          aten::sub         6.79%      10.672ms         6.79%      10.672ms      53.901us      10.504ms         6.64%      10.504ms      53.050us           198
                                         aten::item         4.06%       6.380ms        12.02%      18.883ms      53.798us       6.127ms         3.87%      18.322ms      52.198us           351
                                          aten::add         3.28%       5.147ms         3.28%       5.147ms      52.518us       5.113ms         3.23%       5.113ms      52.171us            98
                                      aten::minimum         1.63%       2.555ms         1.63%       2.555ms      51.092us       2.585ms         1.64%       2.585ms      51.708us            50
                                      aten::maximum         3.22%       5.065ms         3.22%       5.065ms      50.646us       5.133ms         3.25%       5.133ms      51.329us           100
                                        aten::round         1.61%       2.529ms         1.61%       2.529ms      50.578us       2.528ms         1.60%       2.528ms      50.552us            50
                                        aten::zero_         1.99%       3.125ms         4.72%       7.422ms      49.481us       2.835ms         1.79%       7.269ms      48.462us           150
                                        aten::copy_         6.62%      10.394ms         6.62%      10.394ms      41.576us      10.252ms         6.48%      10.252ms      41.010us           250
                                             detach         2.49%       3.905ms         2.49%       3.905ms      39.049us       3.954ms         2.50%       3.954ms      39.539us           100
                                       aten::select         2.01%       3.154ms         2.47%       3.876ms      38.759us       3.866ms         2.44%       3.866ms      38.658us           100
                          aten::_local_scalar_dense         7.96%      12.503ms         7.96%      12.503ms      35.621us      12.195ms         7.71%      12.195ms      34.743us           351
                                           aten::to         2.31%       3.625ms         4.16%       6.530ms      32.650us       4.320ms         2.73%       6.270ms      31.348us           200
                                        aten::fill_         3.70%       5.808ms         3.70%       5.808ms      29.039us       5.892ms         3.73%       5.892ms      29.459us           200
                                   aten::as_strided         0.79%       1.244ms         0.79%       1.244ms       6.221us       0.000us         0.00%       0.000us       0.000us           200
                                        aten::empty         3.55%       5.579ms         3.55%       5.579ms      11.137us       0.000us         0.00%       0.000us       0.000us           501
                                      aten::resize_         2.36%       3.712ms         2.36%       3.712ms      12.332us       0.000us         0.00%       0.000us       0.000us           301
                                   aten::empty_like         1.45%       2.284ms         3.68%       5.776ms      28.878us       0.000us         0.00%       0.000us       0.000us           200
                                aten::empty_strided         2.80%       4.398ms         2.80%       4.398ms      17.592us       0.000us         0.00%       0.000us       0.000us           250
---------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 157.108ms
Self CUDA time total: 158.122ms
```

===> FusedFakeQuant
```
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                   fb::fused_fake_quant        23.42%       6.408ms       100.00%      27.361ms     547.215us       7.887ms        27.20%      28.996ms     579.925us            50
                  aten::fake_quantize_per_tensor_affine         4.25%       1.162ms        27.65%       7.565ms     151.298us     686.176us         2.37%      10.217ms     204.336us            50
aten::_fake_quantize_per_tensor_affine_cachemask_ten...        14.11%       3.860ms        23.40%       6.403ms     128.068us       9.531ms        32.87%       9.531ms     190.612us            50
                                         aten::_aminmax        20.57%       5.628ms        27.47%       7.515ms     150.305us       8.218ms        28.34%       8.218ms     164.367us            50
                                             aten::item         3.65%     999.522us        10.27%       2.810ms      56.202us     931.904us         3.21%       2.674ms      53.481us            50
                              aten::_local_scalar_dense         6.62%       1.811ms         6.62%       1.811ms      36.212us       1.742ms         6.01%       1.742ms      34.843us            50
                                            aten::empty        10.85%       2.969ms        10.85%       2.969ms      14.843us       0.000us         0.00%       0.000us       0.000us           200
                                       aten::as_strided         1.92%     524.365us         1.92%     524.365us       5.244us       0.000us         0.00%       0.000us       0.000us           100
                                       aten::empty_like         6.48%       1.774ms        14.62%       4.000ms      26.670us       0.000us         0.00%       0.000us       0.000us           150
                                    aten::empty_strided         8.14%       2.226ms         8.14%       2.226ms      14.842us       0.000us         0.00%       0.000us       0.000us           150
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 27.361ms
Self CUDA time total: 28.996ms
```

Test Plan:
python test/test_quantization.py TestFusedObsFakeQuantModule

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D29706889

fbshipit-source-id: ae3f9fb1fc559920459bf6e8663e8299bf7d21e1
2021-07-21 10:13:04 -07:00

106 lines
5.7 KiB
Python

# -*- coding: utf-8 -*-
from torch.testing._internal.common_utils import run_tests
# Quantization core tests. These include tests for
# - quantized kernels
# - quantized functional operators
# - quantized workflow modules
# - quantized workflow operators
# - quantized tensor
# 1. Quantized Kernels
# TODO: merge the different quantized op tests into one test class
from quantization.core.test_quantized_op import TestQuantizedOps # noqa: F401
from quantization.core.test_quantized_op import TestQNNPackOps # noqa: F401
from quantization.core.test_quantized_op import TestQuantizedLinear # noqa: F401
from quantization.core.test_quantized_op import TestQuantizedConv # noqa: F401
from quantization.core.test_quantized_op import TestDynamicQuantizedLinear # noqa: F401
from quantization.core.test_quantized_op import TestComparatorOps # noqa: F401
from quantization.core.test_quantized_op import TestPadding # noqa: F401
from quantization.core.test_quantized_op import TestQuantizedEmbeddingOps # noqa: F401
from quantization.core.test_quantized_op import TestDynamicQuantizedRNNOp # noqa: F401
# 2. Quantized Functional/Workflow Ops
from quantization.core.test_quantized_functional import TestQuantizedFunctionalOps # noqa: F401
from quantization.core.test_workflow_ops import TestFakeQuantizeOps # noqa: F401
from quantization.core.test_workflow_ops import TestFusedObsFakeQuant # noqa: F401
# 3. Quantized Tensor
from quantization.core.test_quantized_tensor import TestQuantizedTensor # noqa: F401
# 4. Modules
from quantization.core.test_workflow_module import TestFakeQuantize # noqa: F401
from quantization.core.test_workflow_module import TestObserver # noqa: F401
from quantization.core.test_quantized_module import TestStaticQuantizedModule # noqa: F401
from quantization.core.test_quantized_module import TestDynamicQuantizedModule # noqa: F401
from quantization.core.test_workflow_module import TestRecordHistogramObserver # noqa: F401
from quantization.core.test_workflow_module import TestHistogramObserver # noqa: F401
from quantization.core.test_workflow_module import TestDistributed # noqa: F401
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented
# using eager mode.
# 1. Eager mode post training quantization
from quantization.eager.test_quantize_eager_ptq import TestPostTrainingStatic # noqa: F401
from quantization.eager.test_quantize_eager_ptq import TestPostTrainingDynamic # noqa: F401
from quantization.eager.test_quantize_eager_ptq import TestEagerModeActivationOps # noqa: F401
from quantization.eager.test_quantize_eager_ptq import TestFunctionalModule # noqa: F401
from quantization.eager.test_quantize_eager_ptq import TestQuantizeONNXExport # noqa: F401
# 2. Eager mode quantization aware training
from quantization.eager.test_quantize_eager_qat import TestQuantizationAwareTraining # noqa: F401
from quantization.eager.test_quantize_eager_qat import TestQATActivationOps # noqa: F401
from quantization.eager.test_quantize_eager_qat import TestConvBNQATModule # noqa: F401
# 3. Eager mode fusion passes
from quantization.eager.test_fusion import TestFusion # noqa: F401
# 4. Testing model numerics between quanitzed and FP32 models
from quantization.eager.test_model_numerics import TestModelNumericsEager # noqa: F401
# 5. Tooling: numeric_suite
from quantization.eager.test_numeric_suite_eager import TestEagerModeNumericSuite # noqa: F401
# 6. Equalization and Bias Correction
from quantization.eager.test_equalize_eager import TestEqualizeEager # noqa: F401
from quantization.eager.test_bias_correction_eager import TestBiasCorrection # noqa: F401
# FX GraphModule Graph Mode Quantization. Tests for the functionality of APIs and different features implemented
# using fx quantization.
try:
from quantization.fx.test_quantize_fx import TestFuseFx # noqa: F401
from quantization.fx.test_quantize_fx import TestQuantizeFx # noqa: F401
from quantization.fx.test_quantize_fx import TestQuantizeFxOps # noqa: F401
from quantization.fx.test_quantize_fx import TestQuantizeFxModels # noqa: F401
except ImportError:
# In FBCode we separate FX out into a separate target for the sake of dev
# velocity. These are covered by a separate test target `quantization_fx`
pass
try:
from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcher # noqa: F401
from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcherModels # noqa: F401
from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIs # noqa: F401
from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIsModels # noqa: F401
except ImportError:
pass
# Equalization for FX mode
try:
from quantization.fx.test_equalize_fx import TestEqualizeFx # noqa: F401
except ImportError:
pass
# Backward Compatibility. Tests serialization and BC for quantized modules.
from quantization.bc.test_backward_compatibility import TestSerialization # noqa: F401
# JIT Graph Mode Quantization
from quantization.jit.test_quantize_jit import TestQuantizeJit # noqa: F401
from quantization.jit.test_quantize_jit import TestQuantizeJitPasses # noqa: F401
from quantization.jit.test_quantize_jit import TestQuantizeJitOps # noqa: F401
from quantization.jit.test_quantize_jit import TestQuantizeDynamicJitPasses # noqa: F401
from quantization.jit.test_quantize_jit import TestQuantizeDynamicJitOps # noqa: F401
# Quantization specific fusion passes
from quantization.jit.test_fusion_passes import TestFusionPasses # noqa: F401
from quantization.jit.test_deprecated_jit_quant import TestDeprecatedJitQuantized # noqa: F401
if __name__ == '__main__':
run_tests()