mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
152 lines
4.2 KiB
C++
152 lines
4.2 KiB
C++
#include "gather_op.h"
|
|
namespace caffe2 {
|
|
|
|
REGISTER_CPU_OPERATOR(Gather, GatherOp<CPUContext>);
|
|
|
|
OPERATOR_SCHEMA(Gather)
|
|
.NumInputs(2)
|
|
.NumOutputs(1)
|
|
.SetDoc(R"DOC(
|
|
|
|
The *Gather* op accepts a *DATA* tensor of rank $r >= 1$ and *INDICES* tensor of rank $q$ as inputs. It then gathers entries of the outer-most dimension of *DATA*, indexed by *INDICES*, and concatenate them in an output tensor of rank $q + (r - 1)$.
|
|
|
|
Github Links:
|
|
|
|
- https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.cc
|
|
- https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.h
|
|
|
|
|
|
<details>
|
|
|
|
<summary> <b>Example</b> </summary>
|
|
|
|
**Code**
|
|
|
|
```
|
|
|
|
workspace.ResetWorkspace()
|
|
|
|
op = core.CreateOperator(
|
|
"Gather",
|
|
["DATA", "INDICES"],
|
|
["OUTPUT"]
|
|
)
|
|
data = np.array([[1., 1.2],[2.3, 3.4],[4.5, 5.7]])
|
|
print("DATA:\n",data)
|
|
|
|
inds = np.array([[0, 1],[1, 2]])
|
|
print("INDICES:\n",inds)
|
|
|
|
// Feed X into workspace
|
|
workspace.FeedBlob("DATA", data.astype(np.float32))
|
|
workspace.FeedBlob("INDICES", inds.astype(np.int32))
|
|
|
|
workspace.RunOperatorOnce(op)
|
|
print("OUTPUT:\n", workspace.FetchBlob("OUTPUT"))
|
|
|
|
```
|
|
|
|
**Result**
|
|
|
|
```
|
|
|
|
DATA:
|
|
[[1. 1.2]
|
|
[2.3 3.4]
|
|
[4.5 5.7]]
|
|
INDICES:
|
|
[[0 1]
|
|
[1 2]]
|
|
OUTPUT:
|
|
[[[1. 1.2]
|
|
[2.3 3.4]]
|
|
|
|
[[2.3 3.4]
|
|
[4.5 5.7]]]
|
|
|
|
```
|
|
|
|
</details>
|
|
|
|
)DOC")
|
|
.Input(0, "DATA", "Input data tensor of rank $r>=1$")
|
|
.Input(
|
|
1,
|
|
"INDICES",
|
|
"Input indices tensor of rank $q$. This tensor must contain integers.")
|
|
.Output(0, "OUTPUT", "Output tensor of rank $q+(r-1)$")
|
|
.TensorInferenceFunction([](const OperatorDef& def,
|
|
const vector<TensorShape>& in) {
|
|
ArgumentHelper helper(def);
|
|
const int axis = helper.GetSingleArgument<int>("axis", 0);
|
|
const bool match_outer =
|
|
helper.GetSingleArgument<bool>("match_outer", false);
|
|
const auto& data_dims = GetDimsVector(in[0]);
|
|
const auto& indices_dims = GetDimsVector(in[1]);
|
|
|
|
vector<int> output_dims =
|
|
caffe2::gather_helper::calc_output_shape_vector<int>(
|
|
data_dims, indices_dims, axis, match_outer);
|
|
vector<TensorShape> out(1);
|
|
out[0] = CreateTensorShape(output_dims, in[0].data_type());
|
|
return out;
|
|
})
|
|
.InheritOnnxSchema();
|
|
|
|
class GetGatherGradient : public GradientMakerBase {
|
|
using GradientMakerBase::GradientMakerBase;
|
|
|
|
vector<OperatorDef> GetGradientDefs() override {
|
|
ArgumentHelper argsHelper(def_);
|
|
const bool dense_gradient =
|
|
argsHelper.GetSingleArgument<bool>("dense_gradient", false);
|
|
const int axis = argsHelper.GetSingleArgument<int>("axis", 0);
|
|
|
|
// TBD: While it hasn't been used yet, we need to add wrap_indices support
|
|
// to gradients next.
|
|
// if (argsHelper.HasArgument("wrap_indices_")) {
|
|
// }
|
|
|
|
using Op = GatherOp<CPUContext>;
|
|
|
|
if (axis == 0) {
|
|
if (dense_gradient) {
|
|
return vector<OperatorDef>{CreateOperatorDef(
|
|
"SparseToDense",
|
|
"",
|
|
vector<string>{I(Op::INDICES), GO(0), I(Op::DATA)},
|
|
vector<string>{GI(Op::DATA)})};
|
|
} else {
|
|
// For now we don't do any reshaping as the consumer of this op would
|
|
// probably be ScatterUpdate which is intenionally ignores shapes. We
|
|
// might need to revisit it in the future for correctness purposes. The
|
|
// right shape for the output woild be to flatten INDICES and collapse
|
|
// first X dims of GRAD
|
|
SetSparse(Op::DATA, I(Op::INDICES), GO(0));
|
|
return vector<OperatorDef>();
|
|
}
|
|
}
|
|
|
|
// TBD: This is misleading to use dense_gradient by default for axis 0
|
|
// and not othewise....
|
|
if (argsHelper.HasArgument("dense_gradient")) {
|
|
CAFFE_ENFORCE(
|
|
dense_gradient == true,
|
|
"Gather with axis > 0 must use dense_gradient");
|
|
}
|
|
|
|
Argument axisArg = MakeArgument<int>("axis", axis);
|
|
return SingleGradientDef(
|
|
"BatchGatherGradient",
|
|
"",
|
|
// This is the order as expected by BatchGatherGradient indices,
|
|
// different from SpartseToDense above.
|
|
vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
|
|
vector<string>{GI(0)},
|
|
std::vector<Argument>{axisArg});
|
|
}
|
|
};
|
|
REGISTER_GRADIENT(Gather, GetGatherGradient);
|
|
|
|
} // namespace caffe2
|