#include "utility_dnnlowp_ops.h" namespace caffe2 { template GatherDNNLowPOp::GatherDNNLowPOp( const OperatorDef& operator_def, Workspace* ws) : GatherOp(operator_def, ws), qfactory_(dnnlowp::GetQuantizationFactoryOf(this)) {} template GatherDNNLowPOp::~GatherDNNLowPOp() { if (measure_quantization_error_) { dnnlowp::ReportQuantizationError(this, quantization_error_stats_); } } template bool GatherDNNLowPOp::RunOnDevice() { using namespace dnnlowp; if (!arguments_parsed_) { dnnlowp::ParseDNNLowPOperatorArguments( this, &dequantize_output_, &measure_quantization_error_); arguments_parsed_ = true; } if (!InputIsType(DATA)) { if (dequantize_output_) { return GatherOp::RunOnDevice(); } else { // If input or output is float, delegate to fp32 op Fp32Op_()->DequantizeInput(); // dequantize input if it's not already float if (!Fp32Op_()->Get()->RunOnDevice()) { return false; } int8::Int8TensorCPU* output = Outputs()[0]->template GetMutable(); output->t.ResizeLike(*Fp32Op_()->Get()->Output(0)); T* out_data = output->t.template mutable_data(); TensorQuantizationParams out_qparams; if (HasStaticQuantization(this)) { out_qparams = GetStaticQuantizationParamsOf(this, 0); } else { out_qparams = Fp32Op_()->GetOutputQuantizationParams(qfactory_.get()); } fbgemm::Quantize( static_cast(Fp32Op_()->Get()->Output(0)->raw_data()), out_data, output->t.numel(), out_qparams); PropagateOutputTensorQuantizationParams(this, 0, out_qparams); } } else { DispatchHelper>::call(this, Input(INDICES)); TensorQuantizationParams in_qparams = GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get()); PropagateOutputTensorQuantizationParams(this, 0, in_qparams); } return true; } REGISTER_CPU_OPERATOR_WITH_ENGINE(Gather, DNNLOWP, GatherDNNLowPOp); REGISTER_CPU_OPERATOR_WITH_ENGINE( Int8Gather, DNNLOWP, GatherDNNLowPOp); } // namespace caffe2