Add a new attribute narrow_range to FakeQuant* operations. It quantizes into range [1; 255] instead of [0; 255].

PiperOrigin-RevId: 157641054
This commit is contained in:
A. Unique TensorFlower 2017-05-31 15:06:59 -07:00 committed by TensorFlower Gardener
parent c048e2938c
commit 41b87d6ceb
6 changed files with 1528 additions and 303 deletions

View File

@ -60,25 +60,29 @@ class FakeQuantWithMinMaxArgsOp
: Base::UnaryElementWiseOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, min_ < max_,
InvalidArgument("min has to be smaller than max, was: ", min_,
" >= ", max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1;
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
}
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
FakeQuantWithMinMaxArgsFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_,
steps_, output->flat<float>());
quant_min_, quant_max_, output->flat<float>());
}
private:
float min_;
float max_;
int steps_;
int quant_min_;
int quant_max_;
};
// Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in
@ -94,14 +98,17 @@ class FakeQuantWithMinMaxArgsGradientOp
: Base::BinaryElementWiseOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, min_ < max_,
InvalidArgument("min has to be smaller than max, was: ", min_,
" >= ", max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1;
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
}
template <int NDIMS>
@ -116,12 +123,14 @@ class FakeQuantWithMinMaxArgsGradientOp
InvalidArgument("gradient and input must be the same size"));
FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), gradient.flat<float>(),
input.flat<float>(), min_, max_, steps_, output->flat<float>());
input.flat<float>(), min_, max_, quant_min_, quant_max_,
output->flat<float>());
}
private:
float min_;
float max_;
int steps_;
int quant_min_;
int quant_max_;
};
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU),
@ -136,8 +145,9 @@ typedef Eigen::GpuDevice GPUDevice;
// Forward declarations for functor specializations for GPU.
template <>
void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat inputs, float min,
float max, int steps, typename TTypes<float>::Flat outputs);
const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
const float min, const float max, const int quant_min, const int quant_max,
typename TTypes<float>::Flat outputs);
extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
FakeQuantWithMinMaxArgsOp<GPUDevice>);
@ -145,7 +155,8 @@ REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
template <>
void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs, float min, float max, int steps,
typename TTypes<float>::ConstFlat inputs, const float min, const float max,
const int quant_min, const int quant_max,
typename TTypes<float>::Flat backprops);
REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
@ -164,7 +175,10 @@ class FakeQuantWithMinMaxVarsOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1;
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
}
void Compute(OpKernelContext* context) override {
@ -179,12 +193,13 @@ class FakeQuantWithMinMaxVarsOp : public OpKernel {
FakeQuantWithMinMaxVarsFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.flat<float>(),
min.scalar<float>(), max.scalar<float>(), steps_,
min.scalar<float>(), max.scalar<float>(), quant_min_, quant_max_,
output->flat<float>());
}
private:
int steps_;
int quant_min_;
int quant_max_;
};
// Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in
@ -198,7 +213,10 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1;
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
}
void Compute(OpKernelContext* context) override {
@ -226,13 +244,13 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
FakeQuantWithMinMaxVarsGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), gradient.flat<float>(),
input.flat<float>(), min.scalar<float>(), max.scalar<float>(),
steps_,
grad_wrt_input->flat<float>(), grad_wrt_min->scalar<float>(),
grad_wrt_max->scalar<float>());
quant_min_, quant_max_, grad_wrt_input->flat<float>(),
grad_wrt_min->scalar<float>(), grad_wrt_max->scalar<float>());
}
private:
int steps_;
int quant_min_;
int quant_max_;
};
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU),
@ -246,8 +264,8 @@ template <>
void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstScalar min,
typename TTypes<float>::ConstScalar max, int steps,
typename TTypes<float>::Flat output);
typename TTypes<float>::ConstScalar max, const int quant_min,
const int quant_max, typename TTypes<float>::Flat output);
extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars")
.Device(DEVICE_GPU)
@ -260,8 +278,8 @@ void FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstScalar min,
typename TTypes<float>::ConstScalar max, int steps,
typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::ConstScalar max, const int quant_min,
const int quant_max, typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Scalar backprop_wrt_min,
typename TTypes<float>::Scalar backprop_wrt_max);
extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
@ -284,7 +302,10 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1;
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
}
void Compute(OpKernelContext* context) override {
@ -309,22 +330,22 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
FakeQuant4WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), input.dim_size(2), input.dim_size(3),
input.flat<float>(), min.vec<float>(), max.vec<float>(), steps_,
output->flat<float>());
input.flat<float>(), min.vec<float>(), max.vec<float>(),
quant_min_, quant_max_, output->flat<float>());
break;
}
case 2: {
FakeQuant2WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), input.flat<float>(), min.vec<float>(),
max.vec<float>(), steps_,
max.vec<float>(), quant_min_, quant_max_,
output->flat<float>());
break;
}
case 1: {
FakeQuant1WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.vec<float>(),
min.vec<float>(), max.vec<float>(), steps_,
min.vec<float>(), max.vec<float>(), quant_min_, quant_max_,
output->vec<float>());
break;
}
@ -336,7 +357,8 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
}
private:
int steps_;
int quant_min_;
int quant_max_;
};
// Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its
@ -350,7 +372,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1;
bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
}
void Compute(OpKernelContext* context) override {
@ -388,7 +413,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), input.dim_size(2), input.dim_size(3),
gradient.flat<float>(), input.flat<float>(), min.vec<float>(),
max.vec<float>(), steps_,
max.vec<float>(), quant_min_, quant_max_,
grad_wrt_input->flat<float>(), grad_wrt_min->vec<float>(),
grad_wrt_max->vec<float>());
break;
@ -397,7 +422,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), gradient.flat<float>(), input.flat<float>(),
min.vec<float>(), max.vec<float>(), steps_,
min.vec<float>(), max.vec<float>(), quant_min_, quant_max_,
grad_wrt_input->flat<float>(), grad_wrt_min->vec<float>(),
grad_wrt_max->vec<float>());
break;
@ -405,9 +430,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
case 1: {
FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), gradient.vec<float>(),
input.vec<float>(), min.vec<float>(), max.vec<float>(), steps_,
grad_wrt_input->vec<float>(), grad_wrt_min->vec<float>(),
grad_wrt_max->vec<float>());
input.vec<float>(), min.vec<float>(), max.vec<float>(),
quant_min_, quant_max_, grad_wrt_input->vec<float>(),
grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
break;
}
default:
@ -418,7 +443,8 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
}
private:
int steps_;
int quant_min_;
int quant_max_;
};
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
@ -433,7 +459,7 @@ template <>
void FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstVec inputs,
typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
int steps,
const int quant_min, const int quant_max,
typename TTypes<float>::Vec outputs);
extern template struct FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>;
@ -442,8 +468,8 @@ void FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const GPUDevice& d, const Index batch_size, const Index depth,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstFlat min,
typename TTypes<float>::ConstFlat max, int steps,
typename TTypes<float>::Flat outputs);
typename TTypes<float>::ConstFlat max, const int quant_min,
const int quant_max, typename TTypes<float>::Flat outputs);
extern template struct FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>;
template <>
@ -452,8 +478,8 @@ void FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const Index width, const Index depth,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstFlat min,
typename TTypes<float>::ConstFlat max, int steps,
typename TTypes<float>::Flat outputs);
typename TTypes<float>::ConstFlat max, const int quant_min,
const int quant_max, typename TTypes<float>::Flat outputs);
extern template struct FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
@ -467,7 +493,7 @@ void FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstVec gradients,
typename TTypes<float>::ConstVec inputs,
typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
int steps,
const int quant_min, const int quant_max,
typename TTypes<float>::Vec backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max);
@ -480,7 +506,7 @@ void FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
int steps,
const int quant_min, const int quant_max,
typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max);
@ -494,7 +520,7 @@ void FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
int steps,
const int quant_min, const int quant_max,
typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max);

