mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This improves the documentation page for backend_config_dict to render
the configurations in a human readable format, such as
```
{
'pattern': torch.nn.modules.pooling.AdaptiveAvgPool1d,
'dtype_configs': [
{
'input_dtype': torch.quint8,
'output_dtype': torch.quint8,
},
{
'input_dtype': torch.float16,
'weight_dtype': torch.float16,
'bias_dtype': torch.float16,
'output_dtype': torch.float16,
},
],
'observation_type': ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
},
```
The results are also now sorted alphabetically by the normalized name of
the root op in the pattern.
A couple of utility functions are created to help with this. If in the future
we convert backend_config_dict to use typed objects, we can move this logic
to the objects at that time.
Test plan:
```
cd docs
make html
cd build
python -m server.http
// renders correctly, example: https://gist.github.com/vkuzo/76adfc7c89e119c59813a733fa2cd56f
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77535
Approved by: https://github.com/andrewor14
63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
"""
|
|
This script will generate default values of quantization configs.
|
|
These are for use in the documentation.
|
|
"""
|
|
|
|
import torch
|
|
from torch.ao.quantization.backend_config import get_native_backend_config_dict
|
|
from torch.ao.quantization.backend_config.utils import (
|
|
entry_to_pretty_str,
|
|
remove_boolean_dispatch_from_name,
|
|
)
|
|
import os.path
|
|
|
|
|
|
# Create a directory for the images, if it doesn't exist
|
|
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
|
|
os.path.realpath(os.path.join(__file__, "..")),
|
|
"quantization_backend_configs"
|
|
)
|
|
|
|
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
|
|
os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)
|
|
|
|
output_path = os.path.join(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt")
|
|
|
|
with open(output_path, "w") as f:
|
|
native_backend_config_dict = get_native_backend_config_dict()
|
|
|
|
configs = native_backend_config_dict['configs']
|
|
|
|
def _sort_key_func(entry):
|
|
pattern = entry['pattern']
|
|
while isinstance(pattern, tuple):
|
|
pattern = pattern[-1]
|
|
|
|
pattern = remove_boolean_dispatch_from_name(pattern)
|
|
if not isinstance(pattern, str):
|
|
# methods are already strings
|
|
pattern = torch.typename(pattern)
|
|
|
|
# we want
|
|
#
|
|
# torch.nn.modules.pooling.AdaptiveAvgPool1d
|
|
#
|
|
# and
|
|
#
|
|
# torch._VariableFunctionsClass.adaptive_avg_pool1d
|
|
#
|
|
# to be next to each other, so convert to all lower case
|
|
# and remove the underscores, and compare the last part
|
|
# of the string
|
|
pattern_str_normalized = pattern.lower().replace('_', '')
|
|
key = pattern_str_normalized.split('.')[-1]
|
|
return key
|
|
|
|
configs.sort(key=_sort_key_func)
|
|
|
|
entries = []
|
|
for entry in configs:
|
|
entries.append(entry_to_pretty_str(entry))
|
|
entries = ",\n".join(entries)
|
|
f.write(entries)
|