mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: X-link: https://github.com/pytorch/executorch/pull/3 att Test Plan: Imported from OSS Differential Revision: D47202807 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104668 Approved by: https://github.com/andrewor14
119 lines
3.9 KiB
Python
119 lines
3.9 KiB
Python
from typing import List
|
|
|
|
import torch
|
|
from torch.ao.quantization.pt2e.quantizer.quantizer import (
|
|
QuantizationAnnotation,
|
|
QuantizationConfig,
|
|
QuantizationSpec,
|
|
)
|
|
from torch.fx import Node
|
|
|
|
__all__ = [
|
|
"get_input_act_qspec",
|
|
"get_output_act_qspec",
|
|
"get_weight_qspec",
|
|
"get_bias_qspec",
|
|
]
|
|
|
|
def get_input_act_qspec(quantization_config: QuantizationConfig):
|
|
if quantization_config is None:
|
|
return None
|
|
if quantization_config.input_activation is None:
|
|
return None
|
|
quantization_spec: QuantizationSpec = quantization_config.input_activation
|
|
assert quantization_spec.qscheme in [
|
|
torch.per_tensor_affine,
|
|
torch.per_tensor_symmetric,
|
|
]
|
|
return quantization_spec
|
|
|
|
|
|
def get_output_act_qspec(quantization_config: QuantizationConfig):
|
|
if quantization_config is None:
|
|
return None
|
|
if quantization_config.output_activation is None:
|
|
return None
|
|
quantization_spec: QuantizationSpec = quantization_config.output_activation
|
|
assert quantization_spec.qscheme in [
|
|
torch.per_tensor_affine,
|
|
torch.per_tensor_symmetric,
|
|
]
|
|
return quantization_spec
|
|
|
|
|
|
def get_weight_qspec(quantization_config: QuantizationConfig):
|
|
if quantization_config is None:
|
|
return None
|
|
assert quantization_config is not None
|
|
if quantization_config.weight is None:
|
|
return None
|
|
quantization_spec: QuantizationSpec = quantization_config.weight
|
|
if quantization_spec.qscheme not in [
|
|
torch.per_tensor_symmetric,
|
|
torch.per_channel_symmetric,
|
|
]:
|
|
raise ValueError(
|
|
f"Unsupported quantization_spec {quantization_spec} for weight"
|
|
)
|
|
return quantization_spec
|
|
|
|
|
|
def get_bias_qspec(quantization_config: QuantizationConfig):
|
|
if quantization_config is None:
|
|
return None
|
|
assert quantization_config is not None
|
|
if quantization_config.bias is None:
|
|
return None
|
|
quantization_spec: QuantizationSpec = quantization_config.bias
|
|
assert (
|
|
quantization_spec.dtype == torch.float
|
|
), "Only float dtype for bias is supported for bias right now"
|
|
return quantization_spec
|
|
|
|
|
|
def _annotate_input_qspec_map(node: Node, input_node: Node, qspec):
|
|
quantization_annotation = node.meta.get(
|
|
"quantization_annotation", QuantizationAnnotation()
|
|
)
|
|
if quantization_annotation.input_qspec_map is None:
|
|
quantization_annotation.input_qspec_map = {}
|
|
quantization_annotation.input_qspec_map[input_node] = qspec
|
|
node.meta["quantization_annotation"] = quantization_annotation
|
|
|
|
|
|
def _annotate_output_qspec(node: Node, qspec):
|
|
quantization_annotation = node.meta.get(
|
|
"quantization_annotation", QuantizationAnnotation()
|
|
)
|
|
quantization_annotation.output_qspec = qspec
|
|
node.meta["quantization_annotation"] = quantization_annotation
|
|
|
|
|
|
def _is_sym_size_node(node: Node):
|
|
return node.op == "call_function" and node.target == torch.ops.aten.sym_size
|
|
|
|
|
|
def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
|
|
"""
|
|
This utility is used to handle cases when dynami_shape=True tracing leads
|
|
to symint nodes in the pattern of linear module. In those cases, we need to
|
|
distinguish between the nodes that are in input for just extracting value of
|
|
some dimentions (and symint nodes) vs. the one that is activation.
|
|
For example:
|
|
graph(x, y, weight):
|
|
size_0 = torch.ops.aten.sym_size([x], [0])
|
|
size_1 = torch.ops.aten.sym_size([y], [1])
|
|
view_size = size_0 * size_1
|
|
size_3 = torch.ops.aten.sym_size([x], [2])
|
|
vie_out = torch.ops.aten.view(x, [view_size, size_3])
|
|
return mm(view_out, weight)
|
|
In the example above y node is not actual input. It exist only to extract size_1
|
|
"""
|
|
if _is_sym_size_node(node):
|
|
return True
|
|
|
|
return all(
|
|
((user not in partition_nodes) or _is_sym_size_node(user))
|
|
for user in node.users
|
|
)
|