pytorch/torch/ao/quantization/backend_config/utils.py
Vasiliy Kuznetsov c15fca1137 quant doc: improve rendered documentation for backend_config_dict
Summary:

This improves the documentation page for backend_config_dict to render
the configurations in a human readable format, such as

```
{
  'pattern': torch.nn.modules.pooling.AdaptiveAvgPool1d,
  'dtype_configs': [
    {
      'input_dtype': torch.quint8,
      'output_dtype': torch.quint8,
    },
    {
      'input_dtype': torch.float16,
      'weight_dtype': torch.float16,
      'bias_dtype': torch.float16,
      'output_dtype': torch.float16,
    },
  ],
  'observation_type': ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
},
```

The results are also now sorted alphabetically by the normalized name of
the root op in the pattern.

A couple of utility functions are created to help with this. If in the future
we convert backend_config_dict to use typed objects, we can move this logic
to the objects at that time.

Test plan:

```
cd docs
make html
cd build
python -m server.http
// renders correctly, example: https://gist.github.com/vkuzo/76adfc7c89e119c59813a733fa2cd56f
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77535

Approved by: https://github.com/andrewor14
2022-05-18 11:46:07 +00:00

203 lines
8.2 KiB
Python

from typing import Dict, Any, List, Callable, Union, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..quantization_types import Pattern
def get_pattern_to_dtype_configs(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, List[Dict[str, Any]]]:
pattern_to_dtype_configs: Dict[Pattern, List[Dict[str, torch.dtype]]] = dict()
for config in backend_config_dict.get("configs", []):
pattern = config["pattern"]
dtype_configs = config["dtype_configs"]
pattern_to_dtype_configs[pattern] = dtype_configs
return pattern_to_dtype_configs
def get_qat_module_classes(
backend_config_dict: Dict[str, Any]) -> Tuple[type, ...]:
qat_module_classes = []
for config in backend_config_dict.get("configs", []):
pattern = config["pattern"]
qat_module = config.get("qat_module", None)
if qat_module is not None:
qat_module_classes.append(qat_module)
return tuple(set(qat_module_classes))
def get_fused_module_classes(
backend_config_dict: Dict[str, Any]) -> Tuple[type, ...]:
fused_module_classes = []
for config in backend_config_dict.get("configs", []):
pattern = config["pattern"]
fused_module = config.get("fused_module", None)
if fused_module is not None:
fused_module_classes.append(fused_module)
return tuple(set(fused_module_classes))
def get_pattern_to_input_type_to_index(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Dict[str, int]]:
pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = dict()
for config in backend_config_dict.get("configs", []):
pattern = config["pattern"]
input_type_to_index = config.get("input_type_to_index", {})
pattern_to_input_type_to_index[pattern] = input_type_to_index
return pattern_to_input_type_to_index
def get_root_module_to_quantized_reference_module(
backend_config_dict: Dict[str, Any]) -> Dict[Callable, Callable]:
mapping: Dict[Callable, Callable] = dict()
for config in backend_config_dict.get("configs", []):
if "root_module" in config and "reference_quantized_module_for_root" in config:
mapping[config["root_module"]] = config["reference_quantized_module_for_root"]
return mapping
def get_fuser_method_mapping(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = dict()
for config in backend_config_dict.get("configs", []):
if "fuser_method" in config:
pattern = config["pattern"]
fuser_method = config["fuser_method"]
fuser_method_mapping[pattern] = fuser_method
return fuser_method_mapping
def get_module_to_qat_module(
backend_config_dict: Dict[str, Any]) -> Dict[Callable, Callable]:
module_to_qat_module: Dict[Callable, Callable] = dict()
for config in backend_config_dict.get("configs", []):
if "pattern" in config and "qat_module" in config:
pattern = config["pattern"]
qat_module = config["qat_module"]
module_to_qat_module[pattern] = qat_module
return module_to_qat_module
def get_fusion_pattern_to_root_node_getter(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
""" Get a map from fusion pattern to a function that returns the root node
from the fusion pattern, e.g. the most common one is:
def get_root_node(node_pattern):
while not isinstance(node_pattern[-1], Node):
node_pattern = node_pattern[-1]
return node_pattern[-1]
This can work for all patterns whose root node is the "last node" in the pattern,
e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d))
"""
root_node_getter_mapping: Dict[Pattern, Callable] = dict()
for config in backend_config_dict.get("configs", []):
if "root_node_getter" in config:
pattern = config["pattern"]
root_node_getter = config["root_node_getter"]
root_node_getter_mapping[pattern] = root_node_getter
return root_node_getter_mapping
def get_fusion_pattern_to_extra_inputs_getter(
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
""" Get a map from fusion pattern to a function that returns extra input nodes
from the fusion pattern, in the order required by the root node. This is optional,
if not specified, we will not copy over any extra inputs for the root node.
Example:
# Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d))
# and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra
# argument to the fused module, we can unpack the pattern and return the node at
# MatchAllNode here
# we can implement extra_inputs_getter as follows:
def extra_inputs_getter(pattern) -> List[Any]:
add, extra_input, conv_pattern = pattern
return [extra_input]
"""
extra_inputs_getter_mapping: Dict[Pattern, Callable] = dict()
for config in backend_config_dict.get("configs", []):
if "extra_inputs_getter" in config:
pattern = config["pattern"]
extra_inputs_getter = config["extra_inputs_getter"]
extra_inputs_getter_mapping[pattern] = extra_inputs_getter
return extra_inputs_getter_mapping
def remove_boolean_dispatch_from_name(p) -> Any:
"""
Some ops have a default string representation such as
'<function boolean_dispatch.<locals>.fn at 0x7ff1106bf280>',
this function replaces them with the hardcoded function names.
"""
if p is F.fractional_max_pool2d:
return "torch.nn.functional.fractional_max_pool2d"
elif p is F.fractional_max_pool3d:
return "torch.nn.functional.fractional_max_pool3d"
elif p is F.max_pool1d:
return "torch.nn.functional.max_pool1d"
elif p is F.max_pool2d:
return "torch.nn.functional.max_pool2d"
elif p is F.max_pool3d:
return "torch.nn.functional.max_pool3d"
elif p is F.adaptive_max_pool1d:
return "torch.nn.functional.adaptive_max_pool1d"
elif p is F.adaptive_max_pool2d:
return "torch.nn.functional.adaptive_max_pool2d"
elif p is F.adaptive_max_pool3d:
return "torch.nn.functional.adaptive_max_pool3d"
assert "boolean_dispatch" not in str(p), \
f"{p} does not have a human readable representation in " + \
"quantization documentation"
return p
def pattern_to_human_readable(p) -> Any:
if isinstance(p, tuple):
# nested patterns, recurse
return tuple(pattern_to_human_readable(inner_p) for inner_p in p)
elif isinstance(p, str):
# method names are already human readable
return p
else:
p = remove_boolean_dispatch_from_name(p)
return p
# TODO(future PR): move backend_config_dict to use dataclass and move this logic to
# the corresponding __str__ function
def entry_to_pretty_str(entry) -> str:
"""
Given a backend_config_dict entry, returns a string with the human readable
representation of it.
"""
s = "{\n"
# always output the pattern first
if "pattern" in entry:
pattern_str = pattern_to_human_readable(entry["pattern"])
s += f" 'pattern': {pattern_str},\n"
# custom output for dtype_configs to make it look nice
if "dtype_configs" in entry:
s += " 'dtype_configs': [\n"
for dtype_config in entry["dtype_configs"]:
s += " {\n"
for k, v in dtype_config.items():
s += f" '{k}': {v},\n"
s += " },\n"
s += " ],\n"
# custom output for num_tensor_args_to_observation_type to make it look nice
if "num_tensor_args_to_observation_type" in entry:
s += " 'num_tensor_args_to_observation_type': {\n"
for k, v in entry["num_tensor_args_to_observation_type"].items():
s += f" {k}: {v},\n"
s += " },\n"
# output all the other fields
custom_handled_fields = [
"pattern",
"dtype_configs",
"num_tensor_args_to_observation_type",
]
for field_name in entry:
if field_name in custom_handled_fields:
continue
s += f" '{field_name}': {entry[field_name]},\n"
s += "}"
return s