pytorch/torch/ao/ns/fx/mappings.py
zaf efccb6401c [quant][ao_migration] nn.intrinsic.qat migration to ao (#86171)
All quantization-related modules are being migrated to `torch.ao`. This migrates the `nn.intrinsic.qat`. Please, see the [tracker](https://github.com/pytorch/pytorch/issues/81667) for the timeline.

```
python test/test_quantization.py TestAOMigrationNNIntrinsic
```

Differential Revision: [D39419993](https://our.internmc.facebook.com/intern/diff/D39419993/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39419993/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86171
Approved by: https://github.com/jerryzh168
2022-10-07 17:29:42 +00:00

726 lines
17 KiB
Python

import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
toq = torch.ops.quantized
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.ao.nn.intrinsic.qat as nniqat
import torch.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.ao.nn.qat.dynamic as nnqatd
from torch.ao.quantization.backend_config import get_native_backend_config_dict
import torch.ao.quantization.fx._lower_to_native_backend as \
_lower_to_native_backend
import torch.ao.quantization.quantization_mappings as quantization_mappings
from .ns_types import NSNodeTargetType
from typing import Set, Dict, List, Optional
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
# note: this set is modified below by items from backend_config_dict
sets_of_related_ops: List[Set[NSNodeTargetType]] = [
# conv modules
set([
nn.Conv1d,
]),
set([
nn.Conv2d,
]),
set([
nn.Conv3d,
]),
# conv functionals
set([
F.conv1d,
]),
set([
F.conv2d,
]),
set([
F.conv3d,
]),
# linear modules
set([
nn.Linear,
]),
# linear functionals
set([
F.linear,
]),
# average pool
set([
nn.AvgPool1d,
torch.avg_pool1d,
]),
set([
nn.AvgPool2d,
torch._C._nn.avg_pool2d,
]),
set([
nn.AvgPool3d,
torch._C._nn.avg_pool3d,
]),
# adaptive average pool
set([
nn.AdaptiveAvgPool1d,
F.adaptive_avg_pool1d,
]),
set([
nn.AdaptiveAvgPool2d,
F.adaptive_avg_pool2d,
]),
set([
nn.AdaptiveAvgPool3d,
F.adaptive_avg_pool3d,
]),
# LSTM
set([
nn.LSTM,
]),
# add
set([
torch.add,
operator.add, # x + y
]),
# cat
set([
torch.cat,
]),
# mul
set([
torch.mul,
operator.mul,
]),
# relu
set([
F.relu,
nn.ReLU,
'relu',
'relu_',
torch.relu,
]),
# maxpool
set([
nn.MaxPool1d,
F.max_pool1d,
]),
set([
nn.MaxPool2d,
F.max_pool2d,
]),
set([
nn.MaxPool3d,
F.max_pool3d,
]),
# sigmoid
set([
torch.sigmoid,
'sigmoid',
'sigmoid_',
nn.Sigmoid,
F.sigmoid,
]),
# BatchNorm
set([
nn.BatchNorm2d,
]),
set([
nn.BatchNorm3d,
]),
# ConvTranspose
set([
nn.ConvTranspose1d,
]),
set([
nn.ConvTranspose2d,
]),
set([
nn.ConvTranspose3d,
]),
# ELU
set([
nn.ELU,
]),
# Embedding
set([
nn.Embedding,
]),
# EmbeddingBag
set([
nn.EmbeddingBag,
]),
# GroupNorm
set([
nn.GroupNorm,
]),
# Hardswish
set([
nn.Hardswish,
]),
# InstanceNorm
set([
nn.InstanceNorm1d,
]),
set([
nn.InstanceNorm2d,
]),
set([
nn.InstanceNorm3d,
]),
# LayerNorm
set([
nn.LayerNorm,
]),
# LeakyReLU
set([
nn.LeakyReLU,
]),
# ReLU6
set([
nn.ReLU6,
F.relu6,
]),
# F.elu
set([
F.elu,
]),
# F.hardswish
set([
F.hardswish,
]),
# F.group_norm
set([
F.group_norm,
]),
# F.instance_norm
set([
F.instance_norm,
]),
# F.layer_norm
set([
F.layer_norm,
]),
# F.leaky_relu
set([
F.leaky_relu,
]),
# F.silu
set([
nn.SiLU,
F.silu,
]),
# F.mish
set([
nn.Mish,
F.mish,
]),
# F.tanh
set([
nn.Tanh,
F.tanh,
torch.tanh,
'tanh_',
'tanh',
]),
# F.hardsigmoid
set([
'hardsigmoid_',
'hardsigmoid',
F.hardsigmoid,
nn.Hardsigmoid,
]),
# F.hardtanh
set([
nn.Hardtanh,
F.hardtanh,
F.hardtanh_,
]),
# floordiv
set([
operator.floordiv,
]),
# unsqueeze
set([
torch.unsqueeze,
]),
# stack
set([
torch.stack,
]),
# squeeze
set([
torch.squeeze,
]),
# sort
set([
torch.sort,
]),
# repeat_interleave
set([
torch.repeat_interleave,
]),
# min
set([
torch.min,
]),
# mean
set([
torch.mean,
]),
# max
set([
torch.max,
]),
# transpose
set([
torch.transpose,
]),
# flatten
set([
torch.flatten,
]),
# clamp
set([
torch.clamp,
]),
# chunk
set([
torch.chunk,
]),
# interpolate
set([
torch.nn.functional.interpolate,
]),
# dropout
set([
nn.Dropout,
]),
# F.dropout
set([
F.dropout,
]),
# matmul
set([
torch.matmul,
]),
# Softmax
set([
nn.Softmax,
]),
# PReLU
set([
nn.PReLU,
nnq.PReLU,
]),
# F.prelu
set([
F.prelu,
toq.prelu,
]),
]
# for each floating point op, add versions of the op added by
# backend_config_dict
backend_config_dict = get_native_backend_config_dict()
new_connections = [
# technical debt edge case
(nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
]
for config in backend_config_dict['configs']:
if 'pattern' not in config:
continue
# format: (c, (b, a))
pattern = config['pattern']
first_element = pattern
# look from the end, because pattern is in reverse order
while isinstance(first_element, (list, tuple)):
first_element = first_element[-1]
if 'fused_module' in config:
# case 1: pattern fuses a pattern of ops into an op
# example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
new_connections.append((first_element, config['fused_module']))
if 'qat_module' in config:
# case 2: pattern swaps a module into a QAT module
# example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
new_connections.append((first_element, config['qat_module']))
if 'reference_quantized_module_for_root' in config:
# case 3: reference version of floating point module, such as
# nn.Conv2d and nnqr.Conv2d
new_connections.append(
(first_element, config['reference_quantized_module_for_root'])
)
#
# Add reference module swaps from default lowering path
#
for source_to_target in (
_lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
_lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
_lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
_lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
):
for source, target in source_to_target.items(): # type: ignore[attr-defined]
new_connections.append((source, target))
for source_to_double_target in (
_lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
_lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
):
for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined]
new_connections.append((source, target1))
new_connections.append((source, target2))
#
# Add function swaps from default lowering path
#
for source, (target1, target2) in \
_lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
new_connections.append((source, target1))
new_connections.append((source, target2))
for source_to_target in (
_lower_to_native_backend.QBIN_OP_MAPPING,
_lower_to_native_backend.QBIN_RELU_OP_MAPPING,
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
):
for source, target in source_to_target.items():
new_connections.append((source, target))
#
# Add other swaps, ideally in the future this could be removed
# after the lowering code stops using these.
#
for source_to_target in (
quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
):
for source, target in source_to_target.items():
new_connections.append((source, target))
# add the new connections from backend_config_dict
for item1, item2 in new_connections:
for set_of_related_ops in sets_of_related_ops:
if item1 in set_of_related_ops or item2 in set_of_related_ops:
set_of_related_ops.add(item1)
set_of_related_ops.add(item2)
break
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {}
counter = 0
for set_of_related_ops in sets_of_related_ops:
base_name = str(counter)
counter += 1
base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
return base_name_to_sets_of_related_ops
def get_base_name_for_op(
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
op: NSNodeTargetType,
) -> Optional[str]:
for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
if op in set_of_related_ops:
return base_name
return None
def add_op_to_sets_of_related_ops(
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
op: NSNodeTargetType,
related_op: Optional[NSNodeTargetType],
) -> None:
if related_op is not None:
for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
if related_op in set_of_related_ops:
set_of_related_ops.add(op)
return
# if we got here, related_op was not found
raise AssertionError(f"{related_op} was not found")
else:
counter = 0
while str(counter) in base_name_to_sets_of_related_ops:
counter += 1
base_name_to_sets_of_related_ops[str(counter)] = set([op])
# TODO(future PR): clean this up
def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([
F.linear,
F.conv1d,
F.conv2d,
F.conv3d,
torch.cat,
F.elu,
F.hardswish,
F.instance_norm,
F.layer_norm,
F.leaky_relu,
F.dropout,
F.silu,
F.mish,
operator.add,
torch.add,
operator.mul,
torch.mul,
torch.sum,
F.prelu,
])
FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
toq.linear,
toq.linear_relu,
toq.conv1d,
toq.conv1d_relu,
toq.conv2d,
toq.conv2d_relu,
toq.conv3d,
toq.conv3d_relu,
toq.cat,
toq.elu,
toq.hardswish,
toq.instance_norm,
toq.layer_norm,
toq.leaky_relu,
toq.dropout,
toq.prelu,
# TODO(future PR): implement shadowing for binary ops and
# uncomment below
# toq.add,
# toq.mul,
])
FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
F.relu,
F.tanh,
torch.tanh,
F.sigmoid,
torch.sigmoid,
F.hardsigmoid,
operator.floordiv,
torch.adaptive_avg_pool1d,
F.adaptive_avg_pool2d,
F.adaptive_avg_pool3d,
F.dropout,
F.hardtanh,
F.hardtanh_,
F.interpolate,
F.max_pool1d,
F.max_pool2d,
F.max_pool3d,
F.relu6,
torch.avg_pool1d,
torch._C._nn.avg_pool2d,
torch._C._nn.avg_pool3d,
torch.cat,
torch.chunk,
torch.clamp,
torch.flatten,
torch.transpose,
torch.max,
torch.mean,
torch.min,
torch.repeat_interleave,
torch.sort,
torch.squeeze,
torch.stack,
torch.unsqueeze,
operator.add,
])
MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([
nn.Linear,
nnqat.Linear,
nnqatd.Linear,
nnqd.Linear,
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nnqat.Conv1d,
nnqat.Conv2d,
nnqat.Conv3d,
nnqat.Embedding,
nnqat.EmbeddingBag,
nn.LSTM,
# note: nnqd.Linear is an instance of nnq.Linear, so this
# check has to happen before the int8 module check
nnqd.LSTM,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.Dropout,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.ELU,
nn.GroupNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
nn.LayerNorm,
nn.Hardswish,
nn.LeakyReLU,
nn.ReLU6,
nn.SiLU,
nn.Mish,
nn.Softmax,
nn.PReLU,
nni.BNReLU2d,
nni.BNReLU3d,
nni.ConvReLU1d,
nni.ConvReLU2d,
nni.ConvReLU3d,
nni.LinearReLU,
nni.LinearBn1d,
nni.ConvBn1d,
nni.ConvBn2d,
nni.ConvBn3d,
nniqat.ConvBn1d,
nniqat.ConvBn2d,
nniqat.ConvBn3d,
nniqat.ConvBnReLU1d,
nniqat.ConvBnReLU2d,
nniqat.ConvBnReLU3d,
nniqat.ConvReLU1d,
nniqat.ConvReLU2d,
nniqat.ConvReLU3d,
nniqat.LinearReLU,
nniqat.LinearBn1d,
nniqd.LinearReLU,
])
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([
nnq.Linear,
nnq.Conv1d,
nnq.Conv2d,
nnq.Conv3d,
nnq.BatchNorm2d,
nnq.BatchNorm3d,
nnq.Dropout,
nnq.ConvTranspose1d,
nnq.ConvTranspose2d,
nnq.ELU,
nnq.InstanceNorm1d,
nnq.InstanceNorm2d,
nnq.InstanceNorm3d,
nnq.LayerNorm,
nnq.Hardswish,
nnq.LeakyReLU,
nnq.Embedding,
nnq.EmbeddingBag,
nnq.Dropout,
nnq.Softmax,
nnq.PReLU,
nniq.BNReLU2d,
nniq.BNReLU3d,
nniq.ConvReLU1d,
nniq.ConvReLU2d,
nniq.ConvReLU3d,
nniq.LinearReLU,
])
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
nn.ReLU,
nn.Tanh,
nn.Sigmoid,
nn.Hardsigmoid,
nn.AdaptiveAvgPool1d,
nn.AdaptiveAvgPool2d,
nn.AdaptiveAvgPool3d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.Dropout,
nn.Hardtanh,
nn.Identity,
nn.MaxPool1d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.ReLU6,
])
METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([
'sigmoid_',
'sigmoid',
'tanh_',
'tanh',
'hardsigmoid_',
'hardsigmoid',
'relu_',
'relu',
])
return {
'funs_io_type_fp32': FUNS_IO_TYPE_FP32,
'funs_io_type_fp16': FUNS_IO_TYPE_FP16,
'funs_io_type_int8': FUNS_IO_TYPE_INT8,
'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8,
'mods_io_type_fp32': MODS_IO_TYPE_FP32,
'mods_io_type_int8': MODS_IO_TYPE_INT8,
'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8,
'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8,
}
def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
FUNS_UNMATCHABLE: Set[NSNodeTargetType] = set([
torch.quantize_per_tensor,
operator.getitem,
])
MODS_UNMATCHABLE: Set[NSNodeTargetType] = set([
nn.Identity,
])
METHS_UNMATCHABLE: Set[NSNodeTargetType] = set([
'to',
'dequantize',
'reshape',
'view',
'unsqueeze_',
'unsqueeze',
'transpose',
'squeeze_',
'squeeze',
'size',
'shape',
'resize_',
'repeat_interleave',
'repeat',
'permute',
'numel',
'mean',
'detach_',
'detach',
'contiguous',
'clamp',
'chunk',
])
return {
'funs_unmatchable': FUNS_UNMATCHABLE,
'mods_unmatchable': MODS_UNMATCHABLE,
'meths_unmatchable': METHS_UNMATCHABLE,
}