mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
**Summary** This work continues with https://github.com/pytorch/pytorch/pull/83784 by @vkuzo and includes all the changes in that PR. Quote from https://github.com/pytorch/pytorch/pull/83784: > Issue #83658 reports that ops followed by a certain pattern of `view` and `size` ops were not quantized correctly by FX graph mode quantization. Before this PR, the "size" op was in the "op shares qparams with input" category, and the code assumed that the input of this op has the same dtype as its output. This led to incorrectly propagating the `int` dtype as the output of whichever op was preceding the `view` op, which in turn made that op blocklisted from quantization. > The fix is to create a new category of ops which work on different dtypes of tensors but are not observed. This PR does so for `size`, and also for `shape` since it works the same way. **Note**: This PR needs https://github.com/pytorch/pytorch/pull/91297 to be landed first otherwise there is a UT failure. **Test plan** ``` python test/test_quantization.py -k test_linear_size_view python test/test_quantization.py -k test_linear_shape_view ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/90001 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
205 lines
7.9 KiB
Python
205 lines
7.9 KiB
Python
import torch
|
|
from ._common_operator_config_utils import (
|
|
_get_binary_op_configs,
|
|
_get_bn_configs,
|
|
_get_cat_config,
|
|
_get_conv_configs,
|
|
_get_default_op_configs,
|
|
_get_embedding_op_configs,
|
|
_get_fixed_qparams_op_configs,
|
|
_get_linear_configs,
|
|
_get_ln_configs,
|
|
_get_rnn_op_configs,
|
|
_get_share_qparams_op_configs,
|
|
_get_tensor_info_op_configs,
|
|
)
|
|
from .backend_config import BackendConfig, DTypeConfig
|
|
|
|
__all__ = [
|
|
"get_test_only_legacy_native_backend_config",
|
|
"default_op_quint8_dtype_config",
|
|
"default_op_fp16_dtype_config",
|
|
"default_dynamic_int8_dtype_config",
|
|
"default_dynamic_float16_dtype_config",
|
|
"input_output_only_quint8_dtype_config",
|
|
"weight_only_quint8_dtype_config",
|
|
"weight_only_quint4x2_dtype_config",
|
|
"get_native_backend_config",
|
|
"get_native_backend_config_dict",
|
|
"get_test_only_legacy_native_backend_config_dict",
|
|
]
|
|
|
|
# ===================
|
|
# | DTYPE CONFIGS |
|
|
# ===================
|
|
|
|
# weighted op int8 dtype config
|
|
# this is config for ops that has quantized weights, like linear, conv
|
|
weighted_op_quint8_dtype_config = DTypeConfig(
|
|
input_dtype=torch.quint8,
|
|
output_dtype=torch.quint8,
|
|
weight_dtype=torch.qint8,
|
|
bias_dtype=torch.float,
|
|
)
|
|
|
|
default_op_quint8_dtype_config = DTypeConfig(
|
|
input_dtype=torch.quint8,
|
|
output_dtype=torch.quint8,
|
|
)
|
|
|
|
default_op_fp16_dtype_config = DTypeConfig(
|
|
input_dtype=torch.float16,
|
|
output_dtype=torch.float16,
|
|
weight_dtype=torch.float16,
|
|
bias_dtype=torch.float16,
|
|
)
|
|
|
|
default_dynamic_int8_dtype_config = DTypeConfig(
|
|
input_dtype=torch.quint8,
|
|
output_dtype=torch.float,
|
|
weight_dtype=torch.qint8,
|
|
bias_dtype=torch.float,
|
|
# currently the dtype check is not yet enabled, so we provided the dtype_configs but
|
|
# it is not really used yet,
|
|
# we will enable it a bit later after we moved everything to backend_config_dict
|
|
is_dynamic=True,
|
|
)
|
|
|
|
default_dynamic_float16_dtype_config = DTypeConfig(
|
|
input_dtype=torch.float16,
|
|
output_dtype=torch.float,
|
|
weight_dtype=torch.float16,
|
|
bias_dtype=torch.float,
|
|
# currently the dtype check is not yet enabled, so we provided the dtype_configs but
|
|
# it is not really used yet,
|
|
# we will enable it a bit later after we moved everything to backend_config_dict
|
|
is_dynamic=True,
|
|
)
|
|
|
|
# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights
|
|
input_output_only_quint8_dtype_config = DTypeConfig(
|
|
input_dtype=torch.quint8,
|
|
output_dtype=torch.quint8,
|
|
weight_dtype=torch.float,
|
|
bias_dtype=torch.float,
|
|
)
|
|
|
|
weight_only_quint8_dtype_config = DTypeConfig(
|
|
input_dtype=torch.float,
|
|
output_dtype=torch.float,
|
|
weight_dtype=torch.quint8,
|
|
)
|
|
|
|
weight_only_quint4x2_dtype_config = DTypeConfig(
|
|
input_dtype=torch.float,
|
|
output_dtype=torch.float,
|
|
weight_dtype=torch.quint4x2,
|
|
)
|
|
|
|
|
|
# =====================
|
|
# | BACKEND CONFIGS |
|
|
# =====================
|
|
|
|
def get_test_only_legacy_native_backend_config() -> BackendConfig:
|
|
"""
|
|
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops.
|
|
"""
|
|
conv_dtype_configs = [weighted_op_quint8_dtype_config]
|
|
linear_dtype_configs = [
|
|
weighted_op_quint8_dtype_config,
|
|
default_dynamic_int8_dtype_config,
|
|
default_dynamic_float16_dtype_config,
|
|
default_op_fp16_dtype_config,
|
|
]
|
|
binary_op_dtype_configs = [
|
|
default_op_quint8_dtype_config,
|
|
default_op_fp16_dtype_config,
|
|
]
|
|
default_op_dtype_configs = [default_op_quint8_dtype_config]
|
|
fixed_qparams_op_dtype_configs = [
|
|
default_op_quint8_dtype_config,
|
|
default_op_fp16_dtype_config,
|
|
]
|
|
share_qparams_op_dtype_configs = [
|
|
default_op_quint8_dtype_config,
|
|
default_op_fp16_dtype_config
|
|
]
|
|
tensor_info_op_dtype_configs = [
|
|
default_op_quint8_dtype_config,
|
|
]
|
|
rnn_op_dtype_configs = [
|
|
default_dynamic_int8_dtype_config,
|
|
default_dynamic_float16_dtype_config,
|
|
]
|
|
embedding_op_dtype_configs = [
|
|
weight_only_quint8_dtype_config,
|
|
weight_only_quint4x2_dtype_config,
|
|
]
|
|
layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
|
|
return BackendConfig("_native_and_fp16") \
|
|
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
|
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
|
|
|
|
def get_native_backend_config() -> BackendConfig:
|
|
"""
|
|
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
|
|
"""
|
|
# TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs
|
|
conv_dtype_configs = [weighted_op_quint8_dtype_config]
|
|
linear_dtype_configs = [
|
|
weighted_op_quint8_dtype_config,
|
|
default_dynamic_int8_dtype_config,
|
|
default_dynamic_float16_dtype_config,
|
|
]
|
|
binary_op_dtype_configs = [default_op_quint8_dtype_config]
|
|
default_op_dtype_configs = [default_op_quint8_dtype_config]
|
|
fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
|
|
share_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
|
|
tensor_info_op_dtype_configs = [default_op_quint8_dtype_config]
|
|
rnn_op_dtype_configs = [
|
|
default_dynamic_int8_dtype_config,
|
|
default_dynamic_float16_dtype_config,
|
|
]
|
|
embedding_op_dtype_configs = [
|
|
weight_only_quint8_dtype_config,
|
|
weight_only_quint4x2_dtype_config,
|
|
]
|
|
layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
|
|
return BackendConfig("native") \
|
|
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
|
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
|
|
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
|
|
|
|
def get_native_backend_config_dict():
|
|
"""
|
|
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form.
|
|
"""
|
|
return get_native_backend_config().to_dict()
|
|
|
|
def get_test_only_legacy_native_backend_config_dict():
|
|
"""
|
|
Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional
|
|
fp16 ops in dictionary form.
|
|
"""
|
|
return get_test_only_legacy_native_backend_config().to_dict()
|