mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18959 ghimport-source-id: a934163fa34cb2019732d5f49dc7290c376bf156 Differential Revision: D14831246 Pulled By: ezyang fbshipit-source-id: beb92dc4ee8c82f4c8259c081dd72e477fe7a9d0
59 lines
1.7 KiB
C++
59 lines
1.7 KiB
C++
#include "caffe2/operators/expand_op.h"
|
|
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <vector>
|
|
|
|
#include <caffe2/utils/math.h>
|
|
|
|
namespace caffe2 {
|
|
|
|
REGISTER_CPU_OPERATOR(
|
|
Expand,
|
|
ExpandOp<
|
|
TensorTypes<std::int32_t, std::int64_t, float, double>,
|
|
CPUContext>);
|
|
|
|
REGISTER_CPU_OPERATOR(
|
|
ExpandGradient,
|
|
ExpandGradientOp<
|
|
TensorTypes<std::int32_t, std::int64_t, float, double>,
|
|
CPUContext>);
|
|
|
|
OPERATOR_SCHEMA(Expand)
|
|
.NumInputs(2)
|
|
.NumOutputs(1)
|
|
.SetDoc(R"DOC(
|
|
Broadcast the input tensor to a materialized new tensor using given shape.
|
|
Broadcast rule is similar to "numpy.array(input) * numpy.ones(shape)":
|
|
Dimensions are right alignment;
|
|
Two corresponding dimensions must have the same value, or one of them
|
|
equals to 1.
|
|
In order to align with PyTorch's `expand`, `shape` is allowed to have entries
|
|
equal to -1, which means to preserve the size of the corresponding dimension
|
|
in `X` (so it's actually equivalent to equal to 1).
|
|
)DOC")
|
|
.Input(0, "X", "(*Tensor`<NumericType>`*): input tensor")
|
|
.Input(1, "shape", "(*Tensor`<int>`*): expand shape")
|
|
.Output(0, "Y", "(*Tensor`<NumericType>`*): expanded tensor");
|
|
|
|
OPERATOR_SCHEMA(ExpandGradient).NumInputs(2).NumOutputs(1);
|
|
|
|
namespace {
|
|
|
|
class GetExpandGradient final : public GradientMakerBase {
|
|
using GradientMakerBase::GradientMakerBase;
|
|
std::vector<OperatorDef> GetGradientDefs() override {
|
|
return SingleGradientDef(
|
|
"ExpandGradient",
|
|
"",
|
|
std::vector<string>{GO(0), I(0)},
|
|
std::vector<string>{GI(0)});
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
REGISTER_GRADIENT(Expand, GetExpandGradient);
|
|
} // namespace caffe2
|