Revert "[Inductor][CPU] Fuse SmoothQuant int8 linear pattern (#139595)"

This reverts commit d72a308e77.

Reverted https://github.com/pytorch/pytorch/pull/139595 on behalf of https://github.com/ZainRizvi due to Sorry but the newly added tests in test_mkldnn_pattern_matcher.py fail internally. See D65661038 for more details ([comment](https://github.com/pytorch/pytorch/pull/139595#issuecomment-2465797016))
This commit is contained in:
PyTorch MergeBot 2024-11-08 21:45:52 +00:00
parent 80d0356b11
commit a7724518c0
3 changed files with 10 additions and 327 deletions

View File

@ -932,8 +932,8 @@ static at::Tensor linear_int8_with_onednn_weight(
c10::string_view& unary_post_op_algorithm) {
using ideep::tensor;
const int64_t dim = input.dim();
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char,
"qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char).");
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
TORCH_CHECK(
@ -1022,8 +1022,7 @@ static at::Tensor linear_int8_with_onednn_weight(
empty_tensor;
// Create onednn primitive
auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8;
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
auto src_desc = tensor::desc(src_dims, ideep::data_type::u8, ideep::format_tag::any);
auto weights_desc = packed_weight.get_desc();
auto dst_dtype = dst.get_data_type();
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
@ -1120,14 +1119,12 @@ namespace at::native {
torch::List<std::optional<at::Scalar>> post_op_args,
c10::string_view post_op_algorithm) {
#if AT_MKLDNN_ENABLED()
// act_zero_point.numel() == 0 for symmetric quantization
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
"onednn int8 linear: act scale/zp size should be 1");
static std::optional<at::Tensor> other = std::nullopt;
static const c10::string_view binary_post_op = "none";
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zp,
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
onednn_weight, weight_scales, weight_zero_points,
bias, output_scale, output_zero_point, output_dtype,
other, /*other scale*/1.0, /*other zp*/0,
@ -1158,12 +1155,10 @@ namespace at::native {
torch::List<std::optional<at::Scalar>> unary_post_op_args,
c10::string_view unary_post_op_algorithm) {
#if AT_MKLDNN_ENABLED()
// act_zero_point.numel() == 0 for symmetric quantization
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
"onednn int8 linear: act scale/zp size should be 1");
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zp,
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
onednn_weight, weight_scales, weight_zero_points,
bias, output_scale, output_zero_point, output_dtype,
other, other_scale, other_zero_point,

View File

@ -145,7 +145,6 @@ class TestPatternMatcherBase(TestCase):
dtype=None,
is_dynamic=False,
quantizer=None,
compile_options={}, # noqa: B006
):
counters.clear()
torch._dynamo.reset()
@ -189,7 +188,7 @@ class TestPatternMatcherBase(TestCase):
with torch.no_grad(), maybe_autocast:
clone_inputs = self._clone_inputs(inputs)
expected = mod(*inputs)
actual = torch.compile(mod, **compile_options)(*clone_inputs)
actual = torch.compile(mod)(*clone_inputs)
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
if matcher_count is not None:
self.assertEqual(
@ -2825,94 +2824,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
rtol=0.07,
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_smooth_quant_with_int_mm(self):
r"""
This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao.
The pattern is:
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
or
(with bias) pattern_no_bias -> add -> reshape -> reshape
"""
M = 16
in_feature = 32
out_feature = 64
q_min, q_max = -32, 31
class Mod(torch.nn.Module):
def __init__(
self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool
):
super().__init__()
self.dtype = dtype
self.has_bias = has_bias
self.b = torch.randint(
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
)
self.per_channel_quant = per_channel_quant
a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01
a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
self.a_scale = (
a_scale_per_channel
if self.per_channel_quant
else a_scale_per_tensor
)
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
self.b_scale = self.b_scale.to(dtype)
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
def forward(self, a):
out_shape = a.shape[:-1] + (self.b.size(-1),)
a_reshaped = a.reshape(-1, a.size(-1))
c = torch._int_mm(a_reshaped, self.b)
c = c.to(self.dtype)
c_shape = c.shape
a_scale = self.a_scale.expand(c.shape)
c = c * a_scale
c = c * self.b_scale
if self.has_bias:
c = c.reshape([1, *list(c_shape)])
c = c + self.bias
c = c.reshape(c_shape)
c = c.reshape(out_shape)
return c
has_bias_list = [True, False]
dype_list = (
[torch.float, torch.bfloat16]
if torch.ops.mkldnn._is_mkldnn_bf16_supported()
else [torch.float]
)
per_channel_list = [True, False]
dynamic_list = [True, False]
for has_bias, dtype, per_channel_quant, dynamic in itertools.product(
has_bias_list, dype_list, per_channel_list, dynamic_list
):
mod = Mod(dtype, has_bias, per_channel_quant).eval()
a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8)
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
)
if dynamic:
nodes_count = 10 if has_bias else 7
else:
nodes_count = 7 if has_bias else 6
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
nodes_count,
)
self._test_common(
mod,
(a,),
matcher_check_fn=matcher_check_fn,
check_autocast=dtype,
compile_options={"dynamic": dynamic},
)
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class TestDynamicPatternMatcher(TestPatternMatcherBase):

