mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
66 lines
2.0 KiB
C++
66 lines
2.0 KiB
C++
#include "caffe2/operators/batch_gather_ops.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
REGISTER_CPU_OPERATOR(BatchGather, BatchGatherOp<CPUContext>);
|
|
REGISTER_CPU_OPERATOR(BatchGatherGradient, BatchGatherGradientOp<CPUContext>);
|
|
|
|
OPERATOR_SCHEMA(BatchGather)
|
|
.NumInputs(2)
|
|
.NumOutputs(1)
|
|
.TensorInferenceFunction([](const OperatorDef& def,
|
|
const vector<TensorShape>& in) {
|
|
vector<TensorShape> out(1);
|
|
ArgumentHelper helper(def);
|
|
const auto& data_dims = GetDimsVector(in[0]);
|
|
const auto& indices_dims = GetDimsVector(in[1]);
|
|
|
|
vector<int> output_dims =
|
|
caffe2::gather_helper::calc_output_shape_vector<int>(
|
|
data_dims, indices_dims, 1, false);
|
|
out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT);
|
|
return out;
|
|
})
|
|
.SetDoc(R"DOC(
|
|
Batch gather operation, first dimension in DATA is the batch size.
|
|
Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather
|
|
entries of the second outer dimension (axis == 1) of DATA indexed by INDICES,
|
|
and concatenate them in an output tensor of rank q + (r - 1).
|
|
|
|
Example:
|
|
DATA = [
|
|
[1.0, 1.2, 2.4, 4.5],
|
|
[2.3, 3.4, 3.6, 2.3],
|
|
[4.5, 5.7, 1.2, 4.5],
|
|
]
|
|
INDICES = [0, 2]
|
|
|
|
OUTPUT = [
|
|
[1.0, 2.4],
|
|
[2.3, 3.6],
|
|
[4.5, 1.2],
|
|
]
|
|
)DOC")
|
|
.Input(0, "DATA", "Tensor of rank r >= 2.")
|
|
.Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.")
|
|
.Output(0, "OUTPUT", "Tensor of rank q + (r - 1).")
|
|
.InheritOnnxSchema();
|
|
|
|
OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1);
|
|
|
|
class GetBatchGatherGradient : public GradientMakerBase {
|
|
using GradientMakerBase::GradientMakerBase;
|
|
vector<OperatorDef> GetGradientDefs() override {
|
|
using Op = BatchGatherOp<CPUContext>;
|
|
return SingleGradientDef(
|
|
"BatchGatherGradient",
|
|
"",
|
|
vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
|
|
vector<string>{GI(0)});
|
|
}
|
|
};
|
|
|
|
REGISTER_GRADIENT(BatchGather, GetBatchGatherGradient);
|
|
|
|
} // namespace caffe2
|