View File

@ -37,25 +37,27 @@ namespace tensorflow {
// Gymnastics with nudged zero point is to ensure that real zero maps to
// an integer, which is required for e.g. zero-padding in convolutional layers.
// Returns (nudged_min, nudged_max, nudged_scale).
// Outputs nudged_min, nudged_max, nudged_scale.
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void Nudge(
const float min, const float max, const int steps, float* nudged_min,
float* nudged_max, float* scale) {
const float steps_float = static_cast<float>(steps);
*scale = (max - min) / (steps_float - 0.0f);
const float zero_point_from_min = 0.0f - min / *scale;
const uint8 nudged_zero_point = [zero_point_from_min, steps, steps_float] {
if (zero_point_from_min < 0.0f) {
return static_cast<uint8>(0);
const float min, const float max, const int quant_min, const int quant_max,
float* nudged_min, float* nudged_max, float* scale) {
const float quant_min_float = static_cast<float>(quant_min);
const float quant_max_float = static_cast<float>(quant_max);
*scale = (max - min) / (quant_max_float - quant_min_float);
const float zero_point_from_min = quant_min_float - min / *scale;
const uint8 nudged_zero_point = [zero_point_from_min, quant_min,
quant_min_float, quant_max,
quant_max_float] {
if (zero_point_from_min < quant_min_float) {
return static_cast<uint8>(quant_min);
}
if (zero_point_from_min > steps_float) {
return static_cast<uint8>(steps);
if (zero_point_from_min > quant_max_float) {
return static_cast<uint8>(quant_max);
}
return static_cast<uint8>(StdRound(zero_point_from_min));
}();
*nudged_min = (0.0f - nudged_zero_point) * (*scale);
*nudged_max = (steps_float - nudged_zero_point) * (*scale);
*nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
*nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
}
template <typename T>
@ -76,13 +78,15 @@ using Flat = typename tensorflow::TTypes<T>::Flat;
template <typename Device>
struct FakeQuantWithMinMaxArgsFunctor {
void operator()(const Device& d, ConstFlat<float> inputs, const float min,
const float max, const int steps, Flat<float> outputs) {
const float max, const int quant_min, const int quant_max,
Flat<float> outputs) {
eigen_assert(min <= 0.0f && "min should be <= 0.0");
eigen_assert(max >= 0.0f && "max should be >= 0.0");
eigen_assert(min < max && "min should be < max");
float nudged_min, nudged_max, nudged_scale;
Nudge(min, max, steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min, max, quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const float inv_nudged_scale = 1.0f / nudged_scale;
auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
@ -99,13 +103,15 @@ template <typename Device>
struct FakeQuantWithMinMaxArgsGradientFunctor {
void operator()(const Device& d, ConstFlat<float> gradients,
ConstFlat<float> inputs, const float min, const float max,
const int steps, Flat<float> backprops) {
const int quant_min, const int quant_max,
Flat<float> backprops) {
eigen_assert(min <= 0.0f && "min should be <= 0.0");
eigen_assert(max >= 0.0f && "max should be >= 0.0");
eigen_assert(min < max && "min should be < max");
float nudged_min, nudged_max, nudged_scale;
Nudge(min, max, steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min, max, quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
auto between_nudged_min_max =
(inputs >= nudged_min && inputs <= nudged_max)
@ -120,10 +126,11 @@ template <typename Device>
struct FakeQuantWithMinMaxVarsFunctor {
void operator()(const Device& d, ConstFlat<float> inputs,
ConstScalar<float> min, ConstScalar<float> max,
const int steps,
const int quant_min, const int quant_max,
Flat<float> outputs) {
float nudged_min, nudged_max, nudged_scale;
Nudge(min(), max(), steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min(), max(), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto nudged_scale_repl = inputs.constant(nudged_scale);
const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
@ -140,12 +147,13 @@ template <typename Device>
struct FakeQuantWithMinMaxVarsGradientFunctor {
void operator()(const Device& d, ConstFlat<float> gradients,
ConstFlat<float> inputs, ConstScalar<float> min,
ConstScalar<float> max, const int steps,
Flat<float> backprops_wrt_input,
ConstScalar<float> max, const int quant_min,
const int quant_max, Flat<float> backprops_wrt_input,
Scalar<float> backprop_wrt_min,
Scalar<float> backprop_wrt_max) {
float nudged_min, nudged_max, nudged_scale;
Nudge(min(), max(), steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min(), max(), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto between_min_max =
(inputs >= nudged_min && inputs <= nudged_max)
@ -173,11 +181,12 @@ using Index = typename tensorflow::TTypes<float>::ConstTensor::Index;
template <typename Device>
struct FakeQuant1WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, ConstVec<float> inputs, ConstVec<float> min,
ConstVec<float> max, const int steps,
ConstVec<float> max, const int quant_min, const int quant_max,
Vec<float> outputs) {
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const float clamped =
std::max(std::min(inputs(i), nudged_max), nudged_min);
const float clamped_shifted = clamped - nudged_min;
@ -194,13 +203,14 @@ template <typename Device>
struct FakeQuant2WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, const Index batch_size, const Index depth,
ConstFlat<float> inputs, ConstVec<float> min,
ConstVec<float> max, const int steps,
ConstVec<float> max, const int quant_min, const int quant_max,
Flat<float> outputs) {
Eigen::DSizes<Index, 2> restored(batch_size, depth);
const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto clamped =
inputs_restored.chip<1>(i).cwiseMin(nudged_max).cwiseMax(nudged_min);
const auto clamped_shifted = clamped - nudged_min;
@ -218,13 +228,14 @@ template <typename Device>
struct FakeQuant4WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, const Index batch_size, const Index height,
const Index width, const Index depth, ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max, const int steps,
Flat<float> outputs) {
ConstVec<float> min, ConstVec<float> max, const int quant_min,
const int quant_max, Flat<float> outputs) {
Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth);
const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto clamped =
inputs_restored.chip<3>(i).cwiseMin(nudged_max).cwiseMax(nudged_min);
const auto clamped_shifted = clamped - nudged_min;
@ -245,12 +256,13 @@ template <typename Device>
struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d, ConstVec<float> gradients,
ConstVec<float> inputs, ConstVec<float> min,
ConstVec<float> max, const int steps,
ConstVec<float> max, const int quant_min, const int quant_max,
Vec<float> backprops_wrt_input, Vec<float> backprop_wrt_min,
Vec<float> backprop_wrt_max) {
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const bool between_min_max =
inputs(i) >= nudged_min && inputs(i) <= nudged_max;
@ -271,15 +283,16 @@ template <typename Device>
struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d, const Index batch_size, const Index depth,
ConstFlat<float> gradients, ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max, const int steps,
Flat<float> backprops_wrt_input, Vec<float> backprop_wrt_min,
Vec<float> backprop_wrt_max) {
ConstVec<float> min, ConstVec<float> max, const int quant_min,
const int quant_max, Flat<float> backprops_wrt_input,
Vec<float> backprop_wrt_min, Vec<float> backprop_wrt_max) {
Eigen::DSizes<Index, 2> restored(batch_size, depth);
const auto gradients_restored = gradients.reshape(restored);
const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto gradients_chip = gradients_restored.chip<1>(i);
const auto inputs_chip = inputs_restored.chip<1>(i);
@ -312,15 +325,16 @@ struct FakeQuant4WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d, const Index batch_size, const Index height,
const Index width, const Index depth,
ConstFlat<float> gradients, ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max, const int steps,
Flat<float> backprops_wrt_input, Vec<float> backprop_wrt_min,
Vec<float> backprop_wrt_max) {
ConstVec<float> min, ConstVec<float> max, const int quant_min,
const int quant_max, Flat<float> backprops_wrt_input,
Vec<float> backprop_wrt_min, Vec<float> backprop_wrt_max) {
Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth);
const auto gradients_restored = gradients.reshape(restored);
const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale);
Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto gradients_chip = gradients_restored.chip<3>(i);
const auto inputs_chip = inputs_restored.chip<3>(i);

File diff suppressed because it is too large Load Diff

View File

@ -4918,16 +4918,18 @@ REGISTER_OP("FakeQuantWithMinMaxArgs")
.Attr("min: float = -6.0")
.Attr("max: float = 6.0")
.Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("inputs: float")
.Output("outputs: float")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
Attributes [min; max] define the clamping range for the 'inputs' data. Op
divides this range into 255 steps (total of 256 values), then replaces each
'inputs' value with the closest of the quantized step values.
'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive.
Attributes `[min; max]` define the clamping range for the `inputs` data.
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive.
Quantization is called fake since the output is still in floating point.
)doc");
@ -4936,6 +4938,7 @@ REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
.Attr("min: float = -6.0")
.Attr("max: float = 6.0")
.Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("gradients: float")
.Input("inputs: float")
.Output("backprops: float")
@ -4951,6 +4954,7 @@ backprops: Backpropagated gradients below the FakeQuantWithMinMaxArgs operation:
REGISTER_OP("FakeQuantWithMinMaxVars")
.Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("inputs: float")
.Input("min: float")
.Input("max: float")
@ -4966,16 +4970,19 @@ REGISTER_OP("FakeQuantWithMinMaxVars")
Fake-quantize the 'inputs' tensor of type float via global float scalars `min`
and `max` to 'outputs' tensor of same shape as `inputs`.
[min; max] is the clamping range for the 'inputs' data. Op divides this range
into 255 steps (total of 256 values), then replaces each 'inputs' value with the
closest of the quantized step values.
'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive.
`[min; max]` define the clamping range for the `inputs` data.
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive.
This operation has a gradient and thus allows for training `min` and `max` values.
This operation has a gradient and thus allows for training `min` and `max`
values.
)doc");
REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
.Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("gradients: float")
.Input("inputs: float")
.Input("min: float")
@ -5005,6 +5012,7 @@ gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation.
inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation.
min, max: Quantization interval, scalar floats.
num_bits: The bitwidth of the quantization; between 2 and 8, inclusive.
narrow_range: Whether to quantize into 2^num_bits - 1 distinct values.
backprops_wrt_input: Backpropagated gradients w.r.t. inputs:
`gradients * (inputs >= min && inputs <= max)`.
backprop_wrt_min: Backpropagated gradients w.r.t. min parameter:
@ -5015,6 +5023,7 @@ backprop_wrt_max: Backpropagated gradients w.r.t. max parameter:
REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
.Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("inputs: float")
.Input("min: float")
.Input("max: float")
@ -5038,16 +5047,19 @@ Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`
to 'outputs' tensor of same shape as `inputs`.
[min; max] is the clamping range for the 'inputs' data in the corresponding
depth channel. Op divides this range into 255 steps (total of 256 values), then
replaces each 'inputs' value with the closest of the quantized step values.
'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive.
`[min; max]` define the clamping range for the `inputs` data.
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive.
This operation has a gradient and thus allows for training `min` and `max` values.
This operation has a gradient and thus allows for training `min` and `max`
values.
)doc");
REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
.Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("gradients: float")
.Input("inputs: float")
.Input("min: float")
@ -5082,6 +5094,7 @@ inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape
same as `gradients`.
min, max: Quantization interval, floats of shape `[d]`.
num_bits: The bitwidth of the quantization; between 2 and 8, inclusive.
narrow_range: Whether to quantize into 2^num_bits - 1 distinct values.
backprops_wrt_input: Backpropagated gradients w.r.t. inputs, shape same as
`inputs`:
`gradients * (inputs >= min && inputs <= max)`.

View File

@ -1887,22 +1887,36 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
def _FakeQuantWithMinMaxArgsGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxArgs op."""
return fake_quant_with_min_max_args_gradient(
grad, op.inputs[0], min=op.get_attr("min"), max=op.get_attr("max"))
grad,
op.inputs[0],
min=op.get_attr("min"),
max=op.get_attr("max"),
num_bits=op.get_attr("num_bits"),
narrow_range=op.get_attr("narrow_range"))
@ops.RegisterGradient("FakeQuantWithMinMaxVars")
def _FakeQuantWithMinMaxVarsGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxVars op."""
return fake_quant_with_min_max_vars_gradient(grad, op.inputs[0], op.inputs[1],
op.inputs[2])
return fake_quant_with_min_max_vars_gradient(
grad,
op.inputs[0],
op.inputs[1],
op.inputs[2],
num_bits=op.get_attr("num_bits"),
narrow_range=op.get_attr("narrow_range"))
@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxVarsPerChannel op."""
return fake_quant_with_min_max_vars_per_channel_gradient(grad, op.inputs[0],
op.inputs[1],
op.inputs[2])
return fake_quant_with_min_max_vars_per_channel_gradient(
grad,
op.inputs[0],
op.inputs[1],
op.inputs[2],
num_bits=op.get_attr("num_bits"),
narrow_range=op.get_attr("narrow_range"))
def required_space_to_batch_paddings(input_shape,

View File

@ -898,27 +898,27 @@ tf_module {
}
member_method {
name: "fake_quant_with_min_max_args"
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "fake_quant_with_min_max_args_gradient"
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "fake_quant_with_min_max_vars"
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "fake_quant_with_min_max_vars_gradient"
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "fake_quant_with_min_max_vars_per_channel"
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "fake_quant_with_min_max_vars_per_channel_gradient"
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "fft"