View File

@ -2529,226 +2529,6 @@ def _register_qlinear_weight_prepack():
)
def _register_smooth_quant_int_mm_pattern():
"""
The pattern is:
(no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape
or
(with bias) pattern_no_bias -> add (-> reshape -> reshape)
"""
# When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist
# When torch.compile'ing with dynamic=False, they don't exist
def get_pattern_no_bias(expand_a_scale: bool):
return CallFunction(
aten.reshape.default,
CallFunction(
aten.mul.Tensor,
CallFunction(
aten.mul.Tensor,
CallFunction(
prims.convert_element_type.default,
CallFunction(
aten._int_mm.default,
CallFunction(
aten.reshape.default,
KeywordArg("a"),
KeywordArg("in_shape"),
),
KeywordArg("b"),
),
KeywordArg("dtype"),
),
(
CallFunction(
aten.expand.default,
KeywordArg("x_scale"),
Arg(),
)
if expand_a_scale
else KeywordArg("x_scale")
),
),
KeywordArg("w_scale"),
),
KeywordArg("out_shape_no_bias"),
)
# for torch.compile(dynamic=False)
pattern_no_bias_1 = get_pattern_no_bias(expand_a_scale=False)
pattern_with_bias_1 = CallFunction(
aten.add.Tensor,
pattern_no_bias_1,
KeywordArg("bias"),
)
# for torch.compile(dynamic=True)
pattern_no_bias_2 = get_pattern_no_bias(expand_a_scale=True)
pattern_with_bias_2 = CallFunction(
aten.reshape.default,
CallFunction(
aten.reshape.default,
CallFunction(
aten.add.Tensor,
pattern_no_bias_2,
KeywordArg("bias"),
),
Arg(),
),
KeywordArg("out_shape_with_bias"),
)
def _validate_pattern(match: Match):
if len(match.nodes) not in [6, 7, 10]:
return False
if len(match.nodes) == 10:
# Check the two tailing reshape nodes can be fused
if match.nodes[9].args[1] != match.nodes[6].args[1]:
return False
if len(match.nodes) == 10 or (
len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor
):
bias_idx = 7 if len(match.nodes) == 10 else 6
# Check bias shape
bias_node = match.nodes[bias_idx].args[1]
if not isinstance(bias_node, torch.fx.node.Node):
return False
if len(bias_node.meta.get("tensor_meta").shape) != 1: # type: ignore[union-attr]
return False
return True
pattern_to_pass_number = {
pattern_no_bias_2: 0,
pattern_with_bias_2: 0,
pattern_no_bias_1: 1,
pattern_with_bias_1: 1,
}
for pattern, pass_number in pattern_to_pass_number.items():
@register_freezing_graph_pattern(
pattern,
extra_check=_validate_pattern,
pass_number=pass_number,
)
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
bias = kwargs.get("bias", None)
x = kwargs["a"]
weight = kwargs["b"]
dtype = kwargs["dtype"]
x_scale = kwargs["x_scale"]
w_scale = kwargs["w_scale"]
x_shape = x.meta.get("tensor_meta").shape
if has_free_symbols(x_shape):
# For dynamic shape case, we can't get activation shape ahead of runtime.
x_shape = None
out_node = match.output_node()
with match.graph.inserting_before(out_node):
transpose_node = match.graph.call_function(
aten.permute.default, args=(weight, [1, 0])
)
contig_node = match.graph.call_function(
aten.contiguous.default, args=(transpose_node,)
)
packed_weight_inputs = (
contig_node,
x_shape,
)
packed_weight_op = torch.ops.onednn.qlinear_prepack
prepack_weight_node = match.graph.call_function(
packed_weight_op, args=packed_weight_inputs
)
dummy_zp = match.graph.call_function(aten.empty, args=([0],))
w_scale = match.graph.call_function(
prims.convert_element_type.default, args=(w_scale, torch.float32)
)
x_scale_shape = x_scale.meta.get("tensor_meta").shape
x_scale_is_scalar = False
if not has_free_symbols(x_scale_shape):
prod = 1
for d in x_scale_shape:
prod *= d
x_scale_is_scalar = prod == 1
new_args: Tuple[Any, ...]
if x_scale_is_scalar:
# in this case, we can call onednn.qlinear directly
new_args = (
x,
x_scale,
dummy_zp, # x_zp
prepack_weight_node,
w_scale,
dummy_zp, # w_zp
bias,
1.0, # output_scale
0, # output_zero_point
dtype, # output_dtype
"none", # post op name
[], # post op args
"", # post op algorithm
)
new_linear_node = match.graph.call_function(
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
)
out_node.replace_all_uses_with(new_linear_node)
new_linear_node.meta.update(out_node.meta)
else:
# onednn.qlinear does not support per-channel quantization of x
# so in this case, we have to apply x scale and add bias ourselves after qlinear
x_reshaped = match.graph.call_function(
aten.reshape.default, args=(x, kwargs["in_shape"])
)
new_args = (
x_reshaped,
1.0, # x_scale
0, # x_zp
prepack_weight_node,
w_scale,
dummy_zp, # w_zp
None, # bias
1.0, # output_scale
0, # output_zero_point
dtype, # output_dtype
"none", # post op name
[], # post op args
"", # post op algorithm
)
new_linear_node = match.graph.call_function(
torch.ops.onednn.qlinear_pointwise, args=new_args
)
# apply x scale
new_out_node = match.graph.call_function(
aten.mul.Tensor, args=(new_linear_node, x_scale)
)
# Add bias and reshape
out_shape = kwargs.get(
"out_shape_with_bias", kwargs["out_shape_no_bias"]
)
if bias is not None:
new_out_node = match.graph.call_function(
aten.add.Tensor, args=(new_out_node, bias)
)
new_out_node = match.graph.call_function(
aten.reshape.default,
args=(new_out_node, out_shape),
)
else:
new_out_node = match.graph.call_function(
aten.reshape.default,
args=(new_out_node, out_shape),
)
out_node.replace_all_uses_with(new_out_node)
new_out_node.meta.update(out_node.meta)
for node in reversed(match.nodes):
match.graph.erase_node(node)
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
match.nodes
)
@functools.lru_cache(None)
def _register_quantization_weight_pack_pass():
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
@ -2760,9 +2540,6 @@ def _register_quantization_weight_pack_pass():
# Step 3: QLinear weight prepack
_register_qlinear_weight_prepack()
# Step 4: weight prepack for SmoothQuant from Torchao
_register_smooth_quant_int_mm_pattern()
def quant_lift_up(graph_module: torch.fx.GraphModule):
"""