mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[XLA] Implement MirrorPad op.
Addresses #11890 * Improves the shape inference error message for concatenate. * Adds a helper to Literal that gets an integral value converted to int64. PiperOrigin-RevId: 163829437
This commit is contained in:
parent
c7b674fa28
commit
3cc5fc0886
|
|
@ -650,6 +650,80 @@ class BinaryOpsTest(XLATestCase):
|
||||||
[0, 0, 0, 0, 0, 0]],
|
[0, 0, 0, 0, 0, 0]],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
def testMirrorPad(self):
|
||||||
|
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
|
||||||
|
for dtype in self.numeric_types:
|
||||||
|
self._testBinary(
|
||||||
|
mirror_pad,
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[1, 2, 3], #
|
||||||
|
[4, 5, 6], #
|
||||||
|
],
|
||||||
|
dtype=dtype),
|
||||||
|
np.array([[
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
], [2, 2]], dtype=np.int32),
|
||||||
|
expected=np.array(
|
||||||
|
[
|
||||||
|
[6, 5, 4, 5, 6, 5, 4], #
|
||||||
|
[3, 2, 1, 2, 3, 2, 1], #
|
||||||
|
[6, 5, 4, 5, 6, 5, 4], #
|
||||||
|
[3, 2, 1, 2, 3, 2, 1]
|
||||||
|
],
|
||||||
|
dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
mirror_pad,
|
||||||
|
np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
|
||||||
|
np.array([[0, 0], [0, 0]], dtype=np.int32),
|
||||||
|
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
mirror_pad,
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[1, 2, 3], #
|
||||||
|
[4, 5, 6], #
|
||||||
|
[7, 8, 9]
|
||||||
|
],
|
||||||
|
dtype=dtype),
|
||||||
|
np.array([[2, 2], [0, 0]], dtype=np.int32),
|
||||||
|
expected=np.array(
|
||||||
|
[
|
||||||
|
[7, 8, 9], #
|
||||||
|
[4, 5, 6], #
|
||||||
|
[1, 2, 3], #
|
||||||
|
[4, 5, 6], #
|
||||||
|
[7, 8, 9], #
|
||||||
|
[4, 5, 6], #
|
||||||
|
[1, 2, 3]
|
||||||
|
],
|
||||||
|
dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
mirror_pad,
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[[1, 2, 3], [4, 5, 6]],
|
||||||
|
[[7, 8, 9], [10, 11, 12]],
|
||||||
|
], dtype=dtype),
|
||||||
|
np.array([[0, 0], [1, 1], [1, 1]], dtype=np.int32),
|
||||||
|
expected=np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[5, 4, 5, 6, 5], #
|
||||||
|
[2, 1, 2, 3, 2], #
|
||||||
|
[5, 4, 5, 6, 5], #
|
||||||
|
[2, 1, 2, 3, 2], #
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[11, 10, 11, 12, 11], #
|
||||||
|
[8, 7, 8, 9, 8], #
|
||||||
|
[11, 10, 11, 12, 11], #
|
||||||
|
[8, 7, 8, 9, 8], #
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=dtype))
|
||||||
|
|
||||||
def testReshape(self):
|
def testReshape(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,7 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||||
{"Min", "reduction_indices"},
|
{"Min", "reduction_indices"},
|
||||||
{"OneHot", "depth"},
|
{"OneHot", "depth"},
|
||||||
{"Pad", "paddings"},
|
{"Pad", "paddings"},
|
||||||
|
{"MirrorPad", "paddings"},
|
||||||
{"Prod", "reduction_indices"},
|
{"Prod", "reduction_indices"},
|
||||||
{"RandomStandardNormal", "shape"},
|
{"RandomStandardNormal", "shape"},
|
||||||
{"RandomUniform", "shape"},
|
{"RandomUniform", "shape"},
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ tf_kernel_library(
|
||||||
"l2loss_op.cc",
|
"l2loss_op.cc",
|
||||||
"lrn_ops.cc",
|
"lrn_ops.cc",
|
||||||
"matmul_op.cc",
|
"matmul_op.cc",
|
||||||
|
"mirror_pad_op.cc",
|
||||||
"no_op.cc",
|
"no_op.cc",
|
||||||
"one_hot_op.cc",
|
"one_hot_op.cc",
|
||||||
"pack_op.cc",
|
"pack_op.cc",
|
||||||
|
|
|
||||||
98
tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
Normal file
98
tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class MirrorPadOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit MirrorPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
||||||
|
|
||||||
|
xla::StatusOr<xla::ComputationDataHandle> DoMirrorPad(
|
||||||
|
const xla::ComputationDataHandle& t, const xla::Shape& original_shape,
|
||||||
|
const xla::Literal& pad_literal, xla::ComputationBuilder* b) {
|
||||||
|
xla::ComputationDataHandle accum = t;
|
||||||
|
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
|
||||||
|
--dimno) {
|
||||||
|
auto t_rev = b->Rev(accum, {dimno});
|
||||||
|
TF_ASSIGN_OR_RETURN(int64 lhs_padding,
|
||||||
|
pad_literal.GetIntegralAsS64({dimno, 0}));
|
||||||
|
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
|
||||||
|
pad_literal.GetIntegralAsS64({dimno, 1}));
|
||||||
|
int64 dim_size = original_shape.dimensions(dimno);
|
||||||
|
auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding,
|
||||||
|
dim_size - 1, 1, dimno);
|
||||||
|
auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno);
|
||||||
|
accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno);
|
||||||
|
}
|
||||||
|
return accum;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const TensorShape input_shape = ctx->InputShape(0);
|
||||||
|
const TensorShape pad_shape = ctx->InputShape(1);
|
||||||
|
|
||||||
|
MirrorPadMode mode;
|
||||||
|
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
|
||||||
|
OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT,
|
||||||
|
xla::Unimplemented(
|
||||||
|
"Only REFLECT MirrorPad mode is currently supported"));
|
||||||
|
|
||||||
|
const int dims = input_shape.dims();
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2,
|
||||||
|
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||||
|
pad_shape.DebugString()));
|
||||||
|
const int fixed_dims =
|
||||||
|
(allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1)
|
||||||
|
? 1
|
||||||
|
: dims;
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, fixed_dims == pad_shape.dim_size(0),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"The first dimension of paddings must be the rank of inputs",
|
||||||
|
pad_shape.DebugString(), " ", input_shape.DebugString()));
|
||||||
|
|
||||||
|
// Evaluate the 'padding' constant input, reshaping to a matrix.
|
||||||
|
xla::Literal pad_literal;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
|
||||||
|
|
||||||
|
xla::ComputationBuilder* b = ctx->builder();
|
||||||
|
auto in0 = ctx->Input(0);
|
||||||
|
xla::StatusOr<std::unique_ptr<xla::Shape>> in0_shape = b->GetShape(in0);
|
||||||
|
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
|
||||||
|
xla::StatusOr<xla::ComputationDataHandle> accum_status =
|
||||||
|
DoMirrorPad(in0, *in0_shape.ValueOrDie(), pad_literal, b);
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(ctx, accum_status.status());
|
||||||
|
|
||||||
|
ctx->SetOutput(0, accum_status.ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
@ -284,6 +284,25 @@ ComputationDataHandle ComputationBuilder::Slice(
|
||||||
return ParseOpResponse(s, &response);
|
return ParseOpResponse(s, &response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle ComputationBuilder::SliceInDim(
|
||||||
|
const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
|
||||||
|
int64 stride, int64 dimno) {
|
||||||
|
StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
|
||||||
|
if (!shape_status.ok()) {
|
||||||
|
NoteError(shape_status.status());
|
||||||
|
return ComputationDataHandle{};
|
||||||
|
}
|
||||||
|
const Shape& shape = *shape_status.ValueOrDie();
|
||||||
|
std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
|
||||||
|
std::vector<int64> limits(shape.dimensions().begin(),
|
||||||
|
shape.dimensions().end());
|
||||||
|
std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
|
||||||
|
starts[dimno] = start_index;
|
||||||
|
limits[dimno] = limit_index;
|
||||||
|
strides[dimno] = stride;
|
||||||
|
return Slice(operand, starts, limits, strides);
|
||||||
|
}
|
||||||
|
|
||||||
ComputationDataHandle ComputationBuilder::DynamicSlice(
|
ComputationDataHandle ComputationBuilder::DynamicSlice(
|
||||||
const ComputationDataHandle& operand,
|
const ComputationDataHandle& operand,
|
||||||
const ComputationDataHandle& start_indices,
|
const ComputationDataHandle& start_indices,
|
||||||
|
|
|
||||||
|
|
@ -217,6 +217,16 @@ class ComputationBuilder {
|
||||||
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
||||||
tensorflow::gtl::ArraySlice<int64> stride);
|
tensorflow::gtl::ArraySlice<int64> stride);
|
||||||
|
|
||||||
|
// Enqueues a slice operation in a given dimension, taking all other
|
||||||
|
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
|
||||||
|
// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
|
||||||
|
// for:
|
||||||
|
//
|
||||||
|
// array[:, 2:4:1, :]
|
||||||
|
ComputationDataHandle SliceInDim(const ComputationDataHandle& operand,
|
||||||
|
int64 start_index, int64 limit_index,
|
||||||
|
int64 stride, int64 dimno);
|
||||||
|
|
||||||
// Enqueues a slice operation onto the computation that slices the 'operand'
|
// Enqueues a slice operation onto the computation that slices the 'operand'
|
||||||
// from dynamic start indices which are passed in 'start_indices'.
|
// from dynamic start indices which are passed in 'start_indices'.
|
||||||
// The size of the slice in each dimension is passed in 'slice_sizes',
|
// The size of the slice in each dimension is passed in 'slice_sizes',
|
||||||
|
|
|
||||||
|
|
@ -503,6 +503,28 @@ string Literal::GetAsString(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<int64> Literal::GetIntegralAsS64(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
||||||
|
switch (shape().element_type()) {
|
||||||
|
case PRED:
|
||||||
|
return Get<bool>(multi_index);
|
||||||
|
case U8:
|
||||||
|
return Get<uint8>(multi_index);
|
||||||
|
case S32:
|
||||||
|
return Get<int32>(multi_index);
|
||||||
|
case S64:
|
||||||
|
return Get<int64>(multi_index);
|
||||||
|
case U32:
|
||||||
|
return Get<uint32>(multi_index);
|
||||||
|
case U64:
|
||||||
|
return Get<uint64>(multi_index);
|
||||||
|
default:
|
||||||
|
return FailedPrecondition(
|
||||||
|
"Array element type is not integral: %s",
|
||||||
|
PrimitiveType_Name(shape().element_type()).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int64 Literal::LinearIndex(
|
int64 Literal::LinearIndex(
|
||||||
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
||||||
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
|
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
|
||||||
|
|
|
||||||
|
|
@ -390,6 +390,11 @@ class Literal {
|
||||||
// into text.
|
// into text.
|
||||||
string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index) const;
|
string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index) const;
|
||||||
|
|
||||||
|
// As Get(), but determines the correct type and converts the value into
|
||||||
|
// int64.
|
||||||
|
StatusOr<int64> GetIntegralAsS64(
|
||||||
|
tensorflow::gtl::ArraySlice<int64> multi_index) const;
|
||||||
|
|
||||||
// Returns an identity matrix (rank 2) with the given row and column count.
|
// Returns an identity matrix (rank 2) with the given row and column count.
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
|
static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
|
||||||
|
|
|
||||||
|
|
@ -269,9 +269,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"cannot concatenate arrays that differ in dimensions other than "
|
"cannot concatenate arrays that differ in dimensions other than "
|
||||||
"the one being concatenated (the other array dimensions must be "
|
"the one being concatenated (the other array dimensions must be "
|
||||||
"the same): %s vs %s",
|
"the same): %s vs %s in dimension %lld",
|
||||||
ShapeUtil::HumanString(*arg_shape).c_str(),
|
ShapeUtil::HumanString(*arg_shape).c_str(),
|
||||||
ShapeUtil::HumanString(*shape).c_str());
|
ShapeUtil::HumanString(*shape).c_str(), dimension);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user