Add a new attribute narrow_range to FakeQuant* operations. It quantizes into range [1; 255] instead of [0; 255].

PiperOrigin-RevId: 157641054
This commit is contained in:
A. Unique TensorFlower 2017-05-31 15:06:59 -07:00 committed by TensorFlower Gardener
parent c048e2938c
commit 41b87d6ceb
6 changed files with 1528 additions and 303 deletions

View File

@ -60,25 +60,29 @@ class FakeQuantWithMinMaxArgsOp
: Base::UnaryElementWiseOp(context) { : Base::UnaryElementWiseOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("min", &min_)); OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
OP_REQUIRES_OK(context, context->GetAttr("max", &max_)); OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, min_ < max_, OP_REQUIRES(context, min_ < max_,
InvalidArgument("min has to be smaller than max, was: ", min_, InvalidArgument("min has to be smaller than max, was: ", min_,
" >= ", max_)); " >= ", max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits), OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive")); InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1; bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
} }
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
FakeQuantWithMinMaxArgsFunctor<Device> functor; FakeQuantWithMinMaxArgsFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_, functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_,
steps_, output->flat<float>()); quant_min_, quant_max_, output->flat<float>());
} }
private: private:
float min_; float min_;
float max_; float max_;
int steps_; int quant_min_;
int quant_max_;
}; };
// Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in // Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in
@ -94,14 +98,17 @@ class FakeQuantWithMinMaxArgsGradientOp
: Base::BinaryElementWiseOp(context) { : Base::BinaryElementWiseOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("min", &min_)); OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
OP_REQUIRES_OK(context, context->GetAttr("max", &max_)); OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, min_ < max_, OP_REQUIRES(context, min_ < max_,
InvalidArgument("min has to be smaller than max, was: ", min_, InvalidArgument("min has to be smaller than max, was: ", min_,
" >= ", max_)); " >= ", max_));
int num_bits;
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits), OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive")); InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1; bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
} }
template <int NDIMS> template <int NDIMS>
@ -116,12 +123,14 @@ class FakeQuantWithMinMaxArgsGradientOp
InvalidArgument("gradient and input must be the same size")); InvalidArgument("gradient and input must be the same size"));
FakeQuantWithMinMaxArgsGradientFunctor<Device> functor; FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), gradient.flat<float>(), functor(context->eigen_device<Device>(), gradient.flat<float>(),
input.flat<float>(), min_, max_, steps_, output->flat<float>()); input.flat<float>(), min_, max_, quant_min_, quant_max_,
output->flat<float>());
} }
private: private:
float min_; float min_;
float max_; float max_;
int steps_; int quant_min_;
int quant_max_;
}; };
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU),
@ -136,8 +145,9 @@ typedef Eigen::GpuDevice GPUDevice;
// Forward declarations for functor specializations for GPU. // Forward declarations for functor specializations for GPU.
template <> template <>
void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()( void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat inputs, float min, const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
float max, int steps, typename TTypes<float>::Flat outputs); const float min, const float max, const int quant_min, const int quant_max,
typename TTypes<float>::Flat outputs);
extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>; extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU), REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
FakeQuantWithMinMaxArgsOp<GPUDevice>); FakeQuantWithMinMaxArgsOp<GPUDevice>);
@ -145,7 +155,8 @@ REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
template <> template <>
void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()( void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat gradients, const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs, float min, float max, int steps, typename TTypes<float>::ConstFlat inputs, const float min, const float max,
const int quant_min, const int quant_max,
typename TTypes<float>::Flat backprops); typename TTypes<float>::Flat backprops);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU), Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
@ -164,7 +175,10 @@ class FakeQuantWithMinMaxVarsOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits), OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive")); InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1; bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -179,12 +193,13 @@ class FakeQuantWithMinMaxVarsOp : public OpKernel {
FakeQuantWithMinMaxVarsFunctor<Device> functor; FakeQuantWithMinMaxVarsFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.flat<float>(), functor(context->eigen_device<Device>(), input.flat<float>(),
min.scalar<float>(), max.scalar<float>(), steps_, min.scalar<float>(), max.scalar<float>(), quant_min_, quant_max_,
output->flat<float>()); output->flat<float>());
} }
private: private:
int steps_; int quant_min_;
int quant_max_;
}; };
// Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in // Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in
@ -198,7 +213,10 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits), OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive")); InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1; bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -226,13 +244,13 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
FakeQuantWithMinMaxVarsGradientFunctor<Device> functor; FakeQuantWithMinMaxVarsGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), gradient.flat<float>(), functor(context->eigen_device<Device>(), gradient.flat<float>(),
input.flat<float>(), min.scalar<float>(), max.scalar<float>(), input.flat<float>(), min.scalar<float>(), max.scalar<float>(),
steps_, quant_min_, quant_max_, grad_wrt_input->flat<float>(),
grad_wrt_input->flat<float>(), grad_wrt_min->scalar<float>(), grad_wrt_min->scalar<float>(), grad_wrt_max->scalar<float>());
grad_wrt_max->scalar<float>());
} }
private: private:
int steps_; int quant_min_;
int quant_max_;
}; };
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU),
@ -246,8 +264,8 @@ template <>
void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()( void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat inputs, const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstScalar min, typename TTypes<float>::ConstScalar min,
typename TTypes<float>::ConstScalar max, int steps, typename TTypes<float>::ConstScalar max, const int quant_min,
typename TTypes<float>::Flat output); const int quant_max, typename TTypes<float>::Flat output);
extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>; extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars") REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars")
.Device(DEVICE_GPU) .Device(DEVICE_GPU)
@ -260,8 +278,8 @@ void FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstFlat gradients, const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs, typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstScalar min, typename TTypes<float>::ConstScalar min,
typename TTypes<float>::ConstScalar max, int steps, typename TTypes<float>::ConstScalar max, const int quant_min,
typename TTypes<float>::Flat backprops_wrt_input, const int quant_max, typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Scalar backprop_wrt_min, typename TTypes<float>::Scalar backprop_wrt_min,
typename TTypes<float>::Scalar backprop_wrt_max); typename TTypes<float>::Scalar backprop_wrt_max);
extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>; extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
@ -284,7 +302,10 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits), OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive")); InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1; bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -309,22 +330,22 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
FakeQuant4WithMinMaxVarsPerChannelFunctor<Device> functor; FakeQuant4WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.dim_size(0), functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), input.dim_size(2), input.dim_size(3), input.dim_size(1), input.dim_size(2), input.dim_size(3),
input.flat<float>(), min.vec<float>(), max.vec<float>(), steps_, input.flat<float>(), min.vec<float>(), max.vec<float>(),
output->flat<float>()); quant_min_, quant_max_, output->flat<float>());
break; break;
} }
case 2: { case 2: {
FakeQuant2WithMinMaxVarsPerChannelFunctor<Device> functor; FakeQuant2WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.dim_size(0), functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), input.flat<float>(), min.vec<float>(), input.dim_size(1), input.flat<float>(), min.vec<float>(),
max.vec<float>(), steps_, max.vec<float>(), quant_min_, quant_max_,
output->flat<float>()); output->flat<float>());
break; break;
} }
case 1: { case 1: {
FakeQuant1WithMinMaxVarsPerChannelFunctor<Device> functor; FakeQuant1WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.vec<float>(), functor(context->eigen_device<Device>(), input.vec<float>(),
min.vec<float>(), max.vec<float>(), steps_, min.vec<float>(), max.vec<float>(), quant_min_, quant_max_,
output->vec<float>()); output->vec<float>());
break; break;
} }
@ -336,7 +357,8 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
} }
private: private:
int steps_; int quant_min_;
int quant_max_;
}; };
// Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its // Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its
@ -350,7 +372,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
OP_REQUIRES(context, IsNumBitsValid(num_bits), OP_REQUIRES(context, IsNumBitsValid(num_bits),
InvalidArgument("num_bits must be between 2 and 8, inclusive")); InvalidArgument("num_bits must be between 2 and 8, inclusive"));
steps_ = (1 << num_bits) - 1; bool narrow_range;
OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -388,7 +413,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
functor(context->eigen_device<Device>(), input.dim_size(0), functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), input.dim_size(2), input.dim_size(3), input.dim_size(1), input.dim_size(2), input.dim_size(3),
gradient.flat<float>(), input.flat<float>(), min.vec<float>(), gradient.flat<float>(), input.flat<float>(), min.vec<float>(),
max.vec<float>(), steps_, max.vec<float>(), quant_min_, quant_max_,
grad_wrt_input->flat<float>(), grad_wrt_min->vec<float>(), grad_wrt_input->flat<float>(), grad_wrt_min->vec<float>(),
grad_wrt_max->vec<float>()); grad_wrt_max->vec<float>());
break; break;
@ -397,7 +422,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<Device> functor; FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.dim_size(0), functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), gradient.flat<float>(), input.flat<float>(), input.dim_size(1), gradient.flat<float>(), input.flat<float>(),
min.vec<float>(), max.vec<float>(), steps_, min.vec<float>(), max.vec<float>(), quant_min_, quant_max_,
grad_wrt_input->flat<float>(), grad_wrt_min->vec<float>(), grad_wrt_input->flat<float>(), grad_wrt_min->vec<float>(),
grad_wrt_max->vec<float>()); grad_wrt_max->vec<float>());
break; break;
@ -405,9 +430,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
case 1: { case 1: {
FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<Device> functor; FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), gradient.vec<float>(), functor(context->eigen_device<Device>(), gradient.vec<float>(),
input.vec<float>(), min.vec<float>(), max.vec<float>(), steps_, input.vec<float>(), min.vec<float>(), max.vec<float>(),
grad_wrt_input->vec<float>(), grad_wrt_min->vec<float>(), quant_min_, quant_max_, grad_wrt_input->vec<float>(),
grad_wrt_max->vec<float>()); grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
break; break;
} }
default: default:
@ -418,7 +443,8 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
} }
private: private:
int steps_; int quant_min_;
int quant_max_;
}; };
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel") REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
@ -433,7 +459,7 @@ template <>
void FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()( void FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstVec inputs, const GPUDevice& d, typename TTypes<float>::ConstVec inputs,
typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max, typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
int steps, const int quant_min, const int quant_max,
typename TTypes<float>::Vec outputs); typename TTypes<float>::Vec outputs);
extern template struct FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>; extern template struct FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>;
@ -442,8 +468,8 @@ void FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const GPUDevice& d, const Index batch_size, const Index depth, const GPUDevice& d, const Index batch_size, const Index depth,
typename TTypes<float>::ConstFlat inputs, typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstFlat min, typename TTypes<float>::ConstFlat min,
typename TTypes<float>::ConstFlat max, int steps, typename TTypes<float>::ConstFlat max, const int quant_min,
typename TTypes<float>::Flat outputs); const int quant_max, typename TTypes<float>::Flat outputs);
extern template struct FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>; extern template struct FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>;
template <> template <>
@ -452,8 +478,8 @@ void FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const Index width, const Index depth, const Index width, const Index depth,
typename TTypes<float>::ConstFlat inputs, typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstFlat min, typename TTypes<float>::ConstFlat min,
typename TTypes<float>::ConstFlat max, int steps, typename TTypes<float>::ConstFlat max, const int quant_min,
typename TTypes<float>::Flat outputs); const int quant_max, typename TTypes<float>::Flat outputs);
extern template struct FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>; extern template struct FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel") REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
@ -467,7 +493,7 @@ void FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<float>::ConstVec gradients, const GPUDevice& d, typename TTypes<float>::ConstVec gradients,
typename TTypes<float>::ConstVec inputs, typename TTypes<float>::ConstVec inputs,
typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max, typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
int steps, const int quant_min, const int quant_max,
typename TTypes<float>::Vec backprops_wrt_input, typename TTypes<float>::Vec backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min, typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max); typename TTypes<float>::Vec backprop_wrt_max);
@ -480,7 +506,7 @@ void FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
typename TTypes<float>::ConstFlat gradients, typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs, typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max, typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
int steps, const int quant_min, const int quant_max,
typename TTypes<float>::Flat backprops_wrt_input, typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min, typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max); typename TTypes<float>::Vec backprop_wrt_max);
@ -494,7 +520,7 @@ void FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
typename TTypes<float>::ConstFlat gradients, typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs, typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max, typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
int steps, const int quant_min, const int quant_max,
typename TTypes<float>::Flat backprops_wrt_input, typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min, typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max); typename TTypes<float>::Vec backprop_wrt_max);

