mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This is part 2 of the effort to replace `backend_config_dict` with a python config object, a more formal and robust API that leads to better user experience. This commit integrates the `BackendConfig` implemented in part 1 (https://github.com/pytorch/pytorch/pull/81469) with the existing FX graph mode quantization flow. Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps BC-breaking Notes: Before: ``` import torch from torch.ao.quantization import get_default_qconfig_mapping from torch.ao.quantization.backend_config import ObservationType from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx dtype_config = { "input_dtype": torch.quint8, "output_dtype": torch.quint8 "weight_dtype": torch.qint8, "bias_dtype": torch.float, } backend_config_dict = { "name": "my_backend", "configs": [{ "pattern": torch.nn.Linear, "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [dtype_config], "root_module": torch.nn.Linear, "reference_quantized_module": torch.nn.quantized._reference.Linear, "qat_module": torch.nn.qat.Linear, }] } m = MyModel() qconfig_mapping = get_default_qconfig_mapping() example_inputs = (torch.rand(3, 3),) m = prepare_fx( m, qconfig_mapping, example_inputs, backend_config_dict=backend_config_dict) m = convert_fx(m, backend_config_dict=backend_config_dict) ``` After: ``` import torch from torch.ao.quantization import get_default_qconfig_mapping from torch.ao.quantization.backend_config import ( BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType, ) from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx dtype_config = DTypeConfig( input_dtype=torch.quint8, output_dtype=torch.quint8 weight_dtype=torch.qint8, bias_dtype=torch.float, ) backend_config = BackendConfig("my_backend").set_backend_pattern_config( BackendPatternConfig(torch.nn.Linear) .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .add_dtype_config(dtype_config) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(torch.nn.quantized._reference.Linear) .set_qat_module(torch.nn.qat.Linear)) m = MyModel() qconfig_mapping = get_default_qconfig_mapping() example_inputs = (torch.rand(3, 3),) m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) m = convert_fx(m, backend_config=backend_config) ``` Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar Differential Revision: [D38471932](https://our.internmc.facebook.com/intern/diff/D38471932) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82557 Approved by: https://github.com/jerryzh168 |
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| complex_numbers.rst | ||
| conf.py | ||
| config_mod.rst | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cuda.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| deploy.rst | ||
| distributed.algorithms.join.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| fft.rst | ||
| fsdp.rst | ||
| futures.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit_utils.rst | ||
| jit.rst | ||
| library.rst | ||
| linalg.rst | ||
| math-quantizer-equation.png | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| monitor.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nested.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx_supported_aten_ops.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| pipeline.rst | ||
| profiler.rst | ||
| quantization-accuracy-debugging.rst | ||
| quantization-backend-configuration.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| torch.ao.ns._numeric_suite_fx.rst | ||
| torch.ao.ns._numeric_suite.rst | ||
| torch.overrides.rst | ||
| torch.rst | ||
| type_info.rst | ||