pytorch/torch/ao/quantization/quantizer/utils.py
Jerry Zhang 92a22a8098 [quant][pt2e][quantizer] Suppoert set_module_name in XNNPACKQuantizer (#106087)
Summary:
Added support to allow users to set configurations based on module name in XNNPACKQuantizer, can also serve as an example
for implementing new quantizers

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_xnnpack_quantizer_set_module_name

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106087
Approved by: https://github.com/andrewor14
2023-08-02 01:19:23 +00:00

126 lines
4.1 KiB
Python

from typing import List, Optional
import torch
from torch.ao.quantization.quantizer.quantizer import (
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
)
from torch.fx import Node
__all__ = [
"get_input_act_qspec",
"get_output_act_qspec",
"get_weight_qspec",
"get_bias_qspec",
]
def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
if quantization_config.input_activation is None:
return None
quantization_spec: QuantizationSpec = quantization_config.input_activation
assert quantization_spec.qscheme in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]
return quantization_spec
def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
if quantization_config.output_activation is None:
return None
quantization_spec: QuantizationSpec = quantization_config.output_activation
assert quantization_spec.qscheme in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]
return quantization_spec
def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert quantization_config is not None
if quantization_config.weight is None:
return None
quantization_spec: QuantizationSpec = quantization_config.weight
if quantization_spec.qscheme not in [
torch.per_tensor_symmetric,
torch.per_channel_symmetric,
]:
raise ValueError(
f"Unsupported quantization_spec {quantization_spec} for weight"
)
return quantization_spec
def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert quantization_config is not None
if quantization_config.bias is None:
return None
quantization_spec: QuantizationSpec = quantization_config.bias
assert (
quantization_spec.dtype == torch.float
), "Only float dtype for bias is supported for bias right now"
return quantization_spec
def _annotate_input_qspec_map(node: Node, input_node: Node, qspec):
quantization_annotation = node.meta.get(
"quantization_annotation", QuantizationAnnotation()
)
if quantization_annotation.input_qspec_map is None:
quantization_annotation.input_qspec_map = {}
quantization_annotation.input_qspec_map[input_node] = qspec
node.meta["quantization_annotation"] = quantization_annotation
def _annotate_output_qspec(node: Node, qspec):
quantization_annotation = node.meta.get(
"quantization_annotation", QuantizationAnnotation()
)
quantization_annotation.output_qspec = qspec
node.meta["quantization_annotation"] = quantization_annotation
def _is_sym_size_node(node: Node):
return (
node.op == "call_function"
and node.target == torch.ops.aten.sym_size.default
or node.target == torch.ops.aten.sym_numel.default
or node.target == torch.ops.aten.sym_numel
or node.target == torch.ops.aten.sym_size
)
def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
"""
This utility is used to handle cases when dynami_shape=True tracing leads
to symint nodes in the pattern of linear module. In those cases, we need to
distinguish between the nodes that are in input for just extracting value of
some dimentions (and symint nodes) vs. the one that is activation.
For example:
graph(x, y, weight):
size_0 = torch.ops.aten.sym_size([x], [0])
size_1 = torch.ops.aten.sym_size([y], [1])
view_size = size_0 * size_1
size_3 = torch.ops.aten.sym_size([x], [2])
vie_out = torch.ops.aten.view(x, [view_size, size_3])
return mm(view_out, weight)
In the example above y node is not actual input. It exist only to extract size_1
"""
if _is_sym_size_node(node):
return True
return all(
((user not in partition_nodes) or _is_sym_size_node(user))
for user in node.users
)