Support negative axis for Split op

PiperOrigin-RevId: 157628162
This commit is contained in:
A. Unique TensorFlower 2017-05-31 13:39:29 -07:00 committed by TensorFlower Gardener
parent bc236cfc3b
commit 0b8070253d
8 changed files with 145 additions and 50 deletions

View File

@ -637,27 +637,34 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
return MakeShapeFromPartialTensorShape(partial_shape, out);
}
// Returns a new dimension whose value is given by a scalar input tensor.
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
const Tensor* t = input_tensor(idx);
if (t == nullptr) {
*out = UnknownDim();
return Status::OK();
}
Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
// Caller must ensure that <t> is not NULL.
const int rank = t->dims();
if (rank != 0) {
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
}
int64 val;
if (t->dtype() == DT_INT32) {
val = t->scalar<int32>()();
*val = t->scalar<int32>()();
return Status::OK();
} else if (t->dtype() == DT_INT64) {
val = t->scalar<int64>()();
*val = t->scalar<int64>()();
return Status::OK();
} else {
return errors::InvalidArgument(
"Scalar input for dim size must be int32 or int64");
}
}
// Returns a new dimension whose value is given by a scalar input tensor.
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
int64 val;
const Tensor* t = input_tensor(idx);
if (t == nullptr) {
*out = UnknownDim();
return Status::OK();
}
TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
if (val < 0) {
return errors::InvalidArgument("Dimension size, given by scalar input ",
idx, ", must be non-negative but is ", val);
@ -666,6 +673,35 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
return Status::OK();
}
Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
int idx, int input_rank, DimensionHandle* out) {
int64 val;
const Tensor* t = input_tensor(idx);
if (t == nullptr) {
*out = UnknownDim();
return Status::OK();
}
TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
if (val < 0) {
if (input_rank < 0) {
*out = UnknownDim();
return Status::OK();
} else if (val + input_rank < 0) {
return errors::InvalidArgument("Dimension size, given by scalar input ",
val, " must be in range [-", input_rank,
", ", input_rank, ")");
} else {
val += input_rank;
}
} else if (input_rank >= 0 && val >= input_rank) {
return errors::InvalidArgument("Dimension size, given by scalar input ",
val, " must be in range [-", input_rank,
", ", input_rank, ")");
}
*out = MakeDim(val);
return Status::OK();
}
Status InferenceContext::Divide(DimensionHandle dividend,
DimensionOrConstant divisor,
bool evenly_divisible, DimensionHandle* out) {

View File

@ -401,11 +401,24 @@ class InferenceContext {
inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
// Returns in <val> a scalar value from an input tensor <t>. The input tensor
// must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
// input tensor is not NULL.
Status GetScalarFromTensor(const Tensor* t, int64* val);
// Returns a new dimension whose value is given by a scalar input tensor.
// The input tensor must be in host memory, since it is dereferenced to get
// the value.
Status MakeDimForScalarInput(int idx, DimensionHandle* out);
// Returns a new dimension whose value is given by a scalar input tensor.
// This allows for a negative input dimension given the rank of a separate
// tensor. This rank can be negative if unknown.
// The input tensor must be in host memory, since it is dereferenced to get
// the value.
Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
DimensionHandle* out);
// Look up the attr for the NodeDef being evaluated with name attr_name and
// set *value to its value. If no attr with attr_name is found in def(), or
// the attr does not have a matching type, a non-ok status will be returned.

View File

@ -46,15 +46,18 @@ class SplitOpBase : public OpKernel {
explicit SplitOpBase(OpKernelConstruction* c) : OpKernel(c) {}
void ComputeEasyCases(OpKernelContext* context, bool* done) {
const int32 split_dim = context->input(0).flat<int32>()(0);
const int32 num_split = num_outputs();
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
const int32 num_split = num_outputs();
OP_REQUIRES(
context, 0 <= split_dim && split_dim < input_shape.dims(),
errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
input_shape.dims(), "), but got ", split_dim));
errors::InvalidArgument("-input rank(-", input.dims(),
") <= split_dim < input rank (", input.dims(),
"), but got ", split_dim_orig));
OP_REQUIRES(
context, num_split > 0,
@ -129,10 +132,12 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
if (!context->status().ok() || done) {
return;
}
const int32 split_dim = context->input(0).flat<int32>()(0);
const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
// Android also uses int32 indexing, so check here also.
OP_REQUIRES(
@ -204,15 +209,16 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
if (!context->status().ok() || done) {
return;
}
const int32 split_dim = context->input(0).flat<int32>()(0);
const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
const int32 num_split = Base::num_outputs();
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
std::numeric_limits<int32>::max()),
errors::InvalidArgument("Split on GPU requires input size "
"< max int32"));
int32 prefix_dim_size;
int32 split_dim_size;
int32 suffix_dim_size;
@ -260,10 +266,12 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
if (!context->status().ok() || done) {
return;
}
const int32 split_dim = context->input(0).flat<int32>()(0);
const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
const int32 num_split = Base::num_outputs();
// Android also uses int32 indexing, so check here also.
OP_REQUIRES(

View File

@ -55,7 +55,9 @@ class SplitVOpBase : public OpKernel {
const TensorShape& input_shape = input.shape();
const Tensor& split_tensor = context->input(1);
const int32 split_dim = context->input(2).flat<int32>()(0);
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
OP_REQUIRES(
context,
@ -79,8 +81,9 @@ class SplitVOpBase : public OpKernel {
OP_REQUIRES(
context, 0 <= split_dim && split_dim < input.dims(),
errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
input.dims(), "), but got ", split_dim));
errors::InvalidArgument("-input rank(-", input.dims(),
") <= split_dim < input rank (", input.dims(),
"), but got ", split_dim_orig));
Tlen input_size_split_dim = input_shape.dim_size(split_dim);
@ -187,7 +190,9 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
const int32 split_dim = context->input(2).flat<int32>()(0);
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
// Android also uses int32 indexing, so check here also.
OP_REQUIRES(
@ -257,7 +262,9 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
const int32 num_split = Base::num_outputs();
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
const int32 split_dim = context->input(2).flat<int32>()(0);
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
std::numeric_limits<int32>::max()),
errors::InvalidArgument("Split on GPU requires input size "

View File

@ -456,9 +456,10 @@ REGISTER_OP("Split")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
DimensionHandle split_dimension;
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(0, &split_dimension));
int num_split = c->num_outputs();
ShapeHandle input = c->input(1);
TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
0, c->Rank(input), &split_dimension));
int num_split = c->num_outputs();
ShapeHandle out;
if (!c->ValueKnown(split_dimension)) {
if (c->RankKnown(input)) {
@ -484,7 +485,7 @@ REGISTER_OP("Split")
Splits a tensor into `num_split` tensors along one dimension.
split_dim: 0-D. The dimension along which to split. Must be in the range
`[0, rank(value))`.
`[-rank(value), rank(value))`.
num_split: The number of ways to split. Must evenly divide
`value.shape[split_dim]`.
value: The tensor to split.
@ -503,9 +504,10 @@ REGISTER_OP("SplitV")
.Attr("Tlen: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
DimensionHandle split_dimension;
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension));
int32 num_outputs = c->num_outputs();
ShapeHandle input = c->input(0);
TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
2, c->Rank(input), &split_dimension));
int32 num_outputs = c->num_outputs();
int32 rank = c->Rank(input);
ShapeHandle output_shape;
const Tensor* size_splits = c->input_tensor(1);
@ -594,7 +596,7 @@ size_splits: list containing the sizes of each output tensor along the split
dimension. Must sum to the dimension of value along split_dim.
Can contain one -1 indicating that dimension is to be inferred.
split_dim: 0-D. The dimension along which to split. Must be in the range
`[0, rank(value))`.
`[-rank(value), rank(value))`.
output: Tensors whose shape matches that of `value`
except along `split_dim`, where their sizes are
`size_splits[i]`.

