mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
This improves the documentation page for backend_config_dict to render
the configurations in a human readable format, such as
```
{
'pattern': torch.nn.modules.pooling.AdaptiveAvgPool1d,
'dtype_configs': [
{
'input_dtype': torch.quint8,
'output_dtype': torch.quint8,
},
{
'input_dtype': torch.float16,
'weight_dtype': torch.float16,
'bias_dtype': torch.float16,
'output_dtype': torch.float16,
},
],
'observation_type': ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
},
```
The results are also now sorted alphabetically by the normalized name of
the root op in the pattern.
A couple of utility functions are created to help with this. If in the future
we convert backend_config_dict to use typed objects, we can move this logic
to the objects at that time.
Test plan:
```
cd docs
make html
cd build
python -m server.http
// renders correctly, example: https://gist.github.com/vkuzo/76adfc7c89e119c59813a733fa2cd56f
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77535
Approved by: https://github.com/andrewor14
203 lines
8.2 KiB
Python
203 lines
8.2 KiB
Python
from typing import Dict, Any, List, Callable, Union, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from ..quantization_types import Pattern
|
|
|
|
def get_pattern_to_dtype_configs(
|
|
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, List[Dict[str, Any]]]:
|
|
pattern_to_dtype_configs: Dict[Pattern, List[Dict[str, torch.dtype]]] = dict()
|
|
for config in backend_config_dict.get("configs", []):
|
|
pattern = config["pattern"]
|
|
dtype_configs = config["dtype_configs"]
|
|
pattern_to_dtype_configs[pattern] = dtype_configs
|
|
return pattern_to_dtype_configs
|
|
|
|
def get_qat_module_classes(
|
|
backend_config_dict: Dict[str, Any]) -> Tuple[type, ...]:
|
|
qat_module_classes = []
|
|
for config in backend_config_dict.get("configs", []):
|
|
pattern = config["pattern"]
|
|
qat_module = config.get("qat_module", None)
|
|
if qat_module is not None:
|
|
qat_module_classes.append(qat_module)
|
|
return tuple(set(qat_module_classes))
|
|
|
|
def get_fused_module_classes(
|
|
backend_config_dict: Dict[str, Any]) -> Tuple[type, ...]:
|
|
fused_module_classes = []
|
|
for config in backend_config_dict.get("configs", []):
|
|
pattern = config["pattern"]
|
|
fused_module = config.get("fused_module", None)
|
|
if fused_module is not None:
|
|
fused_module_classes.append(fused_module)
|
|
return tuple(set(fused_module_classes))
|
|
|
|
def get_pattern_to_input_type_to_index(
|
|
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Dict[str, int]]:
|
|
pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = dict()
|
|
for config in backend_config_dict.get("configs", []):
|
|
pattern = config["pattern"]
|
|
input_type_to_index = config.get("input_type_to_index", {})
|
|
pattern_to_input_type_to_index[pattern] = input_type_to_index
|
|
return pattern_to_input_type_to_index
|
|
|
|
def get_root_module_to_quantized_reference_module(
|
|
backend_config_dict: Dict[str, Any]) -> Dict[Callable, Callable]:
|
|
mapping: Dict[Callable, Callable] = dict()
|
|
for config in backend_config_dict.get("configs", []):
|
|
if "root_module" in config and "reference_quantized_module_for_root" in config:
|
|
mapping[config["root_module"]] = config["reference_quantized_module_for_root"]
|
|
return mapping
|
|
|
|
def get_fuser_method_mapping(
|
|
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
|
|
fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = dict()
|
|
for config in backend_config_dict.get("configs", []):
|
|
if "fuser_method" in config:
|
|
pattern = config["pattern"]
|
|
fuser_method = config["fuser_method"]
|
|
fuser_method_mapping[pattern] = fuser_method
|
|
|
|
return fuser_method_mapping
|
|
|
|
def get_module_to_qat_module(
|
|
backend_config_dict: Dict[str, Any]) -> Dict[Callable, Callable]:
|
|
module_to_qat_module: Dict[Callable, Callable] = dict()
|
|
for config in backend_config_dict.get("configs", []):
|
|
if "pattern" in config and "qat_module" in config:
|
|
pattern = config["pattern"]
|
|
qat_module = config["qat_module"]
|
|
module_to_qat_module[pattern] = qat_module
|
|
|
|
return module_to_qat_module
|
|
|
|
def get_fusion_pattern_to_root_node_getter(
|
|
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
|
|
""" Get a map from fusion pattern to a function that returns the root node
|
|
from the fusion pattern, e.g. the most common one is:
|
|
def get_root_node(node_pattern):
|
|
while not isinstance(node_pattern[-1], Node):
|
|
node_pattern = node_pattern[-1]
|
|
return node_pattern[-1]
|
|
This can work for all patterns whose root node is the "last node" in the pattern,
|
|
e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d))
|
|
"""
|
|
root_node_getter_mapping: Dict[Pattern, Callable] = dict()
|
|
for config in backend_config_dict.get("configs", []):
|
|
if "root_node_getter" in config:
|
|
pattern = config["pattern"]
|
|
root_node_getter = config["root_node_getter"]
|
|
root_node_getter_mapping[pattern] = root_node_getter
|
|
|
|
return root_node_getter_mapping
|
|
|
|
def get_fusion_pattern_to_extra_inputs_getter(
|
|
backend_config_dict: Dict[str, Any]) -> Dict[Pattern, Callable]:
|
|
""" Get a map from fusion pattern to a function that returns extra input nodes
|
|
from the fusion pattern, in the order required by the root node. This is optional,
|
|
if not specified, we will not copy over any extra inputs for the root node.
|
|
Example:
|
|
# Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d))
|
|
# and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra
|
|
# argument to the fused module, we can unpack the pattern and return the node at
|
|
# MatchAllNode here
|
|
# we can implement extra_inputs_getter as follows:
|
|
def extra_inputs_getter(pattern) -> List[Any]:
|
|
add, extra_input, conv_pattern = pattern
|
|
return [extra_input]
|
|
"""
|
|
extra_inputs_getter_mapping: Dict[Pattern, Callable] = dict()
|
|
for config in backend_config_dict.get("configs", []):
|
|
if "extra_inputs_getter" in config:
|
|
pattern = config["pattern"]
|
|
extra_inputs_getter = config["extra_inputs_getter"]
|
|
extra_inputs_getter_mapping[pattern] = extra_inputs_getter
|
|
|
|
return extra_inputs_getter_mapping
|
|
|
|
def remove_boolean_dispatch_from_name(p) -> Any:
|
|
"""
|
|
Some ops have a default string representation such as
|
|
'<function boolean_dispatch.<locals>.fn at 0x7ff1106bf280>',
|
|
this function replaces them with the hardcoded function names.
|
|
"""
|
|
if p is F.fractional_max_pool2d:
|
|
return "torch.nn.functional.fractional_max_pool2d"
|
|
elif p is F.fractional_max_pool3d:
|
|
return "torch.nn.functional.fractional_max_pool3d"
|
|
elif p is F.max_pool1d:
|
|
return "torch.nn.functional.max_pool1d"
|
|
elif p is F.max_pool2d:
|
|
return "torch.nn.functional.max_pool2d"
|
|
elif p is F.max_pool3d:
|
|
return "torch.nn.functional.max_pool3d"
|
|
elif p is F.adaptive_max_pool1d:
|
|
return "torch.nn.functional.adaptive_max_pool1d"
|
|
elif p is F.adaptive_max_pool2d:
|
|
return "torch.nn.functional.adaptive_max_pool2d"
|
|
elif p is F.adaptive_max_pool3d:
|
|
return "torch.nn.functional.adaptive_max_pool3d"
|
|
assert "boolean_dispatch" not in str(p), \
|
|
f"{p} does not have a human readable representation in " + \
|
|
"quantization documentation"
|
|
return p
|
|
|
|
def pattern_to_human_readable(p) -> Any:
|
|
if isinstance(p, tuple):
|
|
# nested patterns, recurse
|
|
return tuple(pattern_to_human_readable(inner_p) for inner_p in p)
|
|
elif isinstance(p, str):
|
|
# method names are already human readable
|
|
return p
|
|
else:
|
|
p = remove_boolean_dispatch_from_name(p)
|
|
return p
|
|
|
|
# TODO(future PR): move backend_config_dict to use dataclass and move this logic to
|
|
# the corresponding __str__ function
|
|
def entry_to_pretty_str(entry) -> str:
|
|
"""
|
|
Given a backend_config_dict entry, returns a string with the human readable
|
|
representation of it.
|
|
"""
|
|
s = "{\n"
|
|
|
|
# always output the pattern first
|
|
if "pattern" in entry:
|
|
pattern_str = pattern_to_human_readable(entry["pattern"])
|
|
|
|
s += f" 'pattern': {pattern_str},\n"
|
|
|
|
# custom output for dtype_configs to make it look nice
|
|
if "dtype_configs" in entry:
|
|
s += " 'dtype_configs': [\n"
|
|
for dtype_config in entry["dtype_configs"]:
|
|
s += " {\n"
|
|
for k, v in dtype_config.items():
|
|
s += f" '{k}': {v},\n"
|
|
s += " },\n"
|
|
s += " ],\n"
|
|
|
|
# custom output for num_tensor_args_to_observation_type to make it look nice
|
|
if "num_tensor_args_to_observation_type" in entry:
|
|
s += " 'num_tensor_args_to_observation_type': {\n"
|
|
for k, v in entry["num_tensor_args_to_observation_type"].items():
|
|
s += f" {k}: {v},\n"
|
|
s += " },\n"
|
|
|
|
# output all the other fields
|
|
custom_handled_fields = [
|
|
"pattern",
|
|
"dtype_configs",
|
|
"num_tensor_args_to_observation_type",
|
|
]
|
|
for field_name in entry:
|
|
if field_name in custom_handled_fields:
|
|
continue
|
|
s += f" '{field_name}': {entry[field_name]},\n"
|
|
|
|
s += "}"
|
|
return s
|