pytorch/torch/ao/quantization/__init__.py
Andrew Or c7b4eec233 [Quant][fx][bc-breaking] Replace qconfig_dict with a config object (#78452)
**Summary:** Previously, FX graph mode quantization configurations
were specified through a dictionary of qconfigs. However, this
API was not in line with other core APIs in PyTorch. This commit
replaces this dictionary with a config object that users will
create and pass to prepare and convert. This leads to better
type safety and better user experience in notebook settings
due to improved auto completion.

The new API is as follows:

```
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx

qconfig_mapping = QConfigMapping()
    .set_global(qconfig)
    .set_object_type(torch.nn.Linear, qconfig)
    .set_module_name_regex("foo.*bar", qconfig)
    .set_module_name("mod", qconfig)

prepare_fx(model, qconfig_mapping)
```

For backwards compatibility, `prepare_fx`, `prepare_qat_fx`,
and `convert_fx` will continue to accept qconfig_dicts, which
will be converted to QuantizationConfigs internally.

Note that this commit does not modify existing tests to use the
new API; they will continue to pass in qconfig_dict as before,
which still works but triggers a deprecation warning. This will
be handled in a future commit.

**Test Plan:**
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

**Reviewers:** jerryzh168, vkuzo

**Subscribers:** jerryzh168, vkuzo

Differential Revision: D36747998

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78452
Approved by: https://github.com/jerryzh168
2022-05-30 18:30:07 +00:00

23 lines
795 B
Python

# flake8: noqa: F403
from .fake_quantize import * # noqa: F403
from .fuse_modules import fuse_modules # noqa: F403
from .fuse_modules import fuse_modules_qat # noqa: F403
from .fuser_method_mappings import * # noqa: F403
from .observer import * # noqa: F403
from .qconfig import * # noqa: F403
from .qconfig_mapping import * # noqa: F403
from .quant_type import * # noqa: F403
from .quantization_mappings import * # noqa: F403
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
def default_eval_fn(model, calib_data):
r"""
Default evaluation function takes a torch.utils.data.Dataset or a list of
input Tensors and run the model on the dataset
"""
for data, target in calib_data:
model(data)