mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant][X86] add an op to compute uint8 pointwise mul (#151112)
**Summary** Add a new op, `onednn.qmul.tensor`, for int8 elementwise mul, which accepts inputs on CPU device (instead of QuantizedCPU). The new op is implemented by AVX512 instructions and it provides similar or better performance, depending on shape, than its counterpart for QuantizedCPU device `quantized.mul`. The new op supports output dtypes other than uint8 (fp32, fp16 and bf16 are supported). **Test plan** ``` pytest test/quantization/core/test_quantized_op.py -k test_int8_mul_onednn ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/151112 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
parent
ad81eeb7c7
commit
c1c8c1f8d6
|
|
@ -216,6 +216,17 @@ using qnormalize_nhwc_fn = void (*)(
|
|||
using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
|
||||
const Tensor& /*qw*/);
|
||||
|
||||
using qmul_tensor_cpu_fn = void (*)(
|
||||
Tensor& /*out*/,
|
||||
const Tensor& /*qx*/,
|
||||
double /*qx_scale*/,
|
||||
int64_t /*qx_zero_point*/,
|
||||
const Tensor& /*qy*/,
|
||||
double /*qy_scale*/,
|
||||
int64_t /*qy_zero_point*/,
|
||||
double /*output_scale*/,
|
||||
int64_t /*output_zero_point*/);
|
||||
|
||||
DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub)
|
||||
DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub)
|
||||
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub)
|
||||
|
|
@ -252,5 +263,6 @@ DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub)
|
|||
DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub)
|
||||
DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub)
|
||||
DECLARE_DISPATCH(qprelu_fn, qprelu_stub)
|
||||
DECLARE_DISPATCH(qmul_tensor_cpu_fn, qmul_tensor_cpu_stub)
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include <ATen/native/quantized/cpu/QuantizedOps.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Unroll.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -4245,6 +4246,150 @@ void index_put_kernel_quantized_cpu(TensorIterator& iter, IntArrayRef index_size
|
|||
}, /*serial_execution=*/is_deterministic);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void _qmul_tensor_cpu_impl(
|
||||
T* out_ptr,
|
||||
int64_t size,
|
||||
const uint8_t* x_ptr,
|
||||
double x_scale,
|
||||
int64_t x_zero_point,
|
||||
const uint8_t* y_ptr,
|
||||
double y_scale,
|
||||
int64_t y_zero_point,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
float multiplier = x_scale * y_scale / output_scale;
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
int64_t size_rem = size % 16;
|
||||
int64_t size_com = size - size_rem;
|
||||
int64_t steps = size_com / 16;
|
||||
__m512 vs = _mm512_set1_ps(multiplier);
|
||||
__m512i vza = _mm512_set1_epi32(x_zero_point);
|
||||
__m512i vzb = _mm512_set1_epi32(y_zero_point);
|
||||
__m512i vzc = _mm512_set1_epi32(output_zero_point);
|
||||
__m512i v255 = _mm512_set1_epi32(255);
|
||||
__m512i v0 = _mm512_set1_epi32(0);
|
||||
at::parallel_for(0, steps, 1, [&](int64_t start, int64_t end) {
|
||||
for (const auto d : c10::irange(start, end)) {
|
||||
auto x_data = x_ptr + d * 16;
|
||||
auto y_data = y_ptr + d * 16;
|
||||
auto out_data = out_ptr + d * 16;
|
||||
__m128i va = _mm_loadu_si128((__m128i*)x_data);
|
||||
__m128i vb = _mm_loadu_si128((__m128i*)y_data);
|
||||
__m512i va_i32 = _mm512_cvtepi8_epi32(va);
|
||||
__m512i vb_i32 = _mm512_cvtepi8_epi32(vb);
|
||||
va_i32 = _mm512_sub_epi32(va_i32, vza);
|
||||
vb_i32 = _mm512_sub_epi32(vb_i32, vzb);
|
||||
__m512i vc = _mm512_mullo_epi32(va_i32, vb_i32);
|
||||
__m512 vc_f = _mm512_cvtepi32_ps(vc);
|
||||
vc_f = _mm512_mul_ps(vc_f, vs);
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
_mm512_storeu_ps(out_data, vc_f);
|
||||
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
|
||||
__m256i vc_bf16 = cvtfp32_bf16(vc_f);
|
||||
_mm256_storeu_si256((__m256i*)out_data, vc_bf16);
|
||||
} else if constexpr (std::is_same<T, at::Half>::value) {
|
||||
__m256i vc_f16 = cvtfp32_fp16(vc_f);
|
||||
_mm256_storeu_si256((__m256i*)out_data, vc_f16);
|
||||
} else { // T == uint8, requantization needed
|
||||
__m512i vc_i32 = _mm512_cvtps_epi32(vc_f);
|
||||
vc_i32 = _mm512_add_epi32(vc_i32, vzc);
|
||||
vc_i32 = _mm512_min_epi32(vc_i32, v255);
|
||||
vc_i32 = _mm512_max_epi32(vc_i32, v0);
|
||||
__m128i vc_i8 = _mm512_cvtepi32_epi8(vc_i32);
|
||||
_mm_storeu_si128((__m128i*)out_data, vc_i8);
|
||||
}
|
||||
}
|
||||
});
|
||||
if (size_rem > 0) {
|
||||
for (const auto d : c10::irange(size_rem)) {
|
||||
uint8_t x_data = *(x_ptr + size_com + d);
|
||||
uint8_t y_data = *(y_ptr + size_com + d);
|
||||
int32_t x_val = static_cast<int32_t>(x_data) - x_zero_point;
|
||||
int32_t y_val = static_cast<int32_t>(y_data) - y_zero_point;
|
||||
int32_t out_val = static_cast<int32_t>(x_val * y_val);
|
||||
float out_val_f = (float)out_val * multiplier;
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
*(out_ptr + size_com + d) = out_val_f;
|
||||
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
|
||||
*(out_ptr + size_com + d) = at::BFloat16(out_val_f);
|
||||
} else if constexpr (std::is_same<T, at::Half>::value) {
|
||||
*(out_ptr + size_com + d) = at::Half(out_val_f);
|
||||
} else { // T == uint8, requantization needed
|
||||
out_val_f = std::round(out_val_f);
|
||||
int32_t out_val_i32 = (int32_t)out_val_f + output_zero_point;
|
||||
out_val_i32 = std::min(255, std::max(0, out_val_i32));
|
||||
*(out_ptr + size_com + d) = static_cast<uint8_t>(out_val_i32);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
at::parallel_for(0, size, 1, [&](int64_t start, int64_t end) {
|
||||
for (const auto d : c10::irange(start, end)) {
|
||||
uint8_t x_data = *(x_ptr + d);
|
||||
uint8_t y_data = *(y_ptr + d);
|
||||
int32_t x_val = static_cast<int32_t>(x_data) - x_zero_point;
|
||||
int32_t y_val = static_cast<int32_t>(y_data) - y_zero_point;
|
||||
int32_t out_val = static_cast<int32_t>(x_val * y_val);
|
||||
float out_val_f = (float)out_val * multiplier;
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
*(out_ptr + d) = out_val_f;
|
||||
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
|
||||
*(out_ptr + d) = at::BFloat16(out_val_f);
|
||||
} else if constexpr (std::is_same<T, at::Half>::value) {
|
||||
*(out_ptr + d) = at::Half(out_val_f);
|
||||
} else { // T == uint8, requantization needed
|
||||
out_val_f = std::round(out_val_f);
|
||||
int32_t out_val_i32 = (int32_t)out_val_f + output_zero_point;
|
||||
out_val_i32 = std::min(255, std::max(0, out_val_i32));
|
||||
*(out_ptr + d) = static_cast<uint8_t>(out_val_i32);
|
||||
}
|
||||
}
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
void qmul_tensor_cpu_kernel(
|
||||
Tensor& out,
|
||||
const Tensor& qx,
|
||||
double qx_scale,
|
||||
int64_t qx_zero_point,
|
||||
const Tensor& qy,
|
||||
double qy_scale,
|
||||
int64_t qy_zero_point,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
auto qx_ptr = qx.const_data_ptr<uint8_t>();
|
||||
auto qy_ptr = qy.const_data_ptr<uint8_t>();
|
||||
int64_t size = qx.numel();
|
||||
TORCH_CHECK(
|
||||
size == qy.numel() && size == out.numel(),
|
||||
"qmul_cpu: Expect qx, qy and out to have the same number of elements");
|
||||
if (out.scalar_type() == c10::ScalarType::Float) {
|
||||
auto out_ptr = out.data_ptr<float>();
|
||||
_qmul_tensor_cpu_impl<float>(
|
||||
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point
|
||||
);
|
||||
} else if (out.scalar_type() == c10::ScalarType::BFloat16) {
|
||||
auto out_ptr = out.data_ptr<at::BFloat16>();
|
||||
_qmul_tensor_cpu_impl<at::BFloat16>(
|
||||
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point
|
||||
);
|
||||
} else if (out.scalar_type() == c10::ScalarType::Half) {
|
||||
auto out_ptr = out.data_ptr<at::Half>();
|
||||
_qmul_tensor_cpu_impl<at::Half>(
|
||||
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point
|
||||
);
|
||||
} else {
|
||||
TORCH_CHECK(out.scalar_type() == c10::ScalarType::Byte,
|
||||
"qmul_cpu: Unsupported output dtype: ", out.scalar_type());
|
||||
auto out_ptr = out.data_ptr<uint8_t>();
|
||||
_qmul_tensor_cpu_impl<uint8_t>(
|
||||
out_ptr, size, qx_ptr, qx_scale, qx_zero_point, qy_ptr, qy_scale, qy_zero_point, output_scale, output_zero_point
|
||||
);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
// Some quantization tests are flaky on Windows with AVX512. If --continue-through-error
|
||||
|
|
@ -4343,5 +4488,6 @@ REGISTER_DISPATCH(
|
|||
&index_put_kernel_quantized_cpu)
|
||||
REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel)
|
||||
REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(qmul_tensor_cpu_stub, &qmul_tensor_cpu_kernel)
|
||||
} // namespace at::native
|
||||
// NOLINTEND(*-c-arrays)
|
||||
|
|
|
|||
|
|
@ -277,6 +277,41 @@ Tensor _mul_scalar_out(Tensor& out, const Tensor& self, const Scalar& other) {
|
|||
return out;
|
||||
}
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
DEFINE_DISPATCH(qmul_tensor_cpu_stub);
|
||||
Tensor int8_mul_tensor_onednn(
|
||||
const Tensor& self, double self_scale, int64_t self_zero_point,
|
||||
const Tensor& other, double other_scale, int64_t other_zero_point,
|
||||
double output_scale, int64_t output_zero_point, c10::ScalarType output_dtype) {
|
||||
// Both inputs should have the same shape and both in uint8 dtype.
|
||||
// If output_dtype is uint8, output is requantized with output scale/zero point.
|
||||
// Otherwise, output scale should be 1 and zero point 0.
|
||||
TORCH_CHECK(self.sizes() == other.sizes(),
|
||||
"Quantized mul operands should have the same size.");
|
||||
TORCH_CHECK(self.scalar_type() == at::kByte && other.scalar_type() == at::kByte,
|
||||
"Quantized mul operands should be of type uint8, but got ",
|
||||
self.scalar_type(), " and ", other.scalar_type());
|
||||
TORCH_CHECK(output_dtype == at::kByte || output_dtype == at::kFloat || output_dtype == at::kBFloat16 || output_dtype == at::kHalf,
|
||||
"Quantized mul output should be of type uint8, float, bfloat16 or float16, but got ",
|
||||
output_dtype);
|
||||
if (output_dtype != at::kByte) {
|
||||
TORCH_CHECK(output_scale == 1.0 && output_zero_point == 0,
|
||||
"Quantized mul output scale and zero point should be 1 and 0 for "
|
||||
"output_dtype ", output_dtype, ", but got scale = ",
|
||||
output_scale, " and zero point = ", output_zero_point);
|
||||
}
|
||||
at::Tensor out = at::empty_like(self, self.options().dtype(output_dtype));
|
||||
|
||||
|
||||
qmul_tensor_cpu_stub(
|
||||
self.device().type(), out, self, self_scale, self_zero_point,
|
||||
other, other_scale, other_zero_point,
|
||||
output_scale, output_zero_point);
|
||||
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <bool ReLUFused = false>
|
||||
class QMul final {
|
||||
public:
|
||||
|
|
@ -370,6 +405,24 @@ class QMulScalarTensorOut final {
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
class QMulOnednn final {
|
||||
public:
|
||||
static Tensor run(
|
||||
const Tensor self, double self_scale, int64_t self_zero_point,
|
||||
const Tensor other, double other_scale, int64_t other_zero_point,
|
||||
double output_scale, int64_t output_zero_point, c10::ScalarType output_dtype
|
||||
) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
return int8_mul_tensor_onednn(
|
||||
self, self_scale, self_zero_point,
|
||||
other, other_scale, other_zero_point,
|
||||
output_scale, output_zero_point, output_dtype);
|
||||
#endif
|
||||
TORCH_CHECK(false, "Unimplemented (int8 mul tensor with onednn)");
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::mul"), TORCH_FN(QMul</*ReLUFused=*/false>::run));
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::mul.out"), TORCH_FN(QMulOut</*ReLUFused=*/false>::run));
|
||||
|
|
@ -395,5 +448,9 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
|
|||
m.impl(TORCH_SELECTIVE_NAME("quantized::mul_scalar_relu_out.Tensor"), TORCH_FN(QMulScalarTensorOut</*ReLUFused=*/true>::run));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qmul.tensor"), TORCH_FN(QMulOnednn::run));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -278,4 +278,6 @@ TORCH_LIBRARY(onednn, m) {
|
|||
// Linear with binary postop
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
|
||||
// int8 mul
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qmul.tensor(Tensor self, float self_scale, int self_zero_point, Tensor other, float other_scale, int other_zero_point, float output_scale, int output_zero_point, ScalarType output_dtype) -> Tensor"));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3139,6 +3139,34 @@ class TestQuantizedOps(TestCase):
|
|||
# Verify the result is scriptable
|
||||
mha_quantized_scripted = torch.jit.script(mha_quantized)
|
||||
|
||||
@skipIfNoONEDNN
|
||||
def test_int8_mul_onednn(self):
|
||||
output_dtype_list = [torch.uint8, torch.float, torch.bfloat16, torch.half]
|
||||
shape_list = [(16, 64), (15, 63)]
|
||||
cases = itertools.product(shape_list, output_dtype_list)
|
||||
for shape, output_dtype in cases:
|
||||
a = torch.randn(shape)
|
||||
b = torch.randn(shape)
|
||||
s_a, z_a = 0.1, 1
|
||||
s_b, z_b = 0.2, 2
|
||||
if output_dtype == torch.uint8:
|
||||
s_c, z_c = 0.3, 3
|
||||
else:
|
||||
s_c, z_c = 1, 0
|
||||
qa = torch.quantize_per_tensor(a, s_a, z_a, torch.quint8)
|
||||
qb = torch.quantize_per_tensor(b, s_b, z_b, torch.quint8)
|
||||
dqa = qa.dequantize()
|
||||
dqb = qb.dequantize()
|
||||
c_ref = dqa * dqb
|
||||
if output_dtype == torch.uint8:
|
||||
c_ref = torch.ops.quantized_decomposed.quantize_per_tensor.default(c_ref, s_c, z_c, 0, 255, torch.uint8)
|
||||
c_ref = c_ref.to(output_dtype)
|
||||
|
||||
a_int8 = qa.int_repr()
|
||||
b_int8 = qb.int_repr()
|
||||
c = torch.ops.onednn.qmul.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
|
||||
self.assertEqual(c, c_ref)
|
||||
|
||||
|
||||
class TestDynamicQuantizedOps(TestCase):
|
||||
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user