View File

@ -993,6 +993,9 @@ TEST(ArrayOpsTest, Split_ShapeFn) {
// If the rank is known, we know the rank of each output.
INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
// split_dim is unknown but other inputs are known.
INFER_OK(op, "?;[1,4]", "[?,?];[?,?]");
// split_dim is known.
Tensor split_dim = test::AsTensor<int32>({1, 2});
op.input_tensors[0] = &split_dim;
@ -1004,6 +1007,26 @@ TEST(ArrayOpsTest, Split_ShapeFn) {
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
"?;[1,5]");
// split_dim too large.
split_dim = test::AsScalar<int32>(3);
INFER_ERROR(
"Dimension size, given by scalar input 3 must be in range [-3, 3)", op,
"?;[1,4,8]");
// Negative split_dim.
split_dim = test::AsScalar<int32>(-1);
INFER_OK(op, "?;?", "?;?");
INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]");
split_dim = test::AsScalar<int32>(-2);
INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]");
split_dim = test::AsScalar<int32>(-4);
INFER_ERROR(
"Dimension size, given by scalar input -4 must be in range [-3, 3)", op,
"?;[1,4,8]");
}
TEST(ArrayOpsTest, Tile_ShapeFn) {

View File

@ -53,20 +53,21 @@ class SplitOpTest(test.TestCase):
model_input = array_ops.placeholder(dtypes.float32)
inp = np.zeros((1, 10))
# check that we still fail at runtime if the shapes were unknown
with self.test_session(use_gpu=False) as sess:
with self.test_session(use_gpu=True) as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(array_ops.split(model_input, [4]), {model_input: inp})
# test that we can pass a scalar Tensor as num_splits
with self.test_session(use_gpu=False) as sess:
result = sess.run(
array_ops.split(
array_ops.ones([4, 4]),
num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1],
axis=0))
for axis in [0, -2]:
with self.test_session(use_gpu=True) as sess:
result = sess.run(
array_ops.split(
array_ops.ones([4, 4]),
num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1],
axis=axis))
self.assertEqual(result[0].shape, (2, 4))
self.assertEqual(result[1].shape, (2, 4))
self.assertEqual(result[0].shape, (2, 4))
self.assertEqual(result[1].shape, (2, 4))
# test that none split dimensions remain, even if we don't know how
# the split_dim will be split, but we do know the axis
@ -80,7 +81,7 @@ class SplitOpTest(test.TestCase):
model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
result = array_ops.split(model_input2, [2, 2], axis=0)[0]
with self.test_session(use_gpu=False) as sess:
with self.test_session(use_gpu=True) as sess:
sess.run(result, feed_dict={model_input2: np.ones([4, 2])})
def testExplicitNum(self):
@ -116,7 +117,7 @@ class SplitOpTest(test.TestCase):
def _RunAndVerifyVariable(self, dtype, large_num_splits=False):
# Random dims of rank 5
shape = np.random.randint(1, 5, size=5)
split_dim = np.random.randint(0, 5)
split_dim = np.random.randint(-5, 5)
if large_num_splits:
num_split = np.random.randint(16, 25)
else:
@ -180,12 +181,13 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(result[:, 1:4], inp_grads[1])
def testOutputShape(self):
with self.test_session(use_gpu=True):
tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
size_splits = [3, 7, 2]
outputs = array_ops.split(tensor, size_splits, 1)
for i, output in enumerate(outputs):
self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
for axis in [1, -1]:
with self.test_session(use_gpu=True):
tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
size_splits = [3, 7, 2]
outputs = array_ops.split(tensor, size_splits, axis)
for i, output in enumerate(outputs):
self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
def _compare(self, x, dim, num):
np_ans = np.split(x, num, dim)
@ -246,7 +248,7 @@ class SplitOpTest(test.TestCase):
def _RunAndVerify(self, dtype, large_num_splits=False):
# Random dims of rank 5
shape = np.random.randint(0, 5, size=5)
split_dim = np.random.randint(0, 5)
split_dim = np.random.randint(-5, 5)
if large_num_splits:
num_split = np.random.randint(9, 15)
else:
@ -295,6 +297,10 @@ class SplitOpTest(test.TestCase):
with self.assertRaises(ValueError):
array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=2)
# split dim less than -(rank of input)
with self.assertRaises(ValueError):
array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=-3)
# num_split does not evenly divide the size in split_dim.
with self.assertRaisesRegexp(ValueError, "should evenly divide"):
array_ops.split(value=[0, 1, 2, 3], num_or_size_splits=3, axis=0)

View File

@ -1196,7 +1196,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
evenly divide `value.shape[axis]`; otherwise the sum of sizes along the
split dimension must match that of the `value`.
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
Must be in the range `[0, rank(value))`. Defaults to 0.
Must be in the range `[-rank(value), rank(value))`. Defaults to 0.
num: Optional, used to specify the number of outputs when it cannot be
inferred from the shape of `size_splits`.
name: A name for the operation (optional).