[Inductor] [Quant] Fix QConv Binary Inplace Layout Issue (#115613)

This pull request primarily addresses two issues to resolve the `QConvPointWiseBinaryPT2E` layout problem:

- As the changes made in 611a7457ca, for `QConvPointWiseBinaryPT2E` with post-op `sum`, we should also utilize `NoneLayout` and return `accum` instead of `QConvPointWiseBinaryPT2E`.

- Additionally, this pull request fixes an issue in the `_quantized_convolution_onednn` implementation. Given that we expect `accum` to be inplace changed, we should avoid copying `accum` by changing the memory format or data type inside the kernel implementation. Instead, we have moved the necessary changes of memory format or data type to the lowering of `QConvPointWiseBinaryPT2E`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115613
Approved by: https://github.com/jgong5, https://github.com/oulgen
ghstack dependencies: #116172
This commit is contained in:
leslie-fang-intel 2023-12-24 09:23:31 +08:00 committed by PyTorch MergeBot
parent dfb6815170
commit 81cebca3d2
4 changed files with 55 additions and 30 deletions

View File

@ -1435,10 +1435,21 @@ static at::Tensor _quantized_convolution_onednn(
// has_accum_postop_sum: extra input besides the conv to do conv post op sum fusion.
bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "sum";
if (has_accum_postop_sum && (fp32_output || bfloat16_output)) {
TORCH_CHECK(accum_scale == 1.0, " (ONEDNN): fp32 or bf16 output, accum_scale must be 1.0.");
TORCH_CHECK(accum_zero_point == 0, " (ONEDNN): fp32 or bf16 output, accum_zero_point must be 0");
TORCH_CHECK((accum.value().scalar_type() == c10::kFloat) || (accum.value().scalar_type() == c10::kBFloat16), "The accum tensor should be KFloat or KBFloat.");
if (has_accum_postop_sum) {
TORCH_CHECK(accum.has_value(), "For post op sum, accum tensor should not be empty.");
TORCH_CHECK(
accum.value().is_contiguous(
kSpatialDim == 2
? c10::MemoryFormat::ChannelsLast
: c10::MemoryFormat::ChannelsLast3d
),
"For post op sum, accum tensor must be contiguous."
);
if (fp32_output || bfloat16_output) {
TORCH_CHECK(accum_scale == 1.0, " (ONEDNN): fp32 or bf16 output, accum_scale must be 1.0.");
TORCH_CHECK(accum_zero_point == 0, " (ONEDNN): fp32 or bf16 output, accum_zero_point must be 0");
TORCH_CHECK((accum.value().scalar_type() == c10::kFloat) || (accum.value().scalar_type() == c10::kBFloat16), "The accum tensor should be KFloat or KBFloat.");
}
}
std::string func_name = "quantized::packed_weights_conv";
@ -1605,23 +1616,13 @@ static at::Tensor _quantized_convolution_onednn(
return output;
}
ideep::tensor dst;
at::Tensor accum_contig;
if (has_accum_postop_sum) {
auto dst_desc = ideep::tensor::desc(dst_dims, fp32_output ? ideep::tensor::data_type::f32 : (
bfloat16_output ? ideep::tensor::data_type::bf16 : src_data_type),
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
accum_contig = accum.value().contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d);
if (fp32_output || bfloat16_output) {
TORCH_CHECK((output.scalar_type() == c10::kFloat) || (output.scalar_type() == c10::kBFloat16), "The output tensor should be KFloat or KBFloat.");
if (accum_contig.scalar_type() != output.scalar_type()) {
// accum_contig is KFloat32 and we expect a kBFloat16 output
// or accum_contig is kBFloat16 and we expect a KFloat32 output
accum_contig = accum_contig.to(output.scalar_type());
}
}
TORCH_CHECK(accum_contig.dtype() == output.dtype(), "The output tensor should have same dtype as the accum tensor.");
TORCH_CHECK(accum.value().dtype() == output.dtype(), "The output tensor should have same dtype as the accum tensor.");
// When fused with sum, the dst tensor will share the data ptr as the accum tensor.
dst.init(dst_desc, accum_contig.data_ptr());
dst.init(dst_desc, accum.value().data_ptr());
} else {
if (fp32_output || bfloat16_output) {
// Conv without add: int8-in, fp32-output
@ -1683,7 +1684,7 @@ static at::Tensor _quantized_convolution_onednn(
return output;
}
if (has_accum_postop_sum) {
return accum_contig;
return accum.value();
} else {
return output;
}

View File

@ -6600,14 +6600,16 @@ class TestQuantizedConv(TestCase):
)
if post_op.binary_attr == "sum":
X2_q_cpu_tensor = X2_q.int_repr()
X2_cpu_tensor = (
X2_q.int_repr()
if qconv_output_dtype is None
else X2_q.dequantize().to(qconv_x2_dtype)
).contiguous(memory_format=torch.channels_last)
Y_q_cpu_tensor = qconv(
X_q_cpu_tensor,
X_scale,
X_zero_point,
X2_q_cpu_tensor
if qconv_output_dtype is None
else X2_q.dequantize().to(qconv_x2_dtype),
X2_cpu_tensor,
X2_scale,
X2_zero_point,
packed_weight,
@ -7074,14 +7076,14 @@ class TestQuantizedConv(TestCase):
W_zero_point = [-3]
use_bias_list = [False, True]
use_channelwise = True
qconv_x2_dtype_list = [torch.float32, torch.bfloat16]
output_dtype_list = [torch.float32, torch.bfloat16]
X2_zero_point = 0
use_relu_list = [True, False]
options = itertools.product(
use_bias_list, output_dtype_list, qconv_x2_dtype_list, use_relu_list
use_bias_list, output_dtype_list, use_relu_list
)
for use_bias, output_dtype, qconv_x2_dtype, use_relu in options:
for use_bias, output_dtype, use_relu in options:
qconv_x2_dtype = output_dtype
qconv = torch.ops.onednn.qconv2d_pointwise.binary
qconv_prepack = torch.ops.onednn.qconv_prepack
conv_op = torch.nn.Conv2d(

View File

@ -5932,6 +5932,7 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
"""
self.has_bias = len(inputs) == 6
self.idx_for_inplace_sum = 3 if self.has_bias else 2
super().__init__(
layout,
inputs,
@ -6030,6 +6031,12 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
def get_mutation_names(self):
return [self.inputs[self.idx_for_inplace_sum].get_name()]
def get_unbacked_symbol_defs(self):
return {}
@classmethod
def create(
cls,
@ -6102,16 +6109,20 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
may_convert_to_optional(unary_scalars),
unary_algorithm,
]
if output_dtype is not None:
# in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout
# if output_dtype is not None, the output buf should be dtype output_dtype instead of uint8.
kernel_layout.dtype = output_dtype
return QConvPointWiseBinaryPT2E(
layout=kernel_layout,
assert (
binary_attr == "sum"
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
packed = QConvPointWiseBinaryPT2E(
layout=NoneLayout(accum.get_device()),
inputs=inputs,
constant_args=constant_args,
)
mark_node_as_mutating(packed, accum)
# Return accum since it has been inplace changed.
return packed.inputs[packed.idx_for_inplace_sum]
class QLinearPointwisePT2E(ExternKernelAlloc):

View File

@ -1523,6 +1523,17 @@ def register_onednn_fusion_ops():
unary_scalars,
unary_algorithmm,
):
if (
binary_attr == "sum"
and output_dtype in [torch.float32, torch.bfloat16]
and accum.get_dtype() in [torch.float32, torch.bfloat16]
and accum.get_dtype() != output_dtype
):
# For int8-mixed-bf16 quantization and inplace add,
# there is case when accum dtype is float32 but output dtype is bfloat16.
# Since the accum will be inplaced changed with post op sum,
# we will do accum dtype convertion here.
accum = to_dtype(accum, output_dtype)
return TensorBox.create(
ir.QConvPointWiseBinaryPT2E.create(
x,