pytorch/torch/ao/quantization/backend_config/x86.py
Xia, Weiwen 6fa84fdea2 [FX][Quant] Enable FX quant for patterns like x.view(x.size(...), ...) (#90001)
**Summary**
This work continues with https://github.com/pytorch/pytorch/pull/83784 by @vkuzo and includes all the changes in that PR.
Quote from https://github.com/pytorch/pytorch/pull/83784:
> Issue #83658 reports that ops followed by a certain pattern of `view` and `size` ops were not quantized correctly by FX graph mode quantization.
Before this PR, the "size" op was in the "op shares qparams with input" category, and the code assumed that the input of this op has the same dtype as its output. This led to incorrectly propagating the `int` dtype as the output of whichever op was preceding the `view` op, which in turn made that op blocklisted from quantization.

> The fix is to create a new category of ops which work on different dtypes of tensors but are not observed. This PR does so for `size`, and also for `shape` since it works the same way.

**Note**: This PR needs https://github.com/pytorch/pytorch/pull/91297 to be landed first otherwise there is a UT failure.

**Test plan**
```
python test/test_quantization.py -k test_linear_size_view
python test/test_quantization.py -k test_linear_shape_view
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90001
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
2023-01-27 07:56:29 +00:00

114 lines
3.7 KiB
Python

import torch
from ._common_operator_config_utils import (
_get_binary_op_configs,
_get_bn_configs,
_get_cat_config,
_get_conv_configs,
_get_default_op_configs,
_get_embedding_op_configs,
_get_fixed_qparams_op_configs,
_get_linear_configs,
_get_rnn_op_configs,
_get_share_qparams_op_configs,
_get_tensor_info_op_configs,
)
from .backend_config import BackendConfig, DTypeConfig
__all__ = [
"get_x86_backend_config",
]
# ===================
# | DTYPE CONFIGS |
# ===================
# X86 aligns with FBGEMM for now
x86_weighted_op_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
)
x86_default_op_quint8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
)
x86_default_op_fp16_dtype_config = DTypeConfig(
input_dtype=torch.float16,
output_dtype=torch.float16,
weight_dtype=torch.float16,
bias_dtype=torch.float16,
)
x86_default_dynamic_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.float,
weight_dtype=torch.qint8,
bias_dtype=torch.float,
is_dynamic=True,
)
x86_default_dynamic_float16_dtype_config = DTypeConfig(
input_dtype=torch.float16,
output_dtype=torch.float,
weight_dtype=torch.float16,
bias_dtype=torch.float,
is_dynamic=True,
)
x86_weight_only_quint8_dtype_config = DTypeConfig(
input_dtype=torch.float,
output_dtype=torch.float,
weight_dtype=torch.quint8,
)
x86_weight_only_quint4x2_dtype_config = DTypeConfig(
input_dtype=torch.float,
output_dtype=torch.float,
weight_dtype=torch.quint4x2,
)
# =====================
# | BACKEND CONFIGS |
# =====================
def get_x86_backend_config() -> BackendConfig:
"""
Return the `BackendConfig` for PyTorch's native x86 backend.
"""
conv_dtype_configs = [x86_weighted_op_int8_dtype_config]
linear_dtype_configs = [
x86_weighted_op_int8_dtype_config,
x86_default_dynamic_int8_dtype_config,
x86_default_dynamic_float16_dtype_config,
]
binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
default_op_dtype_configs = [x86_default_op_quint8_dtype_config]
fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config]
tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config]
rnn_op_dtype_configs = [
x86_default_dynamic_int8_dtype_config,
x86_default_dynamic_float16_dtype_config,
]
embedding_op_dtype_configs = [
x86_weight_only_quint8_dtype_config,
x86_weight_only_quint4x2_dtype_config,
]
return BackendConfig("x86") \
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))