pytorch/test/test_quantization.py
Vasiliy Kuznetsov b57c8b720e [wip] Make quantization modules work with DataParallel (#37032)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37032

DataParallel requires all params and buffers of child modules to be updated
in place because of how it implements model replication during the
forward pass (see https://github.com/pytorch/pytorch/pull/12671 for
context). Any params or buffers not updated in place are lost and not
propagated back to the master.

This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut
point) to do their parameter update in-place. This will enable static
quant and QAT to work correctly with DataParallel.

TODO: https://github.com/pytorch/pytorch/pull/32684 needs to land before we can fix the graph mode test failures on this PR.

Test Plan:
script failed before and passes after the diff:
https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40

TODO before land: add integration testing

Imported from OSS

Differential Revision: D21206454

fbshipit-source-id: df6b4b04d0ae0f7ef582c82d81418163019e96f7
2020-05-05 13:06:43 -07:00

63 lines
3.0 KiB
Python

# -*- coding: utf-8 -*-
from torch.testing._internal.common_utils import run_tests
# Quantized Tensor
from quantization.test_quantized_tensor import TestQuantizedTensor # noqa: F401
# Quantized Op
# TODO: merge test cases in quantization.test_quantized
from quantization.test_quantized_op import TestQuantizedOps # noqa: F401
from quantization.test_quantized_op import TestQNNPackOps # noqa: F401
from quantization.test_quantized_op import TestQuantizedLinear # noqa: F401
from quantization.test_quantized_op import TestQuantizedConv # noqa: F401
from quantization.test_quantized_op import TestDynamicQuantizedLinear # noqa: F401
from quantization.test_quantized_op import TestComparatorOps # noqa: F401
from quantization.test_quantized_op import TestPadding # noqa: F401
# Quantized Functional
from quantization.test_quantized_functional import TestQuantizedFunctional # noqa: F401
# Quantized Module
from quantization.test_quantized_module import TestStaticQuantizedModule # noqa: F401
from quantization.test_quantized_module import TestDynamicQuantizedModule # noqa: F401
# Quantization Aware Training
from quantization.test_qat_module import TestQATModule # noqa: F401
# Module
# TODO: merge the fake quant per tensor and per channel test cases
# TODO: some of the tests are actually operator tests, e.g. test_forward_per_tensor, and
# should be moved to test_quantized_op
from quantization.test_workflow_module import TestFakeQuantizePerTensor # noqa: F401
from quantization.test_workflow_module import TestFakeQuantizePerChannel # noqa: F401
from quantization.test_workflow_module import TestObserver # noqa: F401
# TODO: merge with TestObserver
# TODO: some tests belong to test_quantize.py, e.g. test_record_observer
from quantization.test_workflow_module import TestRecordHistogramObserver # noqa: F401
from quantization.test_workflow_module import TestDistributed # noqa: F401
# Workflow
# 1. Eager mode quantization
from quantization.test_quantize import TestPostTrainingStatic # noqa: F401
from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401
from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401
# TODO: move to test_quantize_script
from quantization.test_quantize import TestGraphModePostTrainingStatic # noqa: F401
# TODO: merge with other tests in test_quantize.py?
from quantization.test_quantize import TestFunctionalModule # noqa: F401
from quantization.test_quantize import TestFusion # noqa: F401
from quantization.test_quantize import TestModelNumerics # noqa: F401
# 2. Graph mode quantization
from quantization.test_quantize_script import TestQuantizeScriptJitPasses # noqa: F401
from quantization.test_quantize_script import TestQuantizeScriptPTSQOps # noqa: F401
from quantization.test_quantize_script import TestQuantizeDynamicScript # noqa: F401
# Tooling: numric_suite
from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401
# Backward Compatibility
from quantization.test_backward_compatibility import TestSerialization # noqa: F401
if __name__ == '__main__':
run_tests()