pytorch/caffe2/operators/gather_op.cc
Nikita Shulga a9b0a921d5 Disable avoid-non-const-global-variables lint check (#62008)
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`

All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`;  do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008

Reviewed By: driazati, r-barnes

Differential Revision: D29838584

Pulled By: malfet

fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
2021-07-22 18:04:40 -07:00

152 lines
4.2 KiB
C++

#include "gather_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(Gather, GatherOp<CPUContext>);
OPERATOR_SCHEMA(Gather)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
The *Gather* op accepts a *DATA* tensor of rank $r >= 1$ and *INDICES* tensor of rank $q$ as inputs. It then gathers entries of the outer-most dimension of *DATA*, indexed by *INDICES*, and concatenate them in an output tensor of rank $q + (r - 1)$.
Github Links:
- https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.cc
- https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.h
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"Gather",
["DATA", "INDICES"],
["OUTPUT"]
)
data = np.array([[1., 1.2],[2.3, 3.4],[4.5, 5.7]])
print("DATA:\n",data)
inds = np.array([[0, 1],[1, 2]])
print("INDICES:\n",inds)
// Feed X into workspace
workspace.FeedBlob("DATA", data.astype(np.float32))
workspace.FeedBlob("INDICES", inds.astype(np.int32))
workspace.RunOperatorOnce(op)
print("OUTPUT:\n", workspace.FetchBlob("OUTPUT"))
```
**Result**
```
DATA:
[[1. 1.2]
[2.3 3.4]
[4.5 5.7]]
INDICES:
[[0 1]
[1 2]]
OUTPUT:
[[[1. 1.2]
[2.3 3.4]]
[[2.3 3.4]
[4.5 5.7]]]
```
</details>
)DOC")
.Input(0, "DATA", "Input data tensor of rank $r>=1$")
.Input(
1,
"INDICES",
"Input indices tensor of rank $q$. This tensor must contain integers.")
.Output(0, "OUTPUT", "Output tensor of rank $q+(r-1)$")
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
ArgumentHelper helper(def);
const int axis = helper.GetSingleArgument<int>("axis", 0);
const bool match_outer =
helper.GetSingleArgument<bool>("match_outer", false);
const auto& data_dims = GetDimsVector(in[0]);
const auto& indices_dims = GetDimsVector(in[1]);
vector<int> output_dims =
caffe2::gather_helper::calc_output_shape_vector<int>(
data_dims, indices_dims, axis, match_outer);
vector<TensorShape> out(1);
out[0] = CreateTensorShape(output_dims, in[0].data_type());
return out;
})
.InheritOnnxSchema();
class GetGatherGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
ArgumentHelper argsHelper(def_);
const bool dense_gradient =
argsHelper.GetSingleArgument<bool>("dense_gradient", false);
const int axis = argsHelper.GetSingleArgument<int>("axis", 0);
// TBD: While it hasn't been used yet, we need to add wrap_indices support
// to gradients next.
// if (argsHelper.HasArgument("wrap_indices_")) {
// }
using Op = GatherOp<CPUContext>;
if (axis == 0) {
if (dense_gradient) {
return vector<OperatorDef>{CreateOperatorDef(
"SparseToDense",
"",
vector<string>{I(Op::INDICES), GO(0), I(Op::DATA)},
vector<string>{GI(Op::DATA)})};
} else {
// For now we don't do any reshaping as the consumer of this op would
// probably be ScatterUpdate which is intenionally ignores shapes. We
// might need to revisit it in the future for correctness purposes. The
// right shape for the output woild be to flatten INDICES and collapse
// first X dims of GRAD
SetSparse(Op::DATA, I(Op::INDICES), GO(0));
return vector<OperatorDef>();
}
}
// TBD: This is misleading to use dense_gradient by default for axis 0
// and not othewise....
if (argsHelper.HasArgument("dense_gradient")) {
CAFFE_ENFORCE(
dense_gradient == true,
"Gather with axis > 0 must use dense_gradient");
}
Argument axisArg = MakeArgument<int>("axis", axis);
return SingleGradientDef(
"BatchGatherGradient",
"",
// This is the order as expected by BatchGatherGradient indices,
// different from SpartseToDense above.
vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
vector<string>{GI(0)},
std::vector<Argument>{axisArg});
}
};
REGISTER_GRADIENT(Gather, GetGatherGradient);
} // namespace caffe2