pytorch/torch/quantization/fx/utils.py
Charles David Hernandez c1d070d0f0 [ao] Fixing obs insertion through dtype propagation (#73274)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73274

As noticed in https://discuss.pytorch.org/t/calibration-of-model-in-post-training-static-quantization-using-fx-api/143661/6
and related to https://github.com/pytorch/pytorch/issues/72698 when using fx quantizaiton, if an op like view was used in a
model and the index parameters were passed in to the ops with a
variable rather than
hard coded, fx would mistakenly insert observers for them, leading to an
error when the observer tried to do tensor only operations on a
non-tensor. To fix this, an API was added to specify non tensor
arguments for various ops to enable better dtype propagation.
NON_TENSOR_ARG_DICT is a nested dict whose first key is a named tuple
which contains matching parameters for ops with nontensor args, the
inner dict's keys are dtypes and the values are a list of those arg indices that
take use such dtypes. Alternatively, instead of a list, the inner dict
value can also be a function that takes the node as an argument and
returns the list of arg indices.

Theoretically this api can support arbitrary functions but the current
implmentation is limited to simpler functions given the particular
issue this fixes seems to be rare.

Note: although torch.unsqueeze and torch.transpose are listed in
quantization_patterns.py, those ops appear to be untraceable by fx. I've
included tests for their cases but fixing this issue is beyond the scope
of this PR

Test Plan:
python test/test_quantization.py test_non_reference_size
...
python test/test_quantization.py test_non_reference_<op>

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D34410122

fbshipit-source-id: fc09949ca8a2d6473876a4b6c214eb91e9a9dae2
(cherry picked from commit 3a1375d677b7c98d62b1f5c839645698c39b32b9)
2022-03-16 01:41:17 +00:00

28 lines
892 B
Python

# flake8: noqa: F401
r"""
This file is in the process of migration to `torch/ao/quantization`, and
is kept here for compatibility while the migration process is ongoing.
If you are adding a new entry/functionality, please, add it to the
appropriate files under `torch/ao/quantization/fx/`, while adding an import statement
here.
"""
from torch.ao.quantization.fx.utils import (
graph_pretty_str,
get_per_tensor_qparams,
quantize_node,
get_custom_module_class_keys,
get_linear_prepack_op_for_dtype,
get_qconv_prepack_op,
get_qconv_op,
get_new_attr_name_with_prefix,
graph_module_from_producer_nodes,
assert_and_get_unique_device,
create_getattr_from_value,
create_qparam_nodes,
all_node_args_have_no_tensors,
node_return_type_is_int,
get_non_observable_arg_indexes_and_types,
is_get_tensor_info_node,
maybe_get_next_module
)