View File

@ -37,25 +37,27 @@ namespace tensorflow {
// Gymnastics with nudged zero point is to ensure that real zero maps to // Gymnastics with nudged zero point is to ensure that real zero maps to
// an integer, which is required for e.g. zero-padding in convolutional layers. // an integer, which is required for e.g. zero-padding in convolutional layers.
// Returns (nudged_min, nudged_max, nudged_scale). // Outputs nudged_min, nudged_max, nudged_scale.
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void Nudge( EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void Nudge(
const float min, const float max, const int steps, float* nudged_min, const float min, const float max, const int quant_min, const int quant_max,
float* nudged_max, float* scale) { float* nudged_min, float* nudged_max, float* scale) {
const float steps_float = static_cast<float>(steps); const float quant_min_float = static_cast<float>(quant_min);
*scale = (max - min) / (steps_float - 0.0f); const float quant_max_float = static_cast<float>(quant_max);
const float zero_point_from_min = 0.0f - min / *scale; *scale = (max - min) / (quant_max_float - quant_min_float);
const uint8 nudged_zero_point = [zero_point_from_min, steps, steps_float] { const float zero_point_from_min = quant_min_float - min / *scale;
if (zero_point_from_min < 0.0f) { const uint8 nudged_zero_point = [zero_point_from_min, quant_min,
return static_cast<uint8>(0); quant_min_float, quant_max,
quant_max_float] {
if (zero_point_from_min < quant_min_float) {
return static_cast<uint8>(quant_min);
} }
if (zero_point_from_min > steps_float) { if (zero_point_from_min > quant_max_float) {
return static_cast<uint8>(steps); return static_cast<uint8>(quant_max);
} }
return static_cast<uint8>(StdRound(zero_point_from_min)); return static_cast<uint8>(StdRound(zero_point_from_min));
}(); }();
*nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
*nudged_min = (0.0f - nudged_zero_point) * (*scale); *nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
*nudged_max = (steps_float - nudged_zero_point) * (*scale);
} }
template <typename T> template <typename T>
@ -76,13 +78,15 @@ using Flat = typename tensorflow::TTypes<T>::Flat;
template <typename Device> template <typename Device>
struct FakeQuantWithMinMaxArgsFunctor { struct FakeQuantWithMinMaxArgsFunctor {
void operator()(const Device& d, ConstFlat<float> inputs, const float min, void operator()(const Device& d, ConstFlat<float> inputs, const float min,
const float max, const int steps, Flat<float> outputs) { const float max, const int quant_min, const int quant_max,
Flat<float> outputs) {
eigen_assert(min <= 0.0f && "min should be <= 0.0"); eigen_assert(min <= 0.0f && "min should be <= 0.0");
eigen_assert(max >= 0.0f && "max should be >= 0.0"); eigen_assert(max >= 0.0f && "max should be >= 0.0");
eigen_assert(min < max && "min should be < max"); eigen_assert(min < max && "min should be < max");
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min, max, steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min, max, quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const float inv_nudged_scale = 1.0f / nudged_scale; const float inv_nudged_scale = 1.0f / nudged_scale;
auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
@ -99,13 +103,15 @@ template <typename Device>
struct FakeQuantWithMinMaxArgsGradientFunctor { struct FakeQuantWithMinMaxArgsGradientFunctor {
void operator()(const Device& d, ConstFlat<float> gradients, void operator()(const Device& d, ConstFlat<float> gradients,
ConstFlat<float> inputs, const float min, const float max, ConstFlat<float> inputs, const float min, const float max,
const int steps, Flat<float> backprops) { const int quant_min, const int quant_max,
Flat<float> backprops) {
eigen_assert(min <= 0.0f && "min should be <= 0.0"); eigen_assert(min <= 0.0f && "min should be <= 0.0");
eigen_assert(max >= 0.0f && "max should be >= 0.0"); eigen_assert(max >= 0.0f && "max should be >= 0.0");
eigen_assert(min < max && "min should be < max"); eigen_assert(min < max && "min should be < max");
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min, max, steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min, max, quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
auto between_nudged_min_max = auto between_nudged_min_max =
(inputs >= nudged_min && inputs <= nudged_max) (inputs >= nudged_min && inputs <= nudged_max)
@ -120,10 +126,11 @@ template <typename Device>
struct FakeQuantWithMinMaxVarsFunctor { struct FakeQuantWithMinMaxVarsFunctor {
void operator()(const Device& d, ConstFlat<float> inputs, void operator()(const Device& d, ConstFlat<float> inputs,
ConstScalar<float> min, ConstScalar<float> max, ConstScalar<float> min, ConstScalar<float> max,
const int steps, const int quant_min, const int quant_max,
Flat<float> outputs) { Flat<float> outputs) {
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min(), max(), steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min(), max(), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto nudged_scale_repl = inputs.constant(nudged_scale); const auto nudged_scale_repl = inputs.constant(nudged_scale);
const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
@ -140,12 +147,13 @@ template <typename Device>
struct FakeQuantWithMinMaxVarsGradientFunctor { struct FakeQuantWithMinMaxVarsGradientFunctor {
void operator()(const Device& d, ConstFlat<float> gradients, void operator()(const Device& d, ConstFlat<float> gradients,
ConstFlat<float> inputs, ConstScalar<float> min, ConstFlat<float> inputs, ConstScalar<float> min,
ConstScalar<float> max, const int steps, ConstScalar<float> max, const int quant_min,
Flat<float> backprops_wrt_input, const int quant_max, Flat<float> backprops_wrt_input,
Scalar<float> backprop_wrt_min, Scalar<float> backprop_wrt_min,
Scalar<float> backprop_wrt_max) { Scalar<float> backprop_wrt_max) {
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min(), max(), steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min(), max(), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto between_min_max = const auto between_min_max =
(inputs >= nudged_min && inputs <= nudged_max) (inputs >= nudged_min && inputs <= nudged_max)
@ -173,11 +181,12 @@ using Index = typename tensorflow::TTypes<float>::ConstTensor::Index;
template <typename Device> template <typename Device>
struct FakeQuant1WithMinMaxVarsPerChannelFunctor { struct FakeQuant1WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, ConstVec<float> inputs, ConstVec<float> min, void operator()(const Device& d, ConstVec<float> inputs, ConstVec<float> min,
ConstVec<float> max, const int steps, ConstVec<float> max, const int quant_min, const int quant_max,
Vec<float> outputs) { Vec<float> outputs) {
for (Index i = 0; i < min.size(); ++i) { for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const float clamped = const float clamped =
std::max(std::min(inputs(i), nudged_max), nudged_min); std::max(std::min(inputs(i), nudged_max), nudged_min);
const float clamped_shifted = clamped - nudged_min; const float clamped_shifted = clamped - nudged_min;
@ -194,13 +203,14 @@ template <typename Device>
struct FakeQuant2WithMinMaxVarsPerChannelFunctor { struct FakeQuant2WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, const Index batch_size, const Index depth, void operator()(const Device& d, const Index batch_size, const Index depth,
ConstFlat<float> inputs, ConstVec<float> min, ConstFlat<float> inputs, ConstVec<float> min,
ConstVec<float> max, const int steps, ConstVec<float> max, const int quant_min, const int quant_max,
Flat<float> outputs) { Flat<float> outputs) {
Eigen::DSizes<Index, 2> restored(batch_size, depth); Eigen::DSizes<Index, 2> restored(batch_size, depth);
const auto inputs_restored = inputs.reshape(restored); const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) { for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto clamped = const auto clamped =
inputs_restored.chip<1>(i).cwiseMin(nudged_max).cwiseMax(nudged_min); inputs_restored.chip<1>(i).cwiseMin(nudged_max).cwiseMax(nudged_min);
const auto clamped_shifted = clamped - nudged_min; const auto clamped_shifted = clamped - nudged_min;
@ -218,13 +228,14 @@ template <typename Device>
struct FakeQuant4WithMinMaxVarsPerChannelFunctor { struct FakeQuant4WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, const Index batch_size, const Index height, void operator()(const Device& d, const Index batch_size, const Index height,
const Index width, const Index depth, ConstFlat<float> inputs, const Index width, const Index depth, ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max, const int steps, ConstVec<float> min, ConstVec<float> max, const int quant_min,
Flat<float> outputs) { const int quant_max, Flat<float> outputs) {
Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth); Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth);
const auto inputs_restored = inputs.reshape(restored); const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) { for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto clamped = const auto clamped =
inputs_restored.chip<3>(i).cwiseMin(nudged_max).cwiseMax(nudged_min); inputs_restored.chip<3>(i).cwiseMin(nudged_max).cwiseMax(nudged_min);
const auto clamped_shifted = clamped - nudged_min; const auto clamped_shifted = clamped - nudged_min;
@ -245,12 +256,13 @@ template <typename Device>
struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor { struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d, ConstVec<float> gradients, void operator()(const Device& d, ConstVec<float> gradients,
ConstVec<float> inputs, ConstVec<float> min, ConstVec<float> inputs, ConstVec<float> min,
ConstVec<float> max, const int steps, ConstVec<float> max, const int quant_min, const int quant_max,
Vec<float> backprops_wrt_input, Vec<float> backprop_wrt_min, Vec<float> backprops_wrt_input, Vec<float> backprop_wrt_min,
Vec<float> backprop_wrt_max) { Vec<float> backprop_wrt_max) {
for (Index i = 0; i < min.size(); ++i) { for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const bool between_min_max = const bool between_min_max =
inputs(i) >= nudged_min && inputs(i) <= nudged_max; inputs(i) >= nudged_min && inputs(i) <= nudged_max;
@ -271,15 +283,16 @@ template <typename Device>
struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor { struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d, const Index batch_size, const Index depth, void operator()(const Device& d, const Index batch_size, const Index depth,
ConstFlat<float> gradients, ConstFlat<float> inputs, ConstFlat<float> gradients, ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max, const int steps, ConstVec<float> min, ConstVec<float> max, const int quant_min,
Flat<float> backprops_wrt_input, Vec<float> backprop_wrt_min, const int quant_max, Flat<float> backprops_wrt_input,
Vec<float> backprop_wrt_max) { Vec<float> backprop_wrt_min, Vec<float> backprop_wrt_max) {
Eigen::DSizes<Index, 2> restored(batch_size, depth); Eigen::DSizes<Index, 2> restored(batch_size, depth);
const auto gradients_restored = gradients.reshape(restored); const auto gradients_restored = gradients.reshape(restored);
const auto inputs_restored = inputs.reshape(restored); const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) { for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto gradients_chip = gradients_restored.chip<1>(i); const auto gradients_chip = gradients_restored.chip<1>(i);
const auto inputs_chip = inputs_restored.chip<1>(i); const auto inputs_chip = inputs_restored.chip<1>(i);
@ -312,15 +325,16 @@ struct FakeQuant4WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d, const Index batch_size, const Index height, void operator()(const Device& d, const Index batch_size, const Index height,
const Index width, const Index depth, const Index width, const Index depth,
ConstFlat<float> gradients, ConstFlat<float> inputs, ConstFlat<float> gradients, ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max, const int steps, ConstVec<float> min, ConstVec<float> max, const int quant_min,
Flat<float> backprops_wrt_input, Vec<float> backprop_wrt_min, const int quant_max, Flat<float> backprops_wrt_input,
Vec<float> backprop_wrt_max) { Vec<float> backprop_wrt_min, Vec<float> backprop_wrt_max) {
Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth); Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth);
const auto gradients_restored = gradients.reshape(restored); const auto gradients_restored = gradients.reshape(restored);
const auto inputs_restored = inputs.reshape(restored); const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) { for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale; float nudged_min, nudged_max, nudged_scale;
Nudge(min(i), max(i), steps, &nudged_min, &nudged_max, &nudged_scale); Nudge(min(i), max(i), quant_min, quant_max, &nudged_min, &nudged_max,
&nudged_scale);
const auto gradients_chip = gradients_restored.chip<3>(i); const auto gradients_chip = gradients_restored.chip<3>(i);
const auto inputs_chip = inputs_restored.chip<3>(i); const auto inputs_chip = inputs_restored.chip<3>(i);

File diff suppressed because it is too large Load Diff

View File

@ -4918,16 +4918,18 @@ REGISTER_OP("FakeQuantWithMinMaxArgs")
.Attr("min: float = -6.0") .Attr("min: float = -6.0")
.Attr("max: float = 6.0") .Attr("max: float = 6.0")
.Attr("num_bits: int = 8") .Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("inputs: float") .Input("inputs: float")
.Output("outputs: float") .Output("outputs: float")
.SetShapeFn(shape_inference::UnchangedShape) .SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc( .Doc(R"doc(
Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
Attributes [min; max] define the clamping range for the 'inputs' data. Op Attributes `[min; max]` define the clamping range for the `inputs` data.
divides this range into 255 steps (total of 256 values), then replaces each `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
'inputs' value with the closest of the quantized step values. when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive.
Quantization is called fake since the output is still in floating point. Quantization is called fake since the output is still in floating point.
)doc"); )doc");
@ -4936,6 +4938,7 @@ REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
.Attr("min: float = -6.0") .Attr("min: float = -6.0")
.Attr("max: float = 6.0") .Attr("max: float = 6.0")
.Attr("num_bits: int = 8") .Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("gradients: float") .Input("gradients: float")
.Input("inputs: float") .Input("inputs: float")
.Output("backprops: float") .Output("backprops: float")
@ -4951,6 +4954,7 @@ backprops: Backpropagated gradients below the FakeQuantWithMinMaxArgs operation:
REGISTER_OP("FakeQuantWithMinMaxVars") REGISTER_OP("FakeQuantWithMinMaxVars")
.Attr("num_bits: int = 8") .Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("inputs: float") .Input("inputs: float")
.Input("min: float") .Input("min: float")
.Input("max: float") .Input("max: float")
@ -4966,16 +4970,19 @@ REGISTER_OP("FakeQuantWithMinMaxVars")
Fake-quantize the 'inputs' tensor of type float via global float scalars `min` Fake-quantize the 'inputs' tensor of type float via global float scalars `min`
and `max` to 'outputs' tensor of same shape as `inputs`. and `max` to 'outputs' tensor of same shape as `inputs`.
[min; max] is the clamping range for the 'inputs' data. Op divides this range `[min; max]` define the clamping range for the `inputs` data.
into 255 steps (total of 256 values), then replaces each 'inputs' value with the `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
closest of the quantized step values. when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive.
This operation has a gradient and thus allows for training `min` and `max` values. This operation has a gradient and thus allows for training `min` and `max`
values.
)doc"); )doc");
REGISTER_OP("FakeQuantWithMinMaxVarsGradient") REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
.Attr("num_bits: int = 8") .Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("gradients: float") .Input("gradients: float")
.Input("inputs: float") .Input("inputs: float")
.Input("min: float") .Input("min: float")
@ -5005,6 +5012,7 @@ gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation.
inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation.
min, max: Quantization interval, scalar floats. min, max: Quantization interval, scalar floats.
num_bits: The bitwidth of the quantization; between 2 and 8, inclusive. num_bits: The bitwidth of the quantization; between 2 and 8, inclusive.
narrow_range: Whether to quantize into 2^num_bits - 1 distinct values.
backprops_wrt_input: Backpropagated gradients w.r.t. inputs: backprops_wrt_input: Backpropagated gradients w.r.t. inputs:
`gradients * (inputs >= min && inputs <= max)`. `gradients * (inputs >= min && inputs <= max)`.
backprop_wrt_min: Backpropagated gradients w.r.t. min parameter: backprop_wrt_min: Backpropagated gradients w.r.t. min parameter:
@ -5015,6 +5023,7 @@ backprop_wrt_max: Backpropagated gradients w.r.t. max parameter:
REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel") REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
.Attr("num_bits: int = 8") .Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("inputs: float") .Input("inputs: float")
.Input("min: float") .Input("min: float")
.Input("max: float") .Input("max: float")
@ -5038,16 +5047,19 @@ Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`
to 'outputs' tensor of same shape as `inputs`. to 'outputs' tensor of same shape as `inputs`.
[min; max] is the clamping range for the 'inputs' data in the corresponding `[min; max]` define the clamping range for the `inputs` data.
depth channel. Op divides this range into 255 steps (total of 256 values), then `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
replaces each 'inputs' value with the closest of the quantized step values. when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
'num_bits' is the bitwidth of the quantization; between 2 and 8, inclusive. then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive.
This operation has a gradient and thus allows for training `min` and `max` values. This operation has a gradient and thus allows for training `min` and `max`
values.
)doc"); )doc");
REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient") REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
.Attr("num_bits: int = 8") .Attr("num_bits: int = 8")
.Attr("narrow_range: bool = false")
.Input("gradients: float") .Input("gradients: float")
.Input("inputs: float") .Input("inputs: float")
.Input("min: float") .Input("min: float")
@ -5082,6 +5094,7 @@ inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape
same as `gradients`. same as `gradients`.
min, max: Quantization interval, floats of shape `[d]`. min, max: Quantization interval, floats of shape `[d]`.
num_bits: The bitwidth of the quantization; between 2 and 8, inclusive. num_bits: The bitwidth of the quantization; between 2 and 8, inclusive.
narrow_range: Whether to quantize into 2^num_bits - 1 distinct values.
backprops_wrt_input: Backpropagated gradients w.r.t. inputs, shape same as backprops_wrt_input: Backpropagated gradients w.r.t. inputs, shape same as
`inputs`: `inputs`:
`gradients * (inputs >= min && inputs <= max)`. `gradients * (inputs >= min && inputs <= max)`.

View File

@ -1887,22 +1887,36 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
def _FakeQuantWithMinMaxArgsGradient(op, grad): def _FakeQuantWithMinMaxArgsGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxArgs op.""" """Gradient for FakeQuantWithMinMaxArgs op."""
return fake_quant_with_min_max_args_gradient( return fake_quant_with_min_max_args_gradient(
grad, op.inputs[0], min=op.get_attr("min"), max=op.get_attr("max")) grad,
op.inputs[0],
min=op.get_attr("min"),
max=op.get_attr("max"),
num_bits=op.get_attr("num_bits"),
narrow_range=op.get_attr("narrow_range"))
@ops.RegisterGradient("FakeQuantWithMinMaxVars") @ops.RegisterGradient("FakeQuantWithMinMaxVars")
def _FakeQuantWithMinMaxVarsGradient(op, grad): def _FakeQuantWithMinMaxVarsGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxVars op.""" """Gradient for FakeQuantWithMinMaxVars op."""
return fake_quant_with_min_max_vars_gradient(grad, op.inputs[0], op.inputs[1], return fake_quant_with_min_max_vars_gradient(
op.inputs[2]) grad,
op.inputs[0],
op.inputs[1],
op.inputs[2],
num_bits=op.get_attr("num_bits"),
narrow_range=op.get_attr("narrow_range"))
@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel") @ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad): def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxVarsPerChannel op.""" """Gradient for FakeQuantWithMinMaxVarsPerChannel op."""
return fake_quant_with_min_max_vars_per_channel_gradient(grad, op.inputs[0], return fake_quant_with_min_max_vars_per_channel_gradient(
op.inputs[1], grad,
op.inputs[2]) op.inputs[0],
op.inputs[1],
op.inputs[2],
num_bits=op.get_attr("num_bits"),
narrow_range=op.get_attr("narrow_range"))
def required_space_to_batch_paddings(input_shape, def required_space_to_batch_paddings(input_shape,

View File

@ -898,27 +898,27 @@ tf_module {
} }
member_method { member_method {
name: "fake_quant_with_min_max_args" name: "fake_quant_with_min_max_args"
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "fake_quant_with_min_max_args_gradient" name: "fake_quant_with_min_max_args_gradient"
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "fake_quant_with_min_max_vars" name: "fake_quant_with_min_max_vars"
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "fake_quant_with_min_max_vars_gradient" name: "fake_quant_with_min_max_vars_gradient"
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "fake_quant_with_min_max_vars_per_channel" name: "fake_quant_with_min_max_vars_per_channel"
argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "fake_quant_with_min_max_vars_per_channel_gradient" name: "fake_quant_with_min_max_vars_per_channel_gradient"
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "fft" name: "fft"