mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] [Quant] Fix QConv Binary Inplace Layout Issue (#115613)
This pull request primarily addresses two issues to resolve the `QConvPointWiseBinaryPT2E` layout problem:
- As the changes made in 611a7457ca, for `QConvPointWiseBinaryPT2E` with post-op `sum`, we should also utilize `NoneLayout` and return `accum` instead of `QConvPointWiseBinaryPT2E`.
- Additionally, this pull request fixes an issue in the `_quantized_convolution_onednn` implementation. Given that we expect `accum` to be inplace changed, we should avoid copying `accum` by changing the memory format or data type inside the kernel implementation. Instead, we have moved the necessary changes of memory format or data type to the lowering of `QConvPointWiseBinaryPT2E`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115613
Approved by: https://github.com/jgong5, https://github.com/oulgen
ghstack dependencies: #116172
This commit is contained in:
parent
dfb6815170
commit
81cebca3d2
|
|
@ -1435,10 +1435,21 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
// has_accum_postop_sum: extra input besides the conv to do conv post op sum fusion.
|
||||
bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "sum";
|
||||
|
||||
if (has_accum_postop_sum && (fp32_output || bfloat16_output)) {
|
||||
TORCH_CHECK(accum_scale == 1.0, " (ONEDNN): fp32 or bf16 output, accum_scale must be 1.0.");
|
||||
TORCH_CHECK(accum_zero_point == 0, " (ONEDNN): fp32 or bf16 output, accum_zero_point must be 0");
|
||||
TORCH_CHECK((accum.value().scalar_type() == c10::kFloat) || (accum.value().scalar_type() == c10::kBFloat16), "The accum tensor should be KFloat or KBFloat.");
|
||||
if (has_accum_postop_sum) {
|
||||
TORCH_CHECK(accum.has_value(), "For post op sum, accum tensor should not be empty.");
|
||||
TORCH_CHECK(
|
||||
accum.value().is_contiguous(
|
||||
kSpatialDim == 2
|
||||
? c10::MemoryFormat::ChannelsLast
|
||||
: c10::MemoryFormat::ChannelsLast3d
|
||||
),
|
||||
"For post op sum, accum tensor must be contiguous."
|
||||
);
|
||||
if (fp32_output || bfloat16_output) {
|
||||
TORCH_CHECK(accum_scale == 1.0, " (ONEDNN): fp32 or bf16 output, accum_scale must be 1.0.");
|
||||
TORCH_CHECK(accum_zero_point == 0, " (ONEDNN): fp32 or bf16 output, accum_zero_point must be 0");
|
||||
TORCH_CHECK((accum.value().scalar_type() == c10::kFloat) || (accum.value().scalar_type() == c10::kBFloat16), "The accum tensor should be KFloat or KBFloat.");
|
||||
}
|
||||
}
|
||||
|
||||
std::string func_name = "quantized::packed_weights_conv";
|
||||
|
|
@ -1605,23 +1616,13 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
return output;
|
||||
}
|
||||
ideep::tensor dst;
|
||||
at::Tensor accum_contig;
|
||||
if (has_accum_postop_sum) {
|
||||
auto dst_desc = ideep::tensor::desc(dst_dims, fp32_output ? ideep::tensor::data_type::f32 : (
|
||||
bfloat16_output ? ideep::tensor::data_type::bf16 : src_data_type),
|
||||
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
|
||||
accum_contig = accum.value().contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d);
|
||||
if (fp32_output || bfloat16_output) {
|
||||
TORCH_CHECK((output.scalar_type() == c10::kFloat) || (output.scalar_type() == c10::kBFloat16), "The output tensor should be KFloat or KBFloat.");
|
||||
if (accum_contig.scalar_type() != output.scalar_type()) {
|
||||
// accum_contig is KFloat32 and we expect a kBFloat16 output
|
||||
// or accum_contig is kBFloat16 and we expect a KFloat32 output
|
||||
accum_contig = accum_contig.to(output.scalar_type());
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(accum_contig.dtype() == output.dtype(), "The output tensor should have same dtype as the accum tensor.");
|
||||
TORCH_CHECK(accum.value().dtype() == output.dtype(), "The output tensor should have same dtype as the accum tensor.");
|
||||
// When fused with sum, the dst tensor will share the data ptr as the accum tensor.
|
||||
dst.init(dst_desc, accum_contig.data_ptr());
|
||||
dst.init(dst_desc, accum.value().data_ptr());
|
||||
} else {
|
||||
if (fp32_output || bfloat16_output) {
|
||||
// Conv without add: int8-in, fp32-output
|
||||
|
|
@ -1683,7 +1684,7 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
return output;
|
||||
}
|
||||
if (has_accum_postop_sum) {
|
||||
return accum_contig;
|
||||
return accum.value();
|
||||
} else {
|
||||
return output;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6600,14 +6600,16 @@ class TestQuantizedConv(TestCase):
|
|||
)
|
||||
|
||||
if post_op.binary_attr == "sum":
|
||||
X2_q_cpu_tensor = X2_q.int_repr()
|
||||
X2_cpu_tensor = (
|
||||
X2_q.int_repr()
|
||||
if qconv_output_dtype is None
|
||||
else X2_q.dequantize().to(qconv_x2_dtype)
|
||||
).contiguous(memory_format=torch.channels_last)
|
||||
Y_q_cpu_tensor = qconv(
|
||||
X_q_cpu_tensor,
|
||||
X_scale,
|
||||
X_zero_point,
|
||||
X2_q_cpu_tensor
|
||||
if qconv_output_dtype is None
|
||||
else X2_q.dequantize().to(qconv_x2_dtype),
|
||||
X2_cpu_tensor,
|
||||
X2_scale,
|
||||
X2_zero_point,
|
||||
packed_weight,
|
||||
|
|
@ -7074,14 +7076,14 @@ class TestQuantizedConv(TestCase):
|
|||
W_zero_point = [-3]
|
||||
use_bias_list = [False, True]
|
||||
use_channelwise = True
|
||||
qconv_x2_dtype_list = [torch.float32, torch.bfloat16]
|
||||
output_dtype_list = [torch.float32, torch.bfloat16]
|
||||
X2_zero_point = 0
|
||||
use_relu_list = [True, False]
|
||||
options = itertools.product(
|
||||
use_bias_list, output_dtype_list, qconv_x2_dtype_list, use_relu_list
|
||||
use_bias_list, output_dtype_list, use_relu_list
|
||||
)
|
||||
for use_bias, output_dtype, qconv_x2_dtype, use_relu in options:
|
||||
for use_bias, output_dtype, use_relu in options:
|
||||
qconv_x2_dtype = output_dtype
|
||||
qconv = torch.ops.onednn.qconv2d_pointwise.binary
|
||||
qconv_prepack = torch.ops.onednn.qconv_prepack
|
||||
conv_op = torch.nn.Conv2d(
|
||||
|
|
|
|||
|
|
@ -5932,6 +5932,7 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
|||
accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
|
||||
"""
|
||||
self.has_bias = len(inputs) == 6
|
||||
self.idx_for_inplace_sum = 3 if self.has_bias else 2
|
||||
super().__init__(
|
||||
layout,
|
||||
inputs,
|
||||
|
|
@ -6030,6 +6031,12 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
|||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
|
||||
def get_mutation_names(self):
|
||||
return [self.inputs[self.idx_for_inplace_sum].get_name()]
|
||||
|
||||
def get_unbacked_symbol_defs(self):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
|
|
@ -6102,16 +6109,20 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
|||
may_convert_to_optional(unary_scalars),
|
||||
unary_algorithm,
|
||||
]
|
||||
if output_dtype is not None:
|
||||
# in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout
|
||||
# if output_dtype is not None, the output buf should be dtype output_dtype instead of uint8.
|
||||
kernel_layout.dtype = output_dtype
|
||||
|
||||
return QConvPointWiseBinaryPT2E(
|
||||
layout=kernel_layout,
|
||||
assert (
|
||||
binary_attr == "sum"
|
||||
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
|
||||
|
||||
packed = QConvPointWiseBinaryPT2E(
|
||||
layout=NoneLayout(accum.get_device()),
|
||||
inputs=inputs,
|
||||
constant_args=constant_args,
|
||||
)
|
||||
mark_node_as_mutating(packed, accum)
|
||||
|
||||
# Return accum since it has been inplace changed.
|
||||
return packed.inputs[packed.idx_for_inplace_sum]
|
||||
|
||||
|
||||
class QLinearPointwisePT2E(ExternKernelAlloc):
|
||||
|
|
|
|||
|
|
@ -1523,6 +1523,17 @@ def register_onednn_fusion_ops():
|
|||
unary_scalars,
|
||||
unary_algorithmm,
|
||||
):
|
||||
if (
|
||||
binary_attr == "sum"
|
||||
and output_dtype in [torch.float32, torch.bfloat16]
|
||||
and accum.get_dtype() in [torch.float32, torch.bfloat16]
|
||||
and accum.get_dtype() != output_dtype
|
||||
):
|
||||
# For int8-mixed-bf16 quantization and inplace add,
|
||||
# there is case when accum dtype is float32 but output dtype is bfloat16.
|
||||
# Since the accum will be inplaced changed with post op sum,
|
||||
# we will do accum dtype convertion here.
|
||||
accum = to_dtype(accum, output_dtype)
|
||||
return TensorBox.create(
|
||||
ir.QConvPointWiseBinaryPT2E.create(
|
||||
x,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user