mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
removing quantization utility functions moved to fbgemm (#14301)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14301 This diff removes quantization utility functions copied to fbgemm Reviewed By: Maratyszcza Differential Revision: D13159299 fbshipit-source-id: a7f3cd2af0aa241a8578d532a70a157da70d9289
This commit is contained in:
parent
8c4910b095
commit
fb8c3d62fe
|
|
@ -40,12 +40,12 @@ void FindMinMax(const T* data, float* min, float* max, int len) {
|
|||
for (int i = 0; i < len; ++i) {
|
||||
temp[i] = data[i];
|
||||
}
|
||||
dnnlowp::FindMinMax(temp.data(), min, max, len);
|
||||
fbgemm::FindMinMax(temp.data(), min, max, len);
|
||||
}
|
||||
|
||||
template <>
|
||||
void FindMinMax<float>(const float* data, float* min, float* max, int len) {
|
||||
dnnlowp::FindMinMax(data, min, max, len);
|
||||
fbgemm::FindMinMax(data, min, max, len);
|
||||
}
|
||||
|
||||
void OutputMinMaxObserver::Stop() {
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
|
|||
// Adjust for the fact that B will actually use signed.
|
||||
B_qparams_[i].zero_point += signed_min;
|
||||
|
||||
Quantize<int8_t>(
|
||||
fbgemm::Quantize<int8_t>(
|
||||
B.template data<float>() + i * B_quantized_temp.size(),
|
||||
B_quantized_temp.data(),
|
||||
B_quantized_temp.size(),
|
||||
|
|
@ -694,7 +694,7 @@ bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
|
|||
|
||||
// Requantization
|
||||
for (int j = 0; j < M * N; ++j) {
|
||||
Y_quantized[p * Y_stride + i * M * N + j] = Requantize<T>(
|
||||
Y_quantized[p * Y_stride + i * M * N + j] = fbgemm::Requantize<T>(
|
||||
Y_int32_[p * Y_stride + i * M * N + j],
|
||||
requantization_params_[0]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ TensorQuantizationParams GetInputTensorQuantizationParamsOf(
|
|||
CAFFE_ENFORCE(tensor->numel() == 0 || tensor->template data<float>());
|
||||
|
||||
float min, max;
|
||||
FindMinMax(tensor->template data<float>(), &min, &max, tensor->numel());
|
||||
fbgemm::FindMinMax(
|
||||
tensor->template data<float>(), &min, &max, tensor->numel());
|
||||
|
||||
return qfactory->ChooseQuantizationParams(min, max, is_weight);
|
||||
}
|
||||
|
|
@ -142,7 +143,8 @@ const T* QuantizeInputIfNeeded(
|
|||
// Need to quantize
|
||||
const TensorCPU& tensor = op->Input<Tensor>(input_index, CPU);
|
||||
temp.resize(tensor.numel());
|
||||
Quantize<T>(tensor.data<float>(), temp.data(), temp.size(), qparams);
|
||||
fbgemm::Quantize<T>(
|
||||
tensor.data<float>(), temp.data(), temp.size(), qparams);
|
||||
return temp.data();
|
||||
}
|
||||
}
|
||||
|
|
@ -165,7 +167,7 @@ const T* RowWiseQuantizeInputIfNeeded(
|
|||
int rowwidth = temp.size() / N;
|
||||
// quantize each row
|
||||
for (int i = 0; i < N; i++) {
|
||||
Quantize<T>(
|
||||
fbgemm::Quantize<T>(
|
||||
tensor.data<float>() + rowwidth * i,
|
||||
temp.data() + rowwidth * i,
|
||||
rowwidth,
|
||||
|
|
|
|||
|
|
@ -139,12 +139,12 @@ bool ConcatDNNLowPOp<T>::RunOnDevice() {
|
|||
if (InputTensorCPU_(i).template IsType<T>()) {
|
||||
const T* input_data = input.template data<T>();
|
||||
for (int j = j_begin; j < j_end; ++j) {
|
||||
input_temp[j] = dnnlowp::Requantize<T>(
|
||||
input_temp[j] = fbgemm::Requantize<T>(
|
||||
input_data[j] - in_qparams_[i].zero_point,
|
||||
requantization_params_[i]);
|
||||
}
|
||||
} else {
|
||||
dnnlowp::Quantize<T>(
|
||||
fbgemm::Quantize<T>(
|
||||
input.template data<float>() + j_begin,
|
||||
input_temp.data() + j_begin,
|
||||
j_end - j_begin,
|
||||
|
|
|
|||
|
|
@ -307,7 +307,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
|||
(uint8_t*)col_buffer_data + tid * col_buffer_size;
|
||||
} else {
|
||||
col_buffer_quantized.resize(kernel_dim * output_image_size);
|
||||
Quantize<uint8_t>(
|
||||
fbgemm::Quantize<uint8_t>(
|
||||
(const float*)col_buffer_data + tid * col_buffer_size,
|
||||
col_buffer_quantized.data(),
|
||||
col_buffer_quantized.size(),
|
||||
|
|
@ -605,7 +605,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
|
|||
col_buffer_quantized.size(),
|
||||
dnnlowp_get_num_threads(),
|
||||
dnnlowp_get_thread_num());
|
||||
Quantize<uint8_t>(
|
||||
fbgemm::Quantize<uint8_t>(
|
||||
(const float*)col_buffer_data + begin,
|
||||
col_buffer_quantized.data() + begin,
|
||||
end - begin,
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeBias_() {
|
|||
#endif
|
||||
for (int i = 0; i < b_dequantized_.size(); ++i) {
|
||||
b_dequantized_[i] =
|
||||
Dequantize<int32_t>(b_quantized_data_[i], bias_qparams);
|
||||
fbgemm::Dequantize<int32_t>(b_quantized_data_[i], bias_qparams);
|
||||
}
|
||||
b_dequantized_data_ = b_dequantized_.data();
|
||||
}
|
||||
|
|
@ -245,7 +245,7 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeBias_() {
|
|||
int i_begin = g * (M / filter_qparams_.size());
|
||||
int i_end = i_begin + (M / filter_qparams_.size());
|
||||
for (int i = i_begin; i < i_end; ++i) {
|
||||
b_quantized_[i] = Quantize<int32_t>(
|
||||
b_quantized_[i] = fbgemm::Quantize<int32_t>(
|
||||
b_dequantized_data_[i],
|
||||
0,
|
||||
in_qparams_[INPUT].scale * FilterQuantizationParams(g).scale,
|
||||
|
|
@ -333,7 +333,7 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeWeight_() {
|
|||
// Adjust for the fact that weight will actually use signed.
|
||||
FilterQuantizationParams(g).zero_point -= signed_min;
|
||||
|
||||
Quantize<T_signed>(
|
||||
fbgemm::Quantize<T_signed>(
|
||||
filter.template data<float>() + offset,
|
||||
W_quantized_.data() + offset,
|
||||
(M / filter_qparams_.size()) * kernel_dim,
|
||||
|
|
@ -517,7 +517,7 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNCHW_(
|
|||
raw = std::max(0, raw);
|
||||
}
|
||||
Y_data[i * Y_HxW + j] =
|
||||
dnnlowp::Requantize<T>(raw, RequantizationParams(group_id));
|
||||
fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
|
||||
}
|
||||
}
|
||||
} // !dequantize_output_
|
||||
|
|
@ -659,7 +659,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
|||
(T*)col_buffer_data + tid * col_buffer_size;
|
||||
} else {
|
||||
col_buffer_quantized.resize(kernel_dim * Y_HxW);
|
||||
Quantize<T>(
|
||||
fbgemm::Quantize<T>(
|
||||
(const float*)col_buffer_data + tid * col_buffer_size,
|
||||
col_buffer_quantized.data(),
|
||||
col_buffer_quantized.size(),
|
||||
|
|
@ -894,7 +894,7 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
|
|||
}
|
||||
|
||||
Ydata[i * M + j] =
|
||||
dnnlowp::Requantize<T>(raw, RequantizationParams(group_id));
|
||||
fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
|
||||
if (ReluFused) { // static if
|
||||
Ydata[i * M + j] =
|
||||
std::max<int32_t>(C_zero_point, Ydata[i * M + j]);
|
||||
|
|
@ -904,7 +904,7 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
|
|||
} // for each row i
|
||||
} // !__AVX2__
|
||||
|
||||
PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1409,7 +1409,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
|
|||
reinterpret_cast<const T*>(col_buffer_data);
|
||||
} else {
|
||||
col_buffer_quantized.resize(G * kernel_dim * Y_HxW * N);
|
||||
Quantize<T>(
|
||||
fbgemm::Quantize<T>(
|
||||
reinterpret_cast<const float*>(col_buffer_data),
|
||||
col_buffer_quantized.data(),
|
||||
col_buffer_quantized.size(),
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
#include "caffe2/core/tensor_int8.h"
|
||||
#include "caffe2/operators/conv_pool_op_base.h"
|
||||
#include "caffe2/quantization/server/caffe2_dnnlowp_utils.h"
|
||||
#include "caffe2/quantization/server/op_wrapper.h"
|
||||
|
||||
#ifdef _OPENMP
|
||||
|
|
@ -107,7 +106,7 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
|
|||
OutputTensorCPU_(0)->size(),
|
||||
out_qparams_);
|
||||
} else {
|
||||
PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
}
|
||||
|
||||
MeasureQuantizationError_();
|
||||
|
|
|
|||
|
|
@ -70,375 +70,6 @@ namespace dnnlowp {
|
|||
|
||||
using namespace std;
|
||||
|
||||
float TensorQuantizationParams::Min() const {
|
||||
return Dequantize(0, *this);
|
||||
}
|
||||
|
||||
float TensorQuantizationParams::Max() const {
|
||||
return Dequantize((1 << precision) - 1, *this);
|
||||
}
|
||||
|
||||
int64_t SaturatingRoundingMulWithShift(int32_t a, int32_t b, int right_shift) {
|
||||
int64_t a_64(a);
|
||||
int64_t b_64(b);
|
||||
int64_t ab_64 = a_64 * b_64;
|
||||
|
||||
int64_t nudge = 1ll << (right_shift - 1);
|
||||
return (ab_64 + nudge) >> right_shift;
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
void RequantizeFixedPointAvx2(
|
||||
const int32_t* src,
|
||||
uint8_t* dst,
|
||||
int len,
|
||||
const RequantizationParams& params) {
|
||||
constexpr int VLEN = 8;
|
||||
|
||||
__m256i b = _mm256_set1_epi32(params.multiplier);
|
||||
|
||||
// AVX2 doesn't support arithmetic right shift.
|
||||
// As a work around, we convert 64-bit multiplied results to uint64_t by
|
||||
// adding 0x8000000000000000ULL, logical right shift, and subtract by
|
||||
// (0x8000000000000000ULL >> right_shift).
|
||||
__m256i pre_shift_nudge = _mm256_set1_epi64x(
|
||||
(1ll << (params.right_shift - 1)) + 0x8000000000000000ULL);
|
||||
__m256i post_shift_nudge = _mm256_set1_epi64x(
|
||||
params.target_qparams.zero_point -
|
||||
(0x8000000000000000ULL >> params.right_shift));
|
||||
|
||||
__m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min());
|
||||
__m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max());
|
||||
|
||||
__m256i shuffle_mask_v = _mm256_set_epi8(
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0x0c,
|
||||
0x08,
|
||||
0x04,
|
||||
0x00,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0xff,
|
||||
0x0c,
|
||||
0x08,
|
||||
0x04,
|
||||
0x00);
|
||||
__m256i permute_mask_v =
|
||||
_mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
|
||||
|
||||
int i = 0;
|
||||
for (; i < len / VLEN * VLEN; i += VLEN) {
|
||||
__m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i));
|
||||
|
||||
// a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7
|
||||
// b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7
|
||||
__m256i a_even_v = a_v;
|
||||
__m256i a_odd_v = _mm256_srli_si256(a_v, 4);
|
||||
|
||||
__m256i ab_even_v = _mm256_mul_epi32(a_even_v, b);
|
||||
__m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b);
|
||||
|
||||
__m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge);
|
||||
__m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge);
|
||||
|
||||
__m256i even_result_v = _mm256_add_epi64(
|
||||
_mm256_srli_epi64(even_rounded_v, params.right_shift),
|
||||
post_shift_nudge);
|
||||
__m256i odd_result_v = _mm256_add_epi64(
|
||||
_mm256_srli_epi64(odd_rounded_v, params.right_shift), post_shift_nudge);
|
||||
odd_result_v = _mm256_slli_si256(odd_result_v, 4);
|
||||
|
||||
// even_result_v has numbers we want in its even 32-bit SIMD lanes, and
|
||||
// odd_result_v has numbers we want in its odd 32-bit SIMD lanes.
|
||||
// Use blend to combine them.
|
||||
__m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa);
|
||||
__m256i clipped_v =
|
||||
_mm256_max_epi32(min_v, _mm256_min_epi32(max_v, result_v));
|
||||
|
||||
clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v);
|
||||
clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v);
|
||||
*(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0);
|
||||
}
|
||||
|
||||
for (; i < len; ++i) {
|
||||
dst[i] = RequantizeFixedPoint<uint8_t>(src[i], params);
|
||||
}
|
||||
}
|
||||
|
||||
void RequantizeAvx2(
|
||||
const int32_t* src,
|
||||
uint8_t* dst,
|
||||
int len,
|
||||
const RequantizationParams& params) {
|
||||
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
|
||||
// using AVX2 instructions
|
||||
constexpr int VLEN = 8;
|
||||
|
||||
__m256 multiplier_v = _mm256_set1_ps(params.real_multiplier);
|
||||
__m256i zero_point_v = _mm256_set1_epi16(params.target_qparams.zero_point);
|
||||
|
||||
__m256i min_v = _mm256_set1_epi8(numeric_limits<uint8_t>::min());
|
||||
__m256i max_v = _mm256_set1_epi8(numeric_limits<uint8_t>::max());
|
||||
|
||||
__m256i permute_mask_v =
|
||||
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
|
||||
|
||||
int i = 0;
|
||||
for (; i < len / (VLEN * 4) * (VLEN * 4); i += (VLEN * 4)) {
|
||||
__m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256i y_v =
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i + VLEN));
|
||||
__m256i z_v = _mm256_loadu_si256(
|
||||
reinterpret_cast<const __m256i*>(src + i + 2 * VLEN));
|
||||
__m256i w_v = _mm256_loadu_si256(
|
||||
reinterpret_cast<const __m256i*>(src + i + 3 * VLEN));
|
||||
|
||||
/*
|
||||
* Convert int32_t input to FP32 and multiply by FP32 scale.
|
||||
* Both operations involve statistically unbiased roundings (with default
|
||||
* MXCSR rounding mode):
|
||||
* - Large int32_t values can't be exactly represented as FP32. CVTDQ2PS
|
||||
* instruction on x86 would round it according to nearest FP32 value with
|
||||
* ties to even (assuming default MXCSR rounding mode).
|
||||
* - Product of two FP32 values is generally not exactly representation as
|
||||
* an FP32 value, and will be rounded to nearest FP32 value with ties to
|
||||
* even with default MXCSR rounding mode.
|
||||
*/
|
||||
__m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
|
||||
__m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
|
||||
__m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
|
||||
__m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
|
||||
|
||||
/*
|
||||
* Convert scaled FP32 result to int32_t using CVTPS2DQ instruction.
|
||||
* CVTPS2DQ instruction rounds result according to nearest FP32 value with
|
||||
* ties to even (assuming default MXCSR rounding mode). However, when
|
||||
* conversion overflows, it produces INT32_MIN as a result. For large
|
||||
* positive inputs the result of conversion can become negative, which
|
||||
* affects the final requantization result. Note that on x86 SSE2 we have
|
||||
* e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This happens because
|
||||
* float(INT32_MAX) rounds to 2**31, which overflows int32_t when it is
|
||||
* converted back to integer.
|
||||
*
|
||||
* Thankfully, we can prove that overflow never happens in this
|
||||
* requantization scheme. The largest positive input is INT32_MAX (2**31 -
|
||||
* 1), which turns into 2**31 when converted to float. The largest scale
|
||||
* value is 0x1.FFFFFEp-1. When multiplied together, the result is
|
||||
* 2147483520 (compare to INT32_MAX = 2147483647), which fits into int32_t
|
||||
* without overflow.
|
||||
*/
|
||||
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
|
||||
__m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
|
||||
__m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
|
||||
__m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
|
||||
|
||||
/*
|
||||
* Standard final sequence on x86 AVX2:
|
||||
* - Pack to int16_t and saturate
|
||||
* - Add zero point
|
||||
* - Pack to uint8_t and saturate
|
||||
* - Clamp between qmin and qmax
|
||||
*/
|
||||
__m256i xy_packed_v = _mm256_adds_epi16(
|
||||
_mm256_packs_epi32(x_rounded_v, y_rounded_v), zero_point_v);
|
||||
__m256i zw_packed_v = _mm256_adds_epi16(
|
||||
_mm256_packs_epi32(z_rounded_v, w_rounded_v), zero_point_v);
|
||||
__m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
|
||||
__m256i xyzw_clamped_v =
|
||||
_mm256_max_epu8(min_v, _mm256_min_epu8(xyzw_packed_v, max_v));
|
||||
|
||||
/*
|
||||
* xyzw_clamped_v has results in the following layout so we need to permute:
|
||||
* x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7
|
||||
*/
|
||||
xyzw_clamped_v =
|
||||
_mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
|
||||
|
||||
/*
|
||||
* 4x CVTDQ2PS
|
||||
* 4x MULPS
|
||||
* 4x CVTPS2DQ
|
||||
* 2x PACKSSDW
|
||||
* 1x PACKUSWB
|
||||
* 2x PADDW
|
||||
* 1x PMAXUB
|
||||
* 1x PMINUB
|
||||
* 1x PERMD
|
||||
* ---------------------
|
||||
* 20 instructions total
|
||||
*/
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v);
|
||||
} // i loop vectorized and unrolled 4x
|
||||
|
||||
for (; i < len / VLEN * VLEN; i += VLEN) {
|
||||
__m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
|
||||
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
|
||||
__m256i x_packed_v = _mm256_adds_epi16(
|
||||
_mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), zero_point_v);
|
||||
x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
|
||||
__m256i x_clamped_v =
|
||||
_mm256_max_epu8(min_v, _mm256_min_epu8(x_packed_v, max_v));
|
||||
|
||||
/*
|
||||
* x_clamped_v has results in the following layout so we need to permute:
|
||||
* x0-3 garbage0-11 x4-7 garbage12-23
|
||||
*/
|
||||
x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
|
||||
|
||||
/*
|
||||
* 1x CVTDQ2PS
|
||||
* 1x MULPS
|
||||
* 1x CVTPS2DQ
|
||||
* 1x PACKSSDW
|
||||
* 1x PACKUSWB
|
||||
* 1x PADDW
|
||||
* 1x PMAXUB
|
||||
* 1x PMINUB
|
||||
* 1x PERMD
|
||||
* ---------------------
|
||||
* 9 instructions total
|
||||
*/
|
||||
_mm_storel_epi64(
|
||||
reinterpret_cast<__m128i*>(dst + i),
|
||||
_mm256_castsi256_si128(x_clamped_v));
|
||||
} // i loop vectorized
|
||||
|
||||
for (; i < len; ++i) {
|
||||
dst[i] = Requantize<uint8_t>(src[i], params);
|
||||
} // i loop remainder
|
||||
}
|
||||
#endif
|
||||
|
||||
#define DNNLOWP_SPECIALIZED_REQUANTIZE(T) \
|
||||
template <> \
|
||||
void Requantize<T>( \
|
||||
const int32_t* src, \
|
||||
T* dst, \
|
||||
const int len, \
|
||||
const RequantizationParams& params) { \
|
||||
for (int i = 0; i < len; ++i) { \
|
||||
dst[i] = Requantize<T>(src[i], params); \
|
||||
} \
|
||||
}
|
||||
DNNLOWP_SPECIALIZED_REQUANTIZE(uint16_t)
|
||||
DNNLOWP_SPECIALIZED_REQUANTIZE(int32_t)
|
||||
#undef DNNLOWP_SPECIALIZED_REQUANTIZE
|
||||
|
||||
template <>
|
||||
void Requantize<uint8_t>(
|
||||
const int32_t* src,
|
||||
uint8_t* dst,
|
||||
const int len,
|
||||
const RequantizationParams& params) {
|
||||
if (params.target_qparams.precision == 8 && caffe2::GetCpuId().avx2()) {
|
||||
RequantizeAvx2(src, dst, len, params);
|
||||
} else {
|
||||
for (int i = 0; i < len; ++i) {
|
||||
dst[i] = Requantize<uint8_t>(src[i], params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Quantize(
|
||||
const float* src,
|
||||
T* dst,
|
||||
int len,
|
||||
const TensorQuantizationParams& qparams) {
|
||||
#if defined(__AVX2__) && defined(__FMA__)
|
||||
caffe2::CpuId cpuid = caffe2::GetCpuId();
|
||||
bool avx2_support = cpuid.avx2();
|
||||
bool fma_support = cpuid.fma();
|
||||
if (avx2_support && fma_support && qparams.precision == 8 &&
|
||||
std::is_same<T, uint8_t>::value &&
|
||||
!FLAGS_caffe2_dnnlowp_force_slow_path) {
|
||||
// fast path
|
||||
constexpr int VLEN = 8;
|
||||
std::size_t i = 0;
|
||||
__m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale);
|
||||
for (; i < len / VLEN * VLEN; i += VLEN) {
|
||||
__m256 src_v = _mm256_loadu_ps(src + i);
|
||||
__m256 transformed_v = _mm256_fmadd_ps(
|
||||
src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point));
|
||||
__m256 clipped_v = _mm256_min_ps(
|
||||
_mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)),
|
||||
_mm256_set1_ps(255.f));
|
||||
__m256i rounded_v = _mm256_cvtps_epi32(clipped_v);
|
||||
alignas(64) std::int32_t temp_int32[VLEN];
|
||||
_mm256_store_si256((__m256i*)temp_int32, rounded_v);
|
||||
for (int j = 0; j < VLEN; ++j) {
|
||||
dst[i + j] = temp_int32[j];
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < len; ++i) {
|
||||
float transformed = qparams.zero_point + src[i] / qparams.scale;
|
||||
float clipped = std::min(std::max(transformed, 0.f), 255.f);
|
||||
// Not exactly the same behavior as the vectorized code.
|
||||
// The vectorized code above always rounds to even in halfway cases
|
||||
// (https://software.intel.com/en-us/node/523819), but std::nearbyint
|
||||
// does the same only when the current rounding mode is FE_TONEAREST.
|
||||
// However, in practice, this should not be a problem because most cases
|
||||
// use the default rounding mode FE_TONEAREST.
|
||||
// Note that we cannot implement the same behavior as the vectorized code
|
||||
// using std::round because it does rounding away from zero in halfway
|
||||
// cases.
|
||||
dst[i] = nearbyint(clipped);
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
for (std::size_t i = 0; i < len; ++i) {
|
||||
dst[i] = dnnlowp::Quantize<T>(src[i], qparams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template void Quantize<uint8_t>(
|
||||
const float* src,
|
||||
uint8_t* dst,
|
||||
int len,
|
||||
const TensorQuantizationParams& qparams);
|
||||
|
||||
template void Quantize<int8_t>(
|
||||
const float* src,
|
||||
int8_t* dst,
|
||||
int len,
|
||||
const TensorQuantizationParams& qparams);
|
||||
|
||||
template void Quantize<uint16_t>(
|
||||
const float* src,
|
||||
uint16_t* dst,
|
||||
int len,
|
||||
const TensorQuantizationParams& qparams);
|
||||
|
||||
template void Quantize<int16_t>(
|
||||
const float* src,
|
||||
int16_t* dst,
|
||||
int len,
|
||||
const TensorQuantizationParams& qparams);
|
||||
|
||||
QuantizationFactory::QuantizationKind StringToKind(const string& s) {
|
||||
string s_lower(s);
|
||||
transform(s_lower.begin(), s_lower.end(), s_lower.begin(), ::tolower);
|
||||
|
|
@ -575,88 +206,6 @@ TensorQuantizationParams QuantizationFactory::ChooseQuantizationParams(
|
|||
}
|
||||
}
|
||||
|
||||
TensorQuantizationParams QuantizationFactory::ChooseQuantizationParams_(
|
||||
float min,
|
||||
float max,
|
||||
int32_t qmin,
|
||||
int32_t qmax,
|
||||
bool preserve_sparsity) const {
|
||||
if (min < 0 && max > 0 && preserve_sparsity) {
|
||||
int symmetric_qmin = -((qmax - qmin) / 2 + 1);
|
||||
int symmetric_qmax = (qmax - qmin) / 2;
|
||||
double max_scale =
|
||||
std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax));
|
||||
min = max_scale * symmetric_qmin;
|
||||
max = max_scale * symmetric_qmax;
|
||||
}
|
||||
|
||||
double scale =
|
||||
(std::max(max, 0.f) - std::min(min, 0.f)) / ((double)qmax - qmin);
|
||||
if (scale == 0) {
|
||||
scale = 0.1;
|
||||
}
|
||||
// If scale is 0, we arbitrary adjust the scale to 0.1
|
||||
assert(scale > 0);
|
||||
|
||||
// We extend the [min, max] interval to ensure that it contains 0.
|
||||
// Otherwise, we would not meet the requirement that 0 be an exactly
|
||||
// representable value.
|
||||
min = std::min(min, 0.f);
|
||||
max = std::max(max, 0.f);
|
||||
|
||||
if (force_scale_power_of_two_) {
|
||||
if (scale < 1) {
|
||||
scale = 1. / (1 << (int)floor(log2(1 / scale)));
|
||||
} else {
|
||||
scale = 1 << (int)ceil(log2(scale));
|
||||
}
|
||||
}
|
||||
|
||||
// Zero-point computation.
|
||||
// First the initial floating-point computation. The zero-point can be
|
||||
// determined from solving an affine equation for any known pair
|
||||
// (real value, corresponding quantized value).
|
||||
// We know two such pairs: (rmin, qmin) and (rmax, qmax).
|
||||
// The arithmetic error on the zero point computed from either pair
|
||||
// will be roughly machine_epsilon * (sum of absolute values of terms)
|
||||
// so we want to use the variant that adds the smaller terms.
|
||||
double zero_point_from_min = qmin - min / scale;
|
||||
double zero_point_from_max = qmax - max / scale;
|
||||
double zero_point_from_min_error = std::abs(qmin) + std::abs(min / scale);
|
||||
double zero_point_from_max_error = std::abs(qmax) + std::abs(max / scale);
|
||||
double initial_zero_point =
|
||||
zero_point_from_min_error < zero_point_from_max_error
|
||||
? zero_point_from_min
|
||||
: zero_point_from_max;
|
||||
|
||||
// for symmetric quantization (min == -max), we force zero_point to 128
|
||||
// to model signed integer (FIXME: this is a workaround that gemmlowp
|
||||
// doesn't support signed int AFAIK. Once we have an (efficient) gemm for
|
||||
// signed as well, we can just use signed int with zero_point = 0
|
||||
if (min < 0 && max > 0 && preserve_sparsity) {
|
||||
initial_zero_point = (qmin + qmax) / 2 + 1;
|
||||
}
|
||||
|
||||
// Now we need to nudge the zero point to be an integer
|
||||
// (our zero points are integer, and this is motivated by the requirement
|
||||
// to be able to represent the real value "0" exactly as a quantized value,
|
||||
// which is required in multiple places, for example in Im2col with SAME
|
||||
// padding).
|
||||
int32_t nudged_zero_point = 0;
|
||||
if (initial_zero_point < qmin) {
|
||||
nudged_zero_point = qmin;
|
||||
} else if (initial_zero_point > qmax) {
|
||||
nudged_zero_point = qmax;
|
||||
} else {
|
||||
nudged_zero_point = nearbyint(initial_zero_point);
|
||||
}
|
||||
|
||||
TensorQuantizationParams result;
|
||||
result.scale = scale;
|
||||
result.zero_point = nudged_zero_point;
|
||||
return result;
|
||||
}
|
||||
|
||||
TensorQuantizationParams QuantizationFactory::ChooseQuantizationParams(
|
||||
const float* values,
|
||||
int len,
|
||||
|
|
@ -664,7 +213,7 @@ TensorQuantizationParams QuantizationFactory::ChooseQuantizationParams(
|
|||
int precision,
|
||||
bool preserve_sparsity) const {
|
||||
float min = 0, max = 0;
|
||||
FindMinMax(values, &min, &max, len);
|
||||
fbgemm::FindMinMax(values, &min, &max, len);
|
||||
|
||||
if (MIN_MAX_QUANTIZATION == kind) {
|
||||
return ChooseQuantizationParams(min, max, precision, preserve_sparsity);
|
||||
|
|
@ -703,57 +252,6 @@ TensorQuantizationParams QuantizationFactory::ChooseQuantizationParams(
|
|||
}
|
||||
}
|
||||
|
||||
void QuantizationFactory::ChooseRequantizationMultiplier_(
|
||||
float real_multiplier,
|
||||
int32_t* quantized_multiplier,
|
||||
int* right_shift) const {
|
||||
assert(real_multiplier != 0.f);
|
||||
|
||||
// Assuming requantization_multiplier_precision_ = 31,
|
||||
// the default right shift is 31 when the real multiplier is already
|
||||
// in interval [1/2, 1).
|
||||
// Multiplying a 32-bit signed integer with all 31 bits except the sign bit
|
||||
// is used followed by 31-bit right shift implements multiplying with a real
|
||||
// number in [1/2, 1).
|
||||
// We want to utilize all 31 bits except the sign bit in the 32-bit signed
|
||||
// integer to get the best accuracy.
|
||||
int s = 31;
|
||||
|
||||
// We want to bring the real multiplier into the interval [1/2, 1).
|
||||
// We can do so by multiplying it by two, and recording how many times
|
||||
// we multiplied by two so that we can compensate that by a right
|
||||
// shift by the same amount.
|
||||
if (real_multiplier > 0.f) {
|
||||
while (real_multiplier < 0.5f) {
|
||||
real_multiplier *= 2.f;
|
||||
s++;
|
||||
}
|
||||
while (real_multiplier > 1.f) {
|
||||
real_multiplier /= 2.f;
|
||||
s--;
|
||||
}
|
||||
}
|
||||
// Now that the real multiplier is in [1/2, 1), we convert it
|
||||
// into a fixed-point number.
|
||||
int64_t q = nearbyint(
|
||||
real_multiplier * (1ll << (requantization_multiplier_precision_ - 1)));
|
||||
assert(q <= (1ll << (requantization_multiplier_precision_ - 1)));
|
||||
// Handle the special case when the real multiplier was so close to 1
|
||||
// that its fixed-point approximation was undistinguishable from 1.
|
||||
// We handle this by dividing it by two, and remembering to decrement
|
||||
// the right shift amount.
|
||||
if (q == (1ll << (requantization_multiplier_precision_ - 1))) {
|
||||
q /= 2;
|
||||
s--;
|
||||
}
|
||||
assert(s >= 0);
|
||||
assert(q >= 0);
|
||||
assert(q <= numeric_limits<int32_t>::max());
|
||||
*quantized_multiplier = static_cast<int32_t>(q);
|
||||
*right_shift = s;
|
||||
assert(s < 64);
|
||||
}
|
||||
|
||||
RequantizationParams QuantizationFactory::ChooseRequantizationMultiplier(
|
||||
float real_multiplier,
|
||||
TensorQuantizationParams target_qparams) const {
|
||||
|
|
@ -761,47 +259,13 @@ RequantizationParams QuantizationFactory::ChooseRequantizationMultiplier(
|
|||
params.target_qparams = target_qparams;
|
||||
params.real_multiplier = real_multiplier;
|
||||
|
||||
ChooseRequantizationMultiplier_(
|
||||
real_multiplier, ¶ms.multiplier, ¶ms.right_shift);
|
||||
fbgemm::ChooseRequantizationMultiplier(
|
||||
real_multiplier,
|
||||
¶ms.multiplier,
|
||||
¶ms.right_shift,
|
||||
requantization_multiplier_precision_);
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
void FindMinMax(const float* a, float* min, float* max, int len) {
|
||||
if (len <= 0) {
|
||||
*min = 0.0f;
|
||||
*max = 0.0f;
|
||||
return;
|
||||
}
|
||||
|
||||
float temp_min = *a, temp_max = *a;
|
||||
int i = 0;
|
||||
|
||||
#ifdef __AVX__
|
||||
__m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a);
|
||||
constexpr int VLEN = 8;
|
||||
if (len >= VLEN) {
|
||||
for (; i < len / VLEN * VLEN; i += VLEN) {
|
||||
min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i));
|
||||
max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i));
|
||||
}
|
||||
|
||||
float min_buf[VLEN], max_buf[VLEN];
|
||||
_mm256_storeu_ps(min_buf, min_v);
|
||||
_mm256_storeu_ps(max_buf, max_v);
|
||||
for (int j = 0; j < VLEN; ++j) {
|
||||
temp_min = std::min(temp_min, min_buf[j]);
|
||||
temp_max = std::max(temp_max, max_buf[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; i < len; i++) {
|
||||
temp_min = std::min(temp_min, a[i]);
|
||||
temp_max = std::max(temp_max, a[i]);
|
||||
}
|
||||
*min = temp_min;
|
||||
*max = temp_max;
|
||||
}
|
||||
|
||||
} // namespace dnnlowp
|
||||
|
|
|
|||
|
|
@ -8,200 +8,15 @@
|
|||
|
||||
#include <x86intrin.h>
|
||||
|
||||
#include <fbgemm/QuantUtils.h>
|
||||
|
||||
#include "caffe2/quantization/server/dynamic_histogram.h"
|
||||
#include "caffe2/utils/cpuid.h"
|
||||
|
||||
namespace dnnlowp {
|
||||
|
||||
// Copied from gemmlowp
|
||||
//
|
||||
// A structure to hold quantization parameters 'scale' and 'zero_point'
|
||||
// as discussed in doc/quantization.md. As explained there, the meaning
|
||||
// of these values is as the constants in the quantization equation
|
||||
//
|
||||
// real_value = scale * (quantized_value - zero_point)
|
||||
//
|
||||
// In other words, 'zero_point' is the quantized value that corresponds
|
||||
// to the real value 0, and 'scale' is the difference of real values
|
||||
// corresponding to consecutive quantized values.
|
||||
struct TensorQuantizationParams {
|
||||
float scale;
|
||||
std::int32_t zero_point;
|
||||
int precision;
|
||||
float Min() const;
|
||||
float Max() const;
|
||||
};
|
||||
|
||||
// Parameters when we scale from one quantization parameter to another
|
||||
struct RequantizationParams {
|
||||
float real_multiplier;
|
||||
std::int32_t multiplier;
|
||||
int right_shift;
|
||||
TensorQuantizationParams target_qparams;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Utility functions
|
||||
|
||||
/// Clamp src in T1 to the desired precision and convert it to T2
|
||||
template <typename T1, typename T2 = std::uint8_t>
|
||||
T2 clamp(T1 src, int precision, bool is_signed = false)
|
||||
// TODO: T26263653 fix signed-integer-overflow undefined behavior
|
||||
#if defined(__has_feature)
|
||||
#if __has_feature(__address_sanitizer__)
|
||||
__attribute__((__no_sanitize__("signed-integer-overflow")))
|
||||
#endif
|
||||
#endif
|
||||
{
|
||||
std::int32_t min = is_signed ? -(1LL << (precision - 1)) : 0;
|
||||
std::int32_t max =
|
||||
is_signed ? ((1LL << (precision - 1)) - 1) : (1LL << precision) - 1;
|
||||
|
||||
// Make sure T1 and T2 can represent the precision
|
||||
assert(min >= std::numeric_limits<T1>::lowest());
|
||||
assert(min >= std::numeric_limits<T2>::lowest());
|
||||
assert(max <= std::numeric_limits<T1>::max());
|
||||
assert(max <= std::numeric_limits<T2>::max());
|
||||
|
||||
return std::min<T1>(std::max<T1>(src, min), max);
|
||||
}
|
||||
|
||||
/// Quantize src using zero_point and scale, clamp to the specified precision,
|
||||
/// and convert it to type T
|
||||
template <typename T>
|
||||
T Quantize(
|
||||
float src,
|
||||
std::int32_t zero_point,
|
||||
float scale,
|
||||
int result_precision,
|
||||
bool result_is_signed = std::is_signed<T>::value) {
|
||||
const float transformed_val = zero_point + src / scale;
|
||||
return clamp<std::int64_t, T>(
|
||||
static_cast<std::int64_t>(std::nearbyint(transformed_val)),
|
||||
result_precision,
|
||||
result_is_signed);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Quantize(float src, const TensorQuantizationParams& qparams) {
|
||||
return dnnlowp::Quantize<T>(
|
||||
src, qparams.zero_point, qparams.scale, qparams.precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Quantize(
|
||||
const float* src,
|
||||
T* dst,
|
||||
int len,
|
||||
const TensorQuantizationParams& qparams);
|
||||
|
||||
template <typename T>
|
||||
float Dequantize(T src, const TensorQuantizationParams& qparams) {
|
||||
return qparams.scale * (src - qparams.zero_point);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Dequantize(
|
||||
const T* src,
|
||||
float* dst,
|
||||
int len,
|
||||
const TensorQuantizationParams& qparams) {
|
||||
for (std::size_t i = 0; i < len; i++) {
|
||||
dst[i] = Dequantize(src[i], qparams);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Requantization (pure-integer)
|
||||
|
||||
std::int64_t
|
||||
SaturatingRoundingMulWithShift(std::int32_t a, std::int32_t b, int right_shift);
|
||||
|
||||
template <typename T>
|
||||
T Requantize(
|
||||
std::int32_t src, // int32 input before requantization
|
||||
std::int32_t zero_point,
|
||||
std::int32_t multiplier,
|
||||
int right_shift,
|
||||
int result_precision,
|
||||
bool result_is_signed = false) {
|
||||
std::int64_t quantized_down =
|
||||
zero_point + SaturatingRoundingMulWithShift(src, multiplier, right_shift);
|
||||
return clamp<std::int64_t, T>(
|
||||
quantized_down, result_precision, result_is_signed);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T RequantizeFixedPoint(
|
||||
std::int32_t src, // int32 input before requantization
|
||||
const RequantizationParams& params) {
|
||||
return Requantize<T>(
|
||||
src,
|
||||
params.target_qparams.zero_point,
|
||||
params.multiplier,
|
||||
params.right_shift,
|
||||
params.target_qparams.precision);
|
||||
}
|
||||
|
||||
void RequantizeFixedPointAvx2(
|
||||
const std::int32_t* src,
|
||||
std::uint8_t* dst,
|
||||
int len,
|
||||
const RequantizationParams& params);
|
||||
|
||||
template <typename T>
|
||||
void RequantizeFixedPoint(
|
||||
const std::int32_t* src,
|
||||
T* dst,
|
||||
int len,
|
||||
const RequantizationParams& params) {
|
||||
if (std::is_same<T, uint8_t>::value && params.target_qparams.precision == 8 &&
|
||||
caffe2::GetCpuId().avx2()) {
|
||||
RequantizeFixedPointAvx2(src, dst, len, params);
|
||||
} else {
|
||||
for (int i = 0; i < len; ++i) {
|
||||
dst[i] = RequantizeFixedPoint<T>(src[i], params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Requantization (with floats)
|
||||
|
||||
template <typename T>
|
||||
T Requantize(
|
||||
std::int32_t src, // int32 input before requantization
|
||||
std::int32_t zero_point,
|
||||
float multiplier,
|
||||
int result_precision,
|
||||
bool result_is_signed = false) {
|
||||
long quantized_down = zero_point + std::lrintf(src * multiplier);
|
||||
return clamp<long, T>(quantized_down, result_precision, result_is_signed);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Requantize(
|
||||
std::int32_t src, // int32 input before requantization
|
||||
const RequantizationParams& params) {
|
||||
return Requantize<T>(
|
||||
src,
|
||||
params.target_qparams.zero_point,
|
||||
params.real_multiplier,
|
||||
params.target_qparams.precision);
|
||||
}
|
||||
|
||||
void RequantizeAvx2(
|
||||
const std::int32_t* src,
|
||||
std::uint8_t* dst,
|
||||
int len,
|
||||
const RequantizationParams& params);
|
||||
|
||||
template <typename T>
|
||||
void Requantize(
|
||||
const std::int32_t* src,
|
||||
T* dst,
|
||||
int len,
|
||||
const RequantizationParams& params);
|
||||
using fbgemm::RequantizationParams;
|
||||
using fbgemm::TensorQuantizationParams;
|
||||
|
||||
// Represents a quantization scheme that provides quantization parameter based
|
||||
// on distribution of data to be quantized.
|
||||
|
|
@ -234,12 +49,13 @@ class QuantizationFactory {
|
|||
int precision,
|
||||
bool preserve_sparsity,
|
||||
bool is_signed = false) const {
|
||||
TensorQuantizationParams qparams = ChooseQuantizationParams_(
|
||||
TensorQuantizationParams qparams = fbgemm::ChooseQuantizationParams(
|
||||
min,
|
||||
max,
|
||||
is_signed ? -(1 << (precision - 1)) : 0,
|
||||
is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
|
||||
preserve_sparsity);
|
||||
preserve_sparsity,
|
||||
force_scale_power_of_two_);
|
||||
qparams.precision = precision;
|
||||
return qparams;
|
||||
}
|
||||
|
|
@ -349,20 +165,6 @@ class QuantizationFactory {
|
|||
QuantizationKind weight_kind = MIN_MAX_QUANTIZATION);
|
||||
|
||||
private:
|
||||
/// Choose quantization scale and zero_point that maps
|
||||
/// floating-point range [min, max] to integer range [qmin, qmax]
|
||||
TensorQuantizationParams ChooseQuantizationParams_(
|
||||
float min,
|
||||
float max,
|
||||
std::int32_t qmin,
|
||||
std::int32_t qmax,
|
||||
bool preserve_sparsity) const;
|
||||
|
||||
void ChooseRequantizationMultiplier_(
|
||||
float real_multiplier,
|
||||
std::int32_t* quantized_multiplier,
|
||||
int* right_shift) const;
|
||||
|
||||
int activation_precision_;
|
||||
int weight_precision_;
|
||||
int requantization_multiplier_precision_;
|
||||
|
|
@ -378,9 +180,4 @@ class QuantizationFactory {
|
|||
*/
|
||||
QuantizationFactory::QuantizationKind StringToKind(const std::string& s);
|
||||
|
||||
/**
|
||||
* Find the min and max value in a float matrix.
|
||||
*/
|
||||
void FindMinMax(const float* m, float* min, float* max, int len);
|
||||
|
||||
} // namespace dnnlowp
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ class DNNLowPOp : public Operator<CPUContext> {
|
|||
OutputTensorCPU_(0)->numel(),
|
||||
out_qparams_);
|
||||
} else {
|
||||
PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
}
|
||||
|
||||
MeasureQuantizationError_();
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
|||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int j = 0; j < InputTensorCPU_(i).numel(); ++j) {
|
||||
quantized_in[j] = Requantize<int32_t>(
|
||||
quantized_in[j] = fbgemm::Requantize<int32_t>(
|
||||
input_data[j] - in_qparams_[i].zero_point,
|
||||
in_requantization_params);
|
||||
}
|
||||
|
|
@ -63,7 +63,7 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
|||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int j = 0; j < InputTensorCPU_(i).numel(); ++j) {
|
||||
quantized_in[j] = Quantize<uint32_t>(
|
||||
quantized_in[j] = fbgemm::Quantize<uint32_t>(
|
||||
input_data[j],
|
||||
intermediate_qparams_.zero_point,
|
||||
intermediate_qparams_.scale,
|
||||
|
|
@ -87,7 +87,7 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
|||
#endif
|
||||
for (int i = 0; i < C->numel(); ++i) {
|
||||
int32_t raw = A_quantized[i] + B_quantized[i] - intermediate_zero_point;
|
||||
C_quantized[i] = Requantize<T>(raw, requantization_params_);
|
||||
C_quantized[i] = fbgemm::Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
} else if (B.numel() == 1) {
|
||||
#ifdef _OPENMP
|
||||
|
|
@ -95,7 +95,7 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
|||
#endif
|
||||
for (int i = 0; i < C->numel(); ++i) {
|
||||
int32_t raw = A_quantized[i] + B_quantized[0] - intermediate_zero_point;
|
||||
C_quantized[i] = Requantize<T>(raw, requantization_params_);
|
||||
C_quantized[i] = fbgemm::Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
} else {
|
||||
size_t pre, n, post;
|
||||
|
|
@ -110,7 +110,7 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
|||
int32_t raw = A_quantized[((i * n) + j) * post + k] +
|
||||
B_quantized[j] - intermediate_zero_point;
|
||||
C_quantized[((i * n) + j) * post + k] =
|
||||
Requantize<T>(raw, requantization_params_);
|
||||
fbgemm::Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class UnaryElementwiseWithArgsDNNLowPOp : public Operator<CPUContext> {
|
|||
input.template data<T>(),
|
||||
output.template mutable_data<T>());
|
||||
|
||||
PropagateOutputTensorQuantizationParams(
|
||||
dnnlowp::PropagateOutputTensorQuantizationParams(
|
||||
this, 0, functor_.GetOutputQuantizationParams());
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ bool ElementwiseLinearDNNLowPOp<T>::RunOnDevice() {
|
|||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < b.numel(); ++i) {
|
||||
b_quantized[i] = Quantize<int32_t>(
|
||||
b_quantized[i] = fbgemm::Quantize<int32_t>(
|
||||
b_data[i],
|
||||
0,
|
||||
in_qparams_[0].scale * in_qparams_[1].scale,
|
||||
|
|
@ -78,7 +78,8 @@ bool ElementwiseLinearDNNLowPOp<T>::RunOnDevice() {
|
|||
int32_t raw = (X_quantized[n * D + d] - in_qparams_[0].zero_point) *
|
||||
(a_quantized_[d] - in_qparams_[1].zero_point) +
|
||||
b_quantized[d];
|
||||
Y_quantized[n * D + d] = Requantize<T>(raw, requantization_params_);
|
||||
Y_quantized[n * D + d] =
|
||||
fbgemm::Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -101,7 +102,7 @@ bool ElementwiseLinearDNNLowPOp<T>::GetQuantizationParameters_() {
|
|||
a.template data<float>(), a.numel(), true /*weight*/);
|
||||
|
||||
a_quantized_.resize(a.numel());
|
||||
Quantize<T>(
|
||||
fbgemm::Quantize<T>(
|
||||
a.template data<float>(),
|
||||
a_quantized_.data(),
|
||||
a_quantized_.size(),
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class MulDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, MulFp32Op> {
|
|||
for (int i = 0; i < C->size(); ++i) {
|
||||
int32_t raw = (A_quantized[i] - in_qparams_[0].zero_point) *
|
||||
(B_quantized[i] - in_qparams_[1].zero_point);
|
||||
C_quantized[i] = Requantize<T>(raw, requantization_params_);
|
||||
C_quantized[i] = fbgemm::Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
} else if (B.size() == 1) {
|
||||
#ifdef _OPENMP
|
||||
|
|
@ -65,7 +65,7 @@ class MulDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, MulFp32Op> {
|
|||
for (int i = 0; i < C->size(); ++i) {
|
||||
int32_t raw = (A_quantized[i] - in_qparams_[0].zero_point) *
|
||||
(B_quantized[0] - in_qparams_[1].zero_point);
|
||||
C_quantized[i] = Requantize<T>(raw, requantization_params_);
|
||||
C_quantized[i] = fbgemm::Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
} else {
|
||||
size_t pre, n, post;
|
||||
|
|
@ -81,7 +81,7 @@ class MulDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, MulFp32Op> {
|
|||
in_qparams_[0].zero_point) *
|
||||
(B_quantized[j] - in_qparams_[1].zero_point);
|
||||
C_quantized[((i * n) + j) * post + k] =
|
||||
Requantize<T>(raw, requantization_params_);
|
||||
fbgemm::Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ bool SumDNNLowPOp<T, ReluFused>::RunOnDevice() {
|
|||
for (int j = j_begin; j < j_end; ++j) {
|
||||
int32_t acc = 0;
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
acc += Requantize<int32_t>(
|
||||
acc += fbgemm::Requantize<int32_t>(
|
||||
input_data[i][j] - in_qparams_[i].zero_point,
|
||||
in_requantization_params[i]);
|
||||
}
|
||||
|
|
@ -215,7 +215,8 @@ bool SumDNNLowPOp<T, ReluFused>::RunOnDevice() {
|
|||
if (ReluFused) {
|
||||
raw = std::max(0, raw);
|
||||
}
|
||||
output_data[j] = Requantize<T>(raw, out_requantization_params_);
|
||||
output_data[j] =
|
||||
fbgemm::Requantize<T>(raw, out_requantization_params_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -237,7 +238,7 @@ bool SumDNNLowPOp<T, ReluFused>::RunOnDevice() {
|
|||
for (int j = j_begin; j < j_end; ++j) {
|
||||
int32_t acc = 0;
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
acc += Quantize<int32_t>(
|
||||
acc += fbgemm::Quantize<int32_t>(
|
||||
((const float*)input_data[i])[j],
|
||||
intermediate_qparams_.zero_point,
|
||||
intermediate_qparams_.scale,
|
||||
|
|
@ -247,7 +248,7 @@ bool SumDNNLowPOp<T, ReluFused>::RunOnDevice() {
|
|||
if (ReluFused) {
|
||||
raw = std::max(0, raw);
|
||||
}
|
||||
output_data[j] = Requantize<T>(raw, out_requantization_params_);
|
||||
output_data[j] = fbgemm::Requantize<T>(raw, out_requantization_params_);
|
||||
}
|
||||
}
|
||||
} // !InputTensorCPU_(0).template IsType<T>()
|
||||
|
|
|
|||
|
|
@ -348,8 +348,8 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
|
|||
in_qparams_[0].zero_point * column_offsets_[j] + row_offset;
|
||||
Y_int32_[i * N + j] += b_quantized_data_[j];
|
||||
|
||||
Ydata[i * N + j] =
|
||||
Requantize<T>(Y_int32_[i * N + j], requantization_params_);
|
||||
Ydata[i * N + j] = fbgemm::Requantize<T>(
|
||||
Y_int32_[i * N + j], requantization_params_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -431,7 +431,7 @@ bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
|
|||
// Adjust for the fact that weight will actually use signed.
|
||||
in_qparams_[1].zero_point += signed_min;
|
||||
|
||||
Quantize<T_signed>(
|
||||
fbgemm::Quantize<T_signed>(
|
||||
W.template data<float>(),
|
||||
W_quantized_.data(),
|
||||
W_quantized_.size(),
|
||||
|
|
@ -474,7 +474,7 @@ bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
|
|||
in_qparams_[1].zero_point += signed_min;
|
||||
|
||||
W_quantized_.resize(W.size());
|
||||
Quantize<T_signed>(
|
||||
fbgemm::Quantize<T_signed>(
|
||||
W.template data<float>(),
|
||||
W_quantized_.data(),
|
||||
W_quantized_.size(),
|
||||
|
|
@ -529,7 +529,7 @@ bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
|
|||
b_dequantized_.resize(N);
|
||||
for (int j = 0; j < N; ++j) {
|
||||
b_dequantized_[j] =
|
||||
Dequantize<int32_t>(b_quantized_data_[j], in_qparams_[2]);
|
||||
fbgemm::Dequantize<int32_t>(b_quantized_data_[j], in_qparams_[2]);
|
||||
}
|
||||
b_dequantized_data_ = b_dequantized_.data();
|
||||
}
|
||||
|
|
@ -540,7 +540,7 @@ bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
|
|||
if (!dequantize_output_) {
|
||||
b_quantized_.resize(N);
|
||||
for (int j = 0; j < N; ++j) {
|
||||
b_quantized_[j] = Quantize<int32_t>(
|
||||
b_quantized_[j] = fbgemm::Quantize<int32_t>(
|
||||
b_dequantized_data_[j],
|
||||
in_qparams_[2].zero_point,
|
||||
in_qparams_[2].scale,
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::RunOnDevice() {
|
|||
in_qparams_[0].zero_point * column_offsets_[j] +
|
||||
rowwise_qparams_[j].zero_point * row_offset;
|
||||
Y_int32_[i * N + j] += b_quantized_[j];
|
||||
Ydata[i * N + j] = Requantize<T>(
|
||||
Ydata[i * N + j] = fbgemm::Requantize<T>(
|
||||
Y_int32_[i * N + j], rowwise_requantization_params_[j]);
|
||||
}
|
||||
}
|
||||
|
|
@ -259,7 +259,7 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::GetQuantizationParameters_() {
|
|||
W.template data<float>() + K * i, K, true /*weight*/);
|
||||
rowwise_qparams_[i].zero_point -=
|
||||
(1 << (qfactory_->GetWeightPrecision() - 1));
|
||||
Quantize<T_signed>(
|
||||
fbgemm::Quantize<T_signed>(
|
||||
W.template data<float>() + K * i,
|
||||
W_quantized_.data() + K * i,
|
||||
K,
|
||||
|
|
@ -290,7 +290,7 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::GetQuantizationParameters_() {
|
|||
LOG(WARNING) << "Not supporting nonconstant weights";
|
||||
in_qparams_[1] =
|
||||
GetInputTensorQuantizationParamsOf(this, 1, qfactory_.get());
|
||||
Quantize<T_signed>(
|
||||
fbgemm::Quantize<T_signed>(
|
||||
W.template data<float>(),
|
||||
W_quantized_.data(),
|
||||
W_quantized_.size(),
|
||||
|
|
@ -324,7 +324,7 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::GetQuantizationParameters_() {
|
|||
const auto& b = InputTensorCPU_(2);
|
||||
const float* b_data = b.template data<float>();
|
||||
for (int j = 0; j < N; ++j) {
|
||||
b_quantized_[j] = Quantize<int32_t>(
|
||||
b_quantized_[j] = fbgemm::Quantize<int32_t>(
|
||||
b_data[j], 0, in_qparams_[0].scale * rowwise_qparams_[j].scale, 32);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ void AffineBatchChannelAndRequantizeNCHWAVX2<uint8_t>(
|
|||
(Y_ptr + j));
|
||||
}
|
||||
for (int j = 0; j < r; ++j) {
|
||||
Y_ptr[n + j] = dnnlowp::Requantize<uint8_t>(
|
||||
Y_ptr[n + j] = fbgemm::Requantize<uint8_t>(
|
||||
static_cast<int32_t>(X_ptr[n + j]) * scale[i] + bias[i], params);
|
||||
}
|
||||
}
|
||||
|
|
@ -283,7 +283,7 @@ void AffineBatchChannelAndRequantizeNHWCAVX2<uint8_t>(
|
|||
(Y_ptr + j));
|
||||
}
|
||||
for (int j = 0; j < r; ++j) {
|
||||
Y_ptr[n + j] = dnnlowp::Requantize<uint8_t>(
|
||||
Y_ptr[n + j] = fbgemm::Requantize<uint8_t>(
|
||||
static_cast<int32_t>(X_ptr[n + j]) * scale[c + n + j] +
|
||||
bias[c + n + j],
|
||||
params);
|
||||
|
|
@ -364,7 +364,7 @@ void GroupNormDNNLowPOp<T>::QuantizeGamma() {
|
|||
if (dequantize_output_) {
|
||||
gamma_dequantized_.resize(C);
|
||||
gamma_dequantized_data_ = gamma_dequantized_.data();
|
||||
dnnlowp::Dequantize<int32_t>(
|
||||
fbgemm::Dequantize<int32_t>(
|
||||
gamma_quantized_data_,
|
||||
gamma_dequantized_.data(),
|
||||
C,
|
||||
|
|
@ -394,7 +394,7 @@ void GroupNormDNNLowPOp<T>::QuantizeGammaImpl() {
|
|||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < C; ++i) {
|
||||
gamma_quantized_[i] = dnnlowp::Quantize<int32_t>(
|
||||
gamma_quantized_[i] = fbgemm::Quantize<int32_t>(
|
||||
gamma_dequantized_data_[i],
|
||||
gamma_qparams.zero_point,
|
||||
gamma_qparams.scale,
|
||||
|
|
@ -424,7 +424,7 @@ void GroupNormDNNLowPOp<T>::QuantizeBeta() {
|
|||
if (dequantize_output_) {
|
||||
beta_dequantized_.resize(C);
|
||||
beta_dequantized_data_ = beta_dequantized_.data();
|
||||
dnnlowp::Dequantize<int32_t>(
|
||||
fbgemm::Dequantize<int32_t>(
|
||||
beta_quantized_data_, beta_dequantized_.data(), C, beta_qparams);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -437,7 +437,7 @@ void GroupNormDNNLowPOp<T>::QuantizeBeta() {
|
|||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < C; ++i) {
|
||||
beta_quantized_[i] = dnnlowp::Quantize<int32_t>(
|
||||
beta_quantized_[i] = fbgemm::Quantize<int32_t>(
|
||||
beta_dequantized_data_[i],
|
||||
beta_qparams.zero_point,
|
||||
beta_qparams.scale,
|
||||
|
|
@ -482,7 +482,7 @@ void GroupNormDNNLowPOp<T>::QuantizedGroupMomentsNCHW(
|
|||
const float var =
|
||||
static_cast<float>(sumsq) / static_cast<float>(inner_size) -
|
||||
mean * mean;
|
||||
rsig_dequantized_[i] = dnnlowp::Dequantize<float>(var, var_qparams);
|
||||
rsig_dequantized_[i] = fbgemm::Dequantize<float>(var, var_qparams);
|
||||
}
|
||||
ComputeQuantizedInvStd(
|
||||
outer_size, rsig_dequantized_.data(), rsig_dequantized_.data(), rsig);
|
||||
|
|
@ -527,7 +527,7 @@ void GroupNormDNNLowPOp<T>::QuantizedGroupMomentsNHWC(
|
|||
const float var =
|
||||
static_cast<float>(sumsq) / static_cast<float>(inner_size) -
|
||||
mean * mean;
|
||||
rsig_dequantized_[i] = dnnlowp::Dequantize<float>(var, var_qparams);
|
||||
rsig_dequantized_[i] = fbgemm::Dequantize<float>(var, var_qparams);
|
||||
}
|
||||
ComputeQuantizedInvStd(
|
||||
outer_size, rsig_dequantized_.data(), rsig_dequantized_.data(), rsig);
|
||||
|
|
@ -547,7 +547,7 @@ void GroupNormDNNLowPOp<T>::DequantizedGroupMomentsNCHW(
|
|||
const int outer_size = N * G;
|
||||
const int inner_size = K * HxW;
|
||||
X_dequantized_.resize(size);
|
||||
dnnlowp::Dequantize<T>(X, X_dequantized_.data(), size, in_qparams_[INPUT]);
|
||||
fbgemm::Dequantize<T>(X, X_dequantized_.data(), size, in_qparams_[INPUT]);
|
||||
const std::array<int, 2> dims = {outer_size, inner_size};
|
||||
const int axis = 1;
|
||||
math::Moments<float, CPUContext>(
|
||||
|
|
@ -568,7 +568,7 @@ void GroupNormDNNLowPOp<T>::DequantizedGroupMomentsNHWC(
|
|||
const int size = N * C * HxW;
|
||||
const int outer_size = N * G;
|
||||
X_dequantized_.resize(size);
|
||||
dnnlowp::Dequantize<T>(X, X_dequantized_.data(), size, in_qparams_[INPUT]);
|
||||
fbgemm::Dequantize<T>(X, X_dequantized_.data(), size, in_qparams_[INPUT]);
|
||||
const std::array<int, 4> dims = {N, HxW, G, K};
|
||||
const std::array<int, 2> axes = {1, 3};
|
||||
math::Moments<float, CPUContext>(
|
||||
|
|
@ -644,7 +644,7 @@ bool GroupNormDNNLowPOp<T>::RunOnDeviceWithOrderNCHW() {
|
|||
bias_data);
|
||||
AffineBatchChannelQuantizedNCHW(
|
||||
N, C, HxW, X_data, scale_data, bias_data, Y_data);
|
||||
PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
}
|
||||
MeasureQuantizationError_();
|
||||
return true;
|
||||
|
|
@ -712,7 +712,7 @@ bool GroupNormDNNLowPOp<T>::RunOnDeviceWithOrderNHWC() {
|
|||
bias_data);
|
||||
AffineBatchChannelQuantizedNHWC(
|
||||
N, C, HxW, X_data, scale_data, bias_data, Y_data);
|
||||
PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
|
||||
}
|
||||
MeasureQuantizationError_();
|
||||
return true;
|
||||
|
|
@ -736,7 +736,7 @@ void GroupNormDNNLowPOp<T>::ComputeQuantizedInvStd(
|
|||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N; ++i) {
|
||||
rsig_quantized[i] = dnnlowp::Quantize<int32_t>(
|
||||
rsig_quantized[i] = fbgemm::Quantize<int32_t>(
|
||||
rsig[i], rsig_qparams_.zero_point, rsig_qparams_.scale, 32);
|
||||
}
|
||||
}
|
||||
|
|
@ -765,7 +765,7 @@ void GroupNormDNNLowPOp<T>::ComputeQuantizedFusedParams(
|
|||
qfactory_->ChooseRequantizationMultiplier(
|
||||
real_multiplier, internal_qparams_);
|
||||
for (int i = 0; i < C; ++i) {
|
||||
bias[i] = dnnlowp::Requantize<int32_t>(
|
||||
bias[i] = fbgemm::Requantize<int32_t>(
|
||||
beta[i],
|
||||
internal_qparams_.zero_point,
|
||||
beta_requantization_params.multiplier,
|
||||
|
|
@ -848,7 +848,7 @@ void GroupNormDNNLowPOp<T>::AffineBatchChannelQuantizedNCHW(
|
|||
ConstEigenVectorArrayMap<int32_t>(scale, N * C).transpose())
|
||||
.rowwise() +
|
||||
ConstEigenVectorArrayMap<int32_t>(bias, N * C).transpose();
|
||||
dnnlowp::Requantize<T>(Y_int32_data, Y, size, out_requantization_params);
|
||||
fbgemm::Requantize<T>(Y_int32_data, Y, size, out_requantization_params);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -883,7 +883,7 @@ void GroupNormDNNLowPOp<T>::AffineBatchChannelQuantizedNHWC(
|
|||
.colwise() +
|
||||
ConstEigenVectorArrayMap<int32_t>(bias + i * C, C);
|
||||
}
|
||||
dnnlowp::Requantize<T>(Y_int32_.data(), Y, size, out_requantization_params);
|
||||
fbgemm::Requantize<T>(Y_int32_.data(), Y, size, out_requantization_params);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -128,20 +128,20 @@ static void LSTMUnit(
|
|||
H[d] = H_out_qparams.zero_point;
|
||||
C[d] = C_out_qparams.zero_point;
|
||||
} else {
|
||||
H[d] = Requantize<T>(
|
||||
H[d] = fbgemm::Requantize<T>(
|
||||
H_prev[d] - H_in_qparams.zero_point, h_in_to_out_params);
|
||||
C[d] = Requantize<T>(
|
||||
C[d] = fbgemm::Requantize<T>(
|
||||
C_prev[d] - C_in_qparams.zero_point, c_in_to_out_params);
|
||||
}
|
||||
} else {
|
||||
T i_in =
|
||||
Requantize<T>(X[d] - X_qparams.zero_point, x_to_sigmoid_params);
|
||||
T f_in = Requantize<T>(
|
||||
T i_in = fbgemm::Requantize<T>(
|
||||
X[d] - X_qparams.zero_point, x_to_sigmoid_params);
|
||||
T f_in = fbgemm::Requantize<T>(
|
||||
X[1 * D + d] + forget_bias - 2 * X_qparams.zero_point,
|
||||
x_to_sigmoid_params);
|
||||
T o_in = Requantize<T>(
|
||||
T o_in = fbgemm::Requantize<T>(
|
||||
X[2 * D + d] - X_qparams.zero_point, x_to_sigmoid_params);
|
||||
T g_in = Requantize<T>(
|
||||
T g_in = fbgemm::Requantize<T>(
|
||||
X[3 * D + d] - X_qparams.zero_point, x_to_tanh_params);
|
||||
|
||||
const T i = sigmoid.Compute(i_in);
|
||||
|
|
@ -159,7 +159,7 @@ static void LSTMUnit(
|
|||
((int32_t)i - sigmoid_zero_point) * ((int32_t)g - tanh_zero_point);
|
||||
|
||||
// c_temp.scale = sigmoid_out.scale * tanh_out.scale
|
||||
int32_t f_times_c_prev_rescaled = Requantize<int32_t>(
|
||||
int32_t f_times_c_prev_rescaled = fbgemm::Requantize<int32_t>(
|
||||
f_times_c_prev,
|
||||
0,
|
||||
c_to_tanh_params.real_multiplier,
|
||||
|
|
@ -168,15 +168,17 @@ static void LSTMUnit(
|
|||
int32_t c_temp = f_times_c_prev_rescaled + i_times_g;
|
||||
|
||||
// scale back to c.scale
|
||||
C[d] = Requantize<T>(c_temp, c_out_requantization_params);
|
||||
C[d] = fbgemm::Requantize<T>(c_temp, c_out_requantization_params);
|
||||
|
||||
T c_tanh_input = Requantize<T>(c_temp, c_tanh_requantization_params);
|
||||
T c_tanh_input =
|
||||
fbgemm::Requantize<T>(c_temp, c_tanh_requantization_params);
|
||||
T host_tanh_c = tanh.Compute(c_tanh_input);
|
||||
|
||||
// o_times_host_tanh_c.scale = sigmoid_out.scale * tanh_out.scale
|
||||
int32_t o_times_host_tanh_c = ((int32_t)o - sigmoid_zero_point) *
|
||||
((int32_t)host_tanh_c - tanh_zero_point);
|
||||
H[d] = Requantize<T>(o_times_host_tanh_c, h_requantization_params);
|
||||
H[d] =
|
||||
fbgemm::Requantize<T>(o_times_host_tanh_c, h_requantization_params);
|
||||
}
|
||||
}
|
||||
H_prev += D;
|
||||
|
|
@ -289,7 +291,7 @@ bool LSTMUnitDNNLowPOp<T>::RunOnDevice() {
|
|||
}
|
||||
|
||||
int32_t forget_bias_quantized =
|
||||
Quantize<int32_t>(forget_bias_, G_in_qparams_);
|
||||
fbgemm::Quantize<int32_t>(forget_bias_, G_in_qparams_);
|
||||
|
||||
LSTMUnit(
|
||||
N,
|
||||
|
|
@ -313,12 +315,12 @@ bool LSTMUnitDNNLowPOp<T>::RunOnDevice() {
|
|||
qfactory_.get());
|
||||
|
||||
if (dequantize_output_) {
|
||||
Dequantize<T>(
|
||||
fbgemm::Dequantize<T>(
|
||||
Cdata,
|
||||
OutputTensorCPU_(CELL_T)->template mutable_data<float>(),
|
||||
Ctemp.size(),
|
||||
C_out_qparams_);
|
||||
Dequantize<T>(
|
||||
fbgemm::Dequantize<T>(
|
||||
Hdata,
|
||||
OutputTensorCPU_(HIDDEN_T)->template mutable_data<float>(),
|
||||
Htemp.size(),
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class OpWrapper {
|
|||
|
||||
float min, max;
|
||||
auto& out_tensor = local_output_blobs_[index]->template Get<TensorCPU>();
|
||||
FindMinMax(
|
||||
fbgemm::FindMinMax(
|
||||
out_tensor.template data<float>(), &min, &max, out_tensor.numel());
|
||||
if (op_->OperatorBase::GetSingleArgument<std::string>("followed_by", "") ==
|
||||
"Relu") {
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ bool QuantizeDNNLowPOp<T>::RunOnDevice() {
|
|||
int i_begin, i_end;
|
||||
tie(i_begin, i_end) = Get1DPartition(
|
||||
Input(0).numel(), dnnlowp_get_num_threads(), dnnlowp_get_thread_num());
|
||||
Quantize<T>(
|
||||
fbgemm::Quantize<T>(
|
||||
in_data + i_begin, out_data + i_begin, i_end - i_begin, in_qparams);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -45,14 +45,14 @@ TEST(Requantization, BatchRequantizationUnitTest) {
|
|||
real_multiplier, target_qparams);
|
||||
|
||||
for (int j = 0; j < LEN; ++j) {
|
||||
expected[j] = clamp(
|
||||
expected[j] = fbgemm::clamp(
|
||||
target_qparams.zero_point +
|
||||
std::nearbyint(static_cast<double>(src[j]) * real_multiplier),
|
||||
8);
|
||||
}
|
||||
|
||||
unsigned long long cycle_begin = __rdtsc();
|
||||
Requantize(src.data(), actual.data(), LEN, params);
|
||||
fbgemm::Requantize(src.data(), actual.data(), LEN, params);
|
||||
unsigned long long cycle_end = __rdtsc();
|
||||
double elements_per_cycle = (double)LEN / (cycle_end - cycle_begin);
|
||||
LOG(INFO) << elements_per_cycle << " elements_per_cycle";
|
||||
|
|
@ -118,20 +118,21 @@ TEST(Requantization, RequantizationUnitTest) {
|
|||
vector<float> src(LEN);
|
||||
for (int j = 0; j < LEN; ++j) {
|
||||
float src_orig = value_dist(gen);
|
||||
src_q[j] = Quantize<int32_t>(
|
||||
src_q[j] = fbgemm::Quantize<int32_t>(
|
||||
src_orig, 0, src_qparams.scale, 32, true /* signed*/);
|
||||
src[j] = Dequantize<int32_t>(src_q[j], src_qparams);
|
||||
src[j] = fbgemm::Dequantize<int32_t>(src_q[j], src_qparams);
|
||||
// This number shouldn't have any quantization error
|
||||
EXPECT_EQ(
|
||||
Quantize<int32_t>(src[j], 0, src_qparams.scale, 32, true),
|
||||
fbgemm::Quantize<int32_t>(src[j], 0, src_qparams.scale, 32, true),
|
||||
src_q[j]);
|
||||
}
|
||||
|
||||
vector<uint8_t> dst_q(LEN);
|
||||
Requantize(src_q.data(), dst_q.data(), LEN, requantization_params);
|
||||
fbgemm::Requantize(
|
||||
src_q.data(), dst_q.data(), LEN, requantization_params);
|
||||
|
||||
for (int j = 0; j < LEN; ++j) {
|
||||
float dst = Dequantize<uint8_t>(dst_q[j], dst_qparams);
|
||||
float dst = fbgemm::Dequantize<uint8_t>(dst_q[j], dst_qparams);
|
||||
|
||||
float err = fabsf(dst - src[j]);
|
||||
sum_sq += err * err;
|
||||
|
|
|
|||
|
|
@ -20,10 +20,11 @@ TEST(Sigmoid, SigmoidUnitTest) {
|
|||
float sq_err_sum = 0, max_err = 0;
|
||||
for (int i = 0; i < NSAMPLES; ++i) {
|
||||
float x = distribution(generator);
|
||||
uint8_t x_q =
|
||||
Quantize<uint8_t>(x, sigmoid_approx.GetInputQuantizationParams());
|
||||
uint8_t x_q = fbgemm::Quantize<uint8_t>(
|
||||
x, sigmoid_approx.GetInputQuantizationParams());
|
||||
uint8_t y_q = sigmoid_approx.Compute(x_q);
|
||||
float y = Dequantize(y_q, sigmoid_approx.GetOutputQuantizationParams());
|
||||
float y =
|
||||
fbgemm::Dequantize(y_q, sigmoid_approx.GetOutputQuantizationParams());
|
||||
float sigmoid = exp(x) / (exp(x) + 1);
|
||||
float err = fabs(sigmoid - y);
|
||||
sq_err_sum += err * err;
|
||||
|
|
|
|||
|
|
@ -36,12 +36,14 @@ class Tanh {
|
|||
}
|
||||
|
||||
float GetPassRegionEndDequantized() const {
|
||||
return Dequantize<T>(
|
||||
(uint8_t)(x_pq_index_ + in_qparams_.zero_point), in_qparams_);
|
||||
return fbgemm::Dequantize<T>(
|
||||
static_cast<uint8_t>(x_pq_index_ + in_qparams_.zero_point),
|
||||
in_qparams_);
|
||||
}
|
||||
|
||||
float GetSaturationRegionBegin() const {
|
||||
return Dequantize<T>((T)((1 << num_in_bits_) - 1), in_qparams_);
|
||||
return fbgemm::Dequantize<T>(
|
||||
static_cast<T>((1 << num_in_bits_) - 1), in_qparams_);
|
||||
}
|
||||
|
||||
static constexpr double DEFAULT_MAX_ABS_ERR = 0.02;
|
||||
|
|
|
|||
|
|
@ -24,10 +24,11 @@ TEST(Tanh, TanhUnitTest) {
|
|||
float sq_err_sum = 0, max_err = 0;
|
||||
for (int i = 0; i < NSAMPLES; ++i) {
|
||||
float x = distribution(generator);
|
||||
uint8_t x_q =
|
||||
Quantize<uint8_t>(x, tanh_approx.GetInputQuantizationParams());
|
||||
uint8_t x_q = fbgemm::Quantize<uint8_t>(
|
||||
x, tanh_approx.GetInputQuantizationParams());
|
||||
uint8_t y_q = tanh_approx.Compute(x_q);
|
||||
float y = Dequantize(y_q, tanh_approx.GetOutputQuantizationParams());
|
||||
float y =
|
||||
fbgemm::Dequantize(y_q, tanh_approx.GetOutputQuantizationParams());
|
||||
float err = fabs(tanh(x) - y);
|
||||
sq_err_sum += err * err;
|
||||
max_err = std::max(err, max_err);
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ bool GatherDNNLowPOp<T>::RunOnDevice() {
|
|||
out_qparams = Fp32Op_()->GetOutputQuantizationParams(qfactory_.get());
|
||||
}
|
||||
|
||||
Quantize<T>(
|
||||
fbgemm::Quantize<T>(
|
||||
static_cast<const float*>(Fp32Op_()->Get()->Output(0)->raw_data()),
|
||||
out_data,
|
||||
output->t.numel(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user