mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15685 The declaration of "Dequantize" is in "fbsource/fbcode/deeplearning/fbgemm2/QuantUtils.h", so it requires the "namespace fbgemm". <T> is actually optional, since the type can de deduced from the first argument. In some places we have "Dequantize<T>(...)", while in other places we have "Dequantize(...)". We'd better unify them. As a reference, all occurrences of "Quantize" are using "fbgemm::Quantize<T>(...)". Reviewed By: jspark1105 Differential Revision: D13570847 fbshipit-source-id: 7fca9f7f9e4e0d9e5eb27ac44b8707adc3c80717
43 lines
1.4 KiB
C++
43 lines
1.4 KiB
C++
#include <iostream>
|
|
#include <random>
|
|
#include "sigmoid.h"
|
|
|
|
#include <gtest/gtest.h>
|
|
#include "caffe2/core/logging.h"
|
|
|
|
using namespace dnnlowp;
|
|
using namespace std;
|
|
|
|
TEST(Sigmoid, SigmoidUnitTest) {
|
|
for (double max_abs_err = 0.02; max_abs_err <= 0.1; max_abs_err += 0.01) {
|
|
Sigmoid<uint8_t> sigmoid_approx(max_abs_err);
|
|
LOG(INFO) << "max_abs_err " << max_abs_err;
|
|
|
|
const int NSAMPLES = 1 << 16;
|
|
|
|
std::uniform_real_distribution<float> distribution(-5., 5.);
|
|
std::default_random_engine generator;
|
|
float sq_err_sum = 0, max_err = 0;
|
|
for (int i = 0; i < NSAMPLES; ++i) {
|
|
float x = distribution(generator);
|
|
uint8_t x_q = fbgemm::Quantize<uint8_t>(
|
|
x, sigmoid_approx.GetInputQuantizationParams());
|
|
uint8_t y_q = sigmoid_approx.Compute(x_q);
|
|
float y = fbgemm::Dequantize<uint8_t>(
|
|
y_q, sigmoid_approx.GetOutputQuantizationParams());
|
|
float sigmoid = exp(x) / (exp(x) + 1);
|
|
float err = fabs(sigmoid - y);
|
|
sq_err_sum += err * err;
|
|
max_err = std::max(err, max_err);
|
|
if (err > max_abs_err) {
|
|
LOG(INFO) << "x " << x << " sigmoid_real " << sigmoid
|
|
<< " sigmoid_approx " << y << " err " << err << " x_q "
|
|
<< (int)x_q << " y_q " << (int)y_q;
|
|
}
|
|
EXPECT_LE(err, max_abs_err);
|
|
}
|
|
LOG(INFO) << "avg_l2_err " << std::sqrt(sq_err_sum) / NSAMPLES
|
|
<< " max_err " << max_err;
|
|
}
|
|
}
|