pytorch/caffe2/quantization/server/quantization_error_minimization.h
Marc Fisher 9e60b00316 Remove AutoHeaders.RECURSIVE_GLOB from caffe2/ (#73227)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73227

Reviewed By: christycylee

Differential Revision: D34016914

fbshipit-source-id: 277937f3c13a54ea1180afac253ee9927e56e99e
(cherry picked from commit d97777318170a406d89755e577386cde857dd59b)
2022-03-01 19:31:44 +00:00

58 lines
1.3 KiB
C++

#pragma once
#include "caffe2/quantization/server/dnnlowp.h"
namespace dnnlowp {
class QuantizationErrorMinimization {
public:
virtual TensorQuantizationParams ChooseQuantizationParams(
const Histogram& hist,
bool preserve_sparsity = false,
int precision = 8) = 0;
virtual ~QuantizationErrorMinimization(){};
};
class NormMinimization : public QuantizationErrorMinimization {
public:
enum Kind {
L1,
L2,
};
NormMinimization(Kind kind) : kind_(kind) {}
/**
* Faster approximate search
*/
TensorQuantizationParams NonlinearQuantizationParamsSearch(
const Histogram& hist,
bool preserve_sparsity = false,
int precision = 8);
TensorQuantizationParams ChooseQuantizationParams(
const Histogram& hist,
bool preserve_sparsity = false,
int precision = 8) override;
protected:
Kind kind_;
};
class L1ErrorMinimization : public NormMinimization {
public:
L1ErrorMinimization() : NormMinimization(L1) {}
};
class P99 : public QuantizationErrorMinimization {
public:
float threshold_;
P99(float p99_threshold = 0.99) : threshold_(p99_threshold) {}
TensorQuantizationParams ChooseQuantizationParams(
const Histogram& hist,
bool preserve_sparsity = true,
int precision = 8) override;
}; // class P99QuantizationFactory
} // namespace dnnlowp