mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add Matmul recipe into x86_inductor_quantizer (#122776)
**Summary** Add `matmul` in the quantization recipes, noting that it's not a general recipe but tailored to meet accuracy criteria for specific models. `matmul` recipe is disabled by default. **Test Plan** ``` python -m pytest quantization/pt2e/test_x86inductor_quantizer.py -k test_attention_block ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/122776 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: #122775
This commit is contained in:
parent
8798f5bf0d
commit
e8e9261b90
|
|
@ -289,21 +289,42 @@ class TestHelperModules:
|
||||||
return tmp + self.bn2(self.conv2(tmp))
|
return tmp + self.bn2(self.conv2(tmp))
|
||||||
|
|
||||||
class SelfAttnLikeModule(torch.nn.Module):
|
class SelfAttnLikeModule(torch.nn.Module):
|
||||||
def __init__(self, input_dim) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
transpose_for_score=False,
|
||||||
|
num_attention_heads=None,
|
||||||
|
attention_head_size=None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.q_proj = nn.Linear(input_dim, input_dim, bias=False)
|
self.q_proj = nn.Linear(input_dim, input_dim, bias=False)
|
||||||
self.k_proj = nn.Linear(input_dim, input_dim, bias=False)
|
self.k_proj = nn.Linear(input_dim, input_dim, bias=False)
|
||||||
self.v_proj = nn.Linear(input_dim, input_dim, bias=False)
|
self.v_proj = nn.Linear(input_dim, input_dim, bias=False)
|
||||||
self.softmax = nn.Softmax(dim=-1)
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
self.transpose_for_score = transpose_for_score
|
||||||
|
if self.transpose_for_score:
|
||||||
|
assert num_attention_heads is not None
|
||||||
|
assert attention_head_size is not None
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_size = attention_head_size
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
|
x = x.view(new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
q = self.q_proj(x)
|
q = self.q_proj(x)
|
||||||
k = self.k_proj(x)
|
k = self.k_proj(x)
|
||||||
v = self.v_proj(x)
|
v = self.v_proj(x)
|
||||||
scores = torch.bmm(q, k.transpose(1, 2)) / (self.input_dim ** 0.5)
|
if self.transpose_for_score:
|
||||||
|
q = self.transpose_for_scores(q)
|
||||||
|
k = self.transpose_for_scores(k)
|
||||||
|
v = self.transpose_for_scores(v)
|
||||||
|
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim ** 0.5)
|
||||||
attention = self.softmax(scores)
|
attention = self.softmax(scores)
|
||||||
weighted = torch.bmm(attention, v)
|
weighted = torch.matmul(attention, v)
|
||||||
return weighted
|
return weighted
|
||||||
|
|
||||||
class X86InductorQuantTestCase(QuantizationTestCase):
|
class X86InductorQuantTestCase(QuantizationTestCase):
|
||||||
|
|
@ -1448,3 +1469,68 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||||
node_occurrence,
|
node_occurrence,
|
||||||
node_list,
|
node_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfNoX86
|
||||||
|
def test_attention_block(self):
|
||||||
|
"""
|
||||||
|
Test pattern of Attention like Block with X86InductorQuantizer.
|
||||||
|
"""
|
||||||
|
for annotate_matmul in [False, True]:
|
||||||
|
with override_quantized_engine("x86"), torch.no_grad():
|
||||||
|
m = TestHelperModules.SelfAttnLikeModule(
|
||||||
|
input_dim=64 * 16,
|
||||||
|
transpose_for_score=True,
|
||||||
|
num_attention_heads=16,
|
||||||
|
attention_head_size=64,
|
||||||
|
).eval()
|
||||||
|
example_inputs = (torch.randn(2, 384, 1024),)
|
||||||
|
|
||||||
|
m(*example_inputs)
|
||||||
|
|
||||||
|
quantizer = X86InductorQuantizer().set_global(
|
||||||
|
xiq.get_default_x86_inductor_quantization_config()
|
||||||
|
)
|
||||||
|
|
||||||
|
if annotate_matmul:
|
||||||
|
quantizer.set_function_type_qconfig(torch.matmul, quantizer.get_global_quantization_config())
|
||||||
|
|
||||||
|
node_occurrence = {
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5 if annotate_matmul else 1,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7 if annotate_matmul else 3,
|
||||||
|
# quantize_per_channel for weights are const propagated
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
|
||||||
|
}
|
||||||
|
if annotate_matmul:
|
||||||
|
node_list = [
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||||
|
torch.ops.aten.linear.default,
|
||||||
|
torch.ops.aten.view.default,
|
||||||
|
torch.ops.aten.permute.default,
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||||
|
torch.ops.aten.matmul.default,
|
||||||
|
torch.ops.aten.div.Tensor,
|
||||||
|
torch.ops.aten.softmax.int,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
node_list = [
|
||||||
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||||
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
||||||
|
torch.ops.aten.linear.default,
|
||||||
|
torch.ops.aten.view.default,
|
||||||
|
torch.ops.aten.permute.default,
|
||||||
|
torch.ops.aten.matmul.default,
|
||||||
|
torch.ops.aten.div.Tensor,
|
||||||
|
torch.ops.aten.softmax.int,
|
||||||
|
]
|
||||||
|
self._test_quantizer(
|
||||||
|
m,
|
||||||
|
example_inputs,
|
||||||
|
quantizer,
|
||||||
|
node_occurrence,
|
||||||
|
node_list,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,9 @@ default_quantizable_ops = propagation_quantizable_ops | {
|
||||||
|
|
||||||
# A superset of default_quantizable_ops includes operators support the int8 data type
|
# A superset of default_quantizable_ops includes operators support the int8 data type
|
||||||
# but not enabled by default recipe of X86InductorQuantizer.
|
# but not enabled by default recipe of X86InductorQuantizer.
|
||||||
quantizable_ops = default_quantizable_ops
|
quantizable_ops = default_quantizable_ops | {
|
||||||
|
torch.ops.aten.matmul.default,
|
||||||
|
}
|
||||||
|
|
||||||
QUANT_ANNOTATION_KEY = "quantization_annotation"
|
QUANT_ANNOTATION_KEY = "quantization_annotation"
|
||||||
|
|
||||||
|
|
@ -110,6 +112,12 @@ def _map_module_function_to_aten_operator_type():
|
||||||
],
|
],
|
||||||
torch.ops.aten.flatten.using_ints,
|
torch.ops.aten.flatten.using_ints,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
torch.matmul,
|
||||||
|
],
|
||||||
|
torch.ops.aten.matmul.default,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for map_item in map_list:
|
for map_item in map_list:
|
||||||
module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload]
|
module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload]
|
||||||
|
|
@ -310,6 +318,14 @@ class X86InductorQuantizer(Quantizer):
|
||||||
self.global_config = quantization_config
|
self.global_config = quantization_config
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_global_quantization_config(self):
|
||||||
|
if not isinstance(self.global_config, QuantizationConfig):
|
||||||
|
warnings.warn(
|
||||||
|
"The global_config for X86InductorQuantizer is currently invalid. \
|
||||||
|
Please ensure that you use set_global to establish the global quantization configuration."
|
||||||
|
)
|
||||||
|
return self.global_config
|
||||||
|
|
||||||
def set_function_type_qconfig(
|
def set_function_type_qconfig(
|
||||||
self,
|
self,
|
||||||
function_type: Callable,
|
function_type: Callable,
|
||||||
|
|
@ -499,6 +515,7 @@ class X86InductorQuantizer(Quantizer):
|
||||||
# Step1: Recipe of fusion patterns like conv/linear.
|
# Step1: Recipe of fusion patterns like conv/linear.
|
||||||
self._annotate_conv2d_fusion_pattern(model)
|
self._annotate_conv2d_fusion_pattern(model)
|
||||||
self._annotate_linear_fusion_pattern(model)
|
self._annotate_linear_fusion_pattern(model)
|
||||||
|
self._annotate_matmul(model)
|
||||||
|
|
||||||
# Step2: Recipe to propagate annotation for patterns beside conv/linear.
|
# Step2: Recipe to propagate annotation for patterns beside conv/linear.
|
||||||
# Go through all the nodes from start to end.
|
# Go through all the nodes from start to end.
|
||||||
|
|
@ -752,6 +769,24 @@ class X86InductorQuantizer(Quantizer):
|
||||||
self._annotate_linear_unary(model, config)
|
self._annotate_linear_unary(model, config)
|
||||||
self._annotate_linear(model, config)
|
self._annotate_linear(model, config)
|
||||||
|
|
||||||
|
def _annotate_matmul(self, model: torch.fx.GraphModule):
|
||||||
|
if config := self._get_aten_operator_qconfig(torch.ops.aten.matmul.default):
|
||||||
|
for node in model.graph.nodes:
|
||||||
|
if node.target == torch.ops.aten.matmul.default and not _is_annotated(
|
||||||
|
[node]
|
||||||
|
):
|
||||||
|
input_qspec_map = {}
|
||||||
|
matmul_node = node
|
||||||
|
for input_node in matmul_node.args:
|
||||||
|
input_qspec_map[input_node] = get_input_act_qspec(config)
|
||||||
|
matmul_node.meta[
|
||||||
|
QUANT_ANNOTATION_KEY
|
||||||
|
] = _X86InductorQuantizationAnnotation(
|
||||||
|
input_qspec_map=input_qspec_map,
|
||||||
|
_annotated=True,
|
||||||
|
_is_output_of_quantized_pattern=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _annotate_conv2d_binary_unary(
|
def _annotate_conv2d_binary_unary(
|
||||||
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user