mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Support negative axis for Split op
PiperOrigin-RevId: 157628162
This commit is contained in:
parent
bc236cfc3b
commit
0b8070253d
|
|
@ -637,27 +637,34 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
|
|||
return MakeShapeFromPartialTensorShape(partial_shape, out);
|
||||
}
|
||||
|
||||
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
|
||||
const Tensor* t = input_tensor(idx);
|
||||
if (t == nullptr) {
|
||||
*out = UnknownDim();
|
||||
return Status::OK();
|
||||
}
|
||||
Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
|
||||
// Caller must ensure that <t> is not NULL.
|
||||
const int rank = t->dims();
|
||||
if (rank != 0) {
|
||||
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
|
||||
}
|
||||
|
||||
int64 val;
|
||||
if (t->dtype() == DT_INT32) {
|
||||
val = t->scalar<int32>()();
|
||||
*val = t->scalar<int32>()();
|
||||
return Status::OK();
|
||||
} else if (t->dtype() == DT_INT64) {
|
||||
val = t->scalar<int64>()();
|
||||
*val = t->scalar<int64>()();
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Scalar input for dim size must be int32 or int64");
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
|
||||
int64 val;
|
||||
const Tensor* t = input_tensor(idx);
|
||||
if (t == nullptr) {
|
||||
*out = UnknownDim();
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
|
||||
if (val < 0) {
|
||||
return errors::InvalidArgument("Dimension size, given by scalar input ",
|
||||
idx, ", must be non-negative but is ", val);
|
||||
|
|
@ -666,6 +673,35 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
|
||||
int idx, int input_rank, DimensionHandle* out) {
|
||||
int64 val;
|
||||
const Tensor* t = input_tensor(idx);
|
||||
if (t == nullptr) {
|
||||
*out = UnknownDim();
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
|
||||
if (val < 0) {
|
||||
if (input_rank < 0) {
|
||||
*out = UnknownDim();
|
||||
return Status::OK();
|
||||
} else if (val + input_rank < 0) {
|
||||
return errors::InvalidArgument("Dimension size, given by scalar input ",
|
||||
val, " must be in range [-", input_rank,
|
||||
", ", input_rank, ")");
|
||||
} else {
|
||||
val += input_rank;
|
||||
}
|
||||
} else if (input_rank >= 0 && val >= input_rank) {
|
||||
return errors::InvalidArgument("Dimension size, given by scalar input ",
|
||||
val, " must be in range [-", input_rank,
|
||||
", ", input_rank, ")");
|
||||
}
|
||||
*out = MakeDim(val);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InferenceContext::Divide(DimensionHandle dividend,
|
||||
DimensionOrConstant divisor,
|
||||
bool evenly_divisible, DimensionHandle* out) {
|
||||
|
|
|
|||
|
|
@ -401,11 +401,24 @@ class InferenceContext {
|
|||
|
||||
inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
|
||||
|
||||
// Returns in <val> a scalar value from an input tensor <t>. The input tensor
|
||||
// must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
|
||||
// input tensor is not NULL.
|
||||
Status GetScalarFromTensor(const Tensor* t, int64* val);
|
||||
|
||||
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||
// The input tensor must be in host memory, since it is dereferenced to get
|
||||
// the value.
|
||||
Status MakeDimForScalarInput(int idx, DimensionHandle* out);
|
||||
|
||||
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||
// This allows for a negative input dimension given the rank of a separate
|
||||
// tensor. This rank can be negative if unknown.
|
||||
// The input tensor must be in host memory, since it is dereferenced to get
|
||||
// the value.
|
||||
Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
|
||||
DimensionHandle* out);
|
||||
|
||||
// Look up the attr for the NodeDef being evaluated with name attr_name and
|
||||
// set *value to its value. If no attr with attr_name is found in def(), or
|
||||
// the attr does not have a matching type, a non-ok status will be returned.
|
||||
|
|
|
|||
|
|
@ -46,15 +46,18 @@ class SplitOpBase : public OpKernel {
|
|||
explicit SplitOpBase(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
void ComputeEasyCases(OpKernelContext* context, bool* done) {
|
||||
const int32 split_dim = context->input(0).flat<int32>()(0);
|
||||
const int32 num_split = num_outputs();
|
||||
const Tensor& input = context->input(1);
|
||||
const TensorShape& input_shape = input.shape();
|
||||
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
|
||||
const int32 split_dim =
|
||||
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||
const int32 num_split = num_outputs();
|
||||
|
||||
OP_REQUIRES(
|
||||
context, 0 <= split_dim && split_dim < input_shape.dims(),
|
||||
errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
|
||||
input_shape.dims(), "), but got ", split_dim));
|
||||
errors::InvalidArgument("-input rank(-", input.dims(),
|
||||
") <= split_dim < input rank (", input.dims(),
|
||||
"), but got ", split_dim_orig));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, num_split > 0,
|
||||
|
|
@ -129,10 +132,12 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
|
|||
if (!context->status().ok() || done) {
|
||||
return;
|
||||
}
|
||||
const int32 split_dim = context->input(0).flat<int32>()(0);
|
||||
const int32 num_split = Base::num_outputs();
|
||||
const Tensor& input = context->input(1);
|
||||
const TensorShape& input_shape = input.shape();
|
||||
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
|
||||
const int32 split_dim =
|
||||
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||
|
||||
// Android also uses int32 indexing, so check here also.
|
||||
OP_REQUIRES(
|
||||
|
|
@ -204,15 +209,16 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
|||
if (!context->status().ok() || done) {
|
||||
return;
|
||||
}
|
||||
const int32 split_dim = context->input(0).flat<int32>()(0);
|
||||
const int32 num_split = Base::num_outputs();
|
||||
const Tensor& input = context->input(1);
|
||||
const TensorShape& input_shape = input.shape();
|
||||
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
|
||||
const int32 split_dim =
|
||||
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||
const int32 num_split = Base::num_outputs();
|
||||
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
|
||||
std::numeric_limits<int32>::max()),
|
||||
errors::InvalidArgument("Split on GPU requires input size "
|
||||
"< max int32"));
|
||||
|
||||
int32 prefix_dim_size;
|
||||
int32 split_dim_size;
|
||||
int32 suffix_dim_size;
|
||||
|
|
@ -260,10 +266,12 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
|||
if (!context->status().ok() || done) {
|
||||
return;
|
||||
}
|
||||
const int32 split_dim = context->input(0).flat<int32>()(0);
|
||||
const int32 num_split = Base::num_outputs();
|
||||
const Tensor& input = context->input(1);
|
||||
const TensorShape& input_shape = input.shape();
|
||||
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
|
||||
const int32 split_dim =
|
||||
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||
const int32 num_split = Base::num_outputs();
|
||||
|
||||
// Android also uses int32 indexing, so check here also.
|
||||
OP_REQUIRES(
|
||||
|
|
|
|||
|
|
@ -55,7 +55,9 @@ class SplitVOpBase : public OpKernel {
|
|||
const TensorShape& input_shape = input.shape();
|
||||
const Tensor& split_tensor = context->input(1);
|
||||
|
||||
const int32 split_dim = context->input(2).flat<int32>()(0);
|
||||
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
|
||||
const int32 split_dim =
|
||||
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
|
|
@ -79,8 +81,9 @@ class SplitVOpBase : public OpKernel {
|
|||
|
||||
OP_REQUIRES(
|
||||
context, 0 <= split_dim && split_dim < input.dims(),
|
||||
errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
|
||||
input.dims(), "), but got ", split_dim));
|
||||
errors::InvalidArgument("-input rank(-", input.dims(),
|
||||
") <= split_dim < input rank (", input.dims(),
|
||||
"), but got ", split_dim_orig));
|
||||
|
||||
Tlen input_size_split_dim = input_shape.dim_size(split_dim);
|
||||
|
||||
|
|
@ -187,7 +190,9 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
|
|||
const int32 num_split = Base::num_outputs();
|
||||
const Tensor& input = context->input(0);
|
||||
const TensorShape& input_shape = input.shape();
|
||||
const int32 split_dim = context->input(2).flat<int32>()(0);
|
||||
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
|
||||
const int32 split_dim =
|
||||
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||
|
||||
// Android also uses int32 indexing, so check here also.
|
||||
OP_REQUIRES(
|
||||
|
|
@ -257,7 +262,9 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
|
|||
const int32 num_split = Base::num_outputs();
|
||||
const Tensor& input = context->input(0);
|
||||
const TensorShape& input_shape = input.shape();
|
||||
const int32 split_dim = context->input(2).flat<int32>()(0);
|
||||
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
|
||||
const int32 split_dim =
|
||||
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
|
||||
std::numeric_limits<int32>::max()),
|
||||
errors::InvalidArgument("Split on GPU requires input size "
|
||||
|
|
|
|||
|
|
@ -456,9 +456,10 @@ REGISTER_OP("Split")
|
|||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
DimensionHandle split_dimension;
|
||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(0, &split_dimension));
|
||||
int num_split = c->num_outputs();
|
||||
ShapeHandle input = c->input(1);
|
||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
|
||||
0, c->Rank(input), &split_dimension));
|
||||
int num_split = c->num_outputs();
|
||||
ShapeHandle out;
|
||||
if (!c->ValueKnown(split_dimension)) {
|
||||
if (c->RankKnown(input)) {
|
||||
|
|
@ -484,7 +485,7 @@ REGISTER_OP("Split")
|
|||
Splits a tensor into `num_split` tensors along one dimension.
|
||||
|
||||
split_dim: 0-D. The dimension along which to split. Must be in the range
|
||||
`[0, rank(value))`.
|
||||
`[-rank(value), rank(value))`.
|
||||
num_split: The number of ways to split. Must evenly divide
|
||||
`value.shape[split_dim]`.
|
||||
value: The tensor to split.
|
||||
|
|
@ -503,9 +504,10 @@ REGISTER_OP("SplitV")
|
|||
.Attr("Tlen: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
DimensionHandle split_dimension;
|
||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension));
|
||||
int32 num_outputs = c->num_outputs();
|
||||
ShapeHandle input = c->input(0);
|
||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
|
||||
2, c->Rank(input), &split_dimension));
|
||||
int32 num_outputs = c->num_outputs();
|
||||
int32 rank = c->Rank(input);
|
||||
ShapeHandle output_shape;
|
||||
const Tensor* size_splits = c->input_tensor(1);
|
||||
|
|
@ -594,7 +596,7 @@ size_splits: list containing the sizes of each output tensor along the split
|
|||
dimension. Must sum to the dimension of value along split_dim.
|
||||
Can contain one -1 indicating that dimension is to be inferred.
|
||||
split_dim: 0-D. The dimension along which to split. Must be in the range
|
||||
`[0, rank(value))`.
|
||||
`[-rank(value), rank(value))`.
|
||||
output: Tensors whose shape matches that of `value`
|
||||
except along `split_dim`, where their sizes are
|
||||
`size_splits[i]`.
|
||||
|
|
|
|||
|
|
@ -993,6 +993,9 @@ TEST(ArrayOpsTest, Split_ShapeFn) {
|
|||
// If the rank is known, we know the rank of each output.
|
||||
INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
|
||||
|
||||
// split_dim is unknown but other inputs are known.
|
||||
INFER_OK(op, "?;[1,4]", "[?,?];[?,?]");
|
||||
|
||||
// split_dim is known.
|
||||
Tensor split_dim = test::AsTensor<int32>({1, 2});
|
||||
op.input_tensors[0] = &split_dim;
|
||||
|
|
@ -1004,6 +1007,26 @@ TEST(ArrayOpsTest, Split_ShapeFn) {
|
|||
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
|
||||
INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
|
||||
"?;[1,5]");
|
||||
|
||||
// split_dim too large.
|
||||
split_dim = test::AsScalar<int32>(3);
|
||||
INFER_ERROR(
|
||||
"Dimension size, given by scalar input 3 must be in range [-3, 3)", op,
|
||||
"?;[1,4,8]");
|
||||
|
||||
// Negative split_dim.
|
||||
split_dim = test::AsScalar<int32>(-1);
|
||||
INFER_OK(op, "?;?", "?;?");
|
||||
INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
|
||||
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
|
||||
INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
|
||||
INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]");
|
||||
split_dim = test::AsScalar<int32>(-2);
|
||||
INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]");
|
||||
split_dim = test::AsScalar<int32>(-4);
|
||||
INFER_ERROR(
|
||||
"Dimension size, given by scalar input -4 must be in range [-3, 3)", op,
|
||||
"?;[1,4,8]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, Tile_ShapeFn) {
|
||||
|
|
|
|||
|
|
@ -53,20 +53,21 @@ class SplitOpTest(test.TestCase):
|
|||
model_input = array_ops.placeholder(dtypes.float32)
|
||||
inp = np.zeros((1, 10))
|
||||
# check that we still fail at runtime if the shapes were unknown
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
sess.run(array_ops.split(model_input, [4]), {model_input: inp})
|
||||
|
||||
# test that we can pass a scalar Tensor as num_splits
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
result = sess.run(
|
||||
array_ops.split(
|
||||
array_ops.ones([4, 4]),
|
||||
num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1],
|
||||
axis=0))
|
||||
for axis in [0, -2]:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
result = sess.run(
|
||||
array_ops.split(
|
||||
array_ops.ones([4, 4]),
|
||||
num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1],
|
||||
axis=axis))
|
||||
|
||||
self.assertEqual(result[0].shape, (2, 4))
|
||||
self.assertEqual(result[1].shape, (2, 4))
|
||||
self.assertEqual(result[0].shape, (2, 4))
|
||||
self.assertEqual(result[1].shape, (2, 4))
|
||||
|
||||
# test that none split dimensions remain, even if we don't know how
|
||||
# the split_dim will be split, but we do know the axis
|
||||
|
|
@ -80,7 +81,7 @@ class SplitOpTest(test.TestCase):
|
|||
model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
|
||||
result = array_ops.split(model_input2, [2, 2], axis=0)[0]
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
sess.run(result, feed_dict={model_input2: np.ones([4, 2])})
|
||||
|
||||
def testExplicitNum(self):
|
||||
|
|
@ -116,7 +117,7 @@ class SplitOpTest(test.TestCase):
|
|||
def _RunAndVerifyVariable(self, dtype, large_num_splits=False):
|
||||
# Random dims of rank 5
|
||||
shape = np.random.randint(1, 5, size=5)
|
||||
split_dim = np.random.randint(0, 5)
|
||||
split_dim = np.random.randint(-5, 5)
|
||||
if large_num_splits:
|
||||
num_split = np.random.randint(16, 25)
|
||||
else:
|
||||
|
|
@ -180,12 +181,13 @@ class SplitOpTest(test.TestCase):
|
|||
self.assertAllEqual(result[:, 1:4], inp_grads[1])
|
||||
|
||||
def testOutputShape(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
|
||||
size_splits = [3, 7, 2]
|
||||
outputs = array_ops.split(tensor, size_splits, 1)
|
||||
for i, output in enumerate(outputs):
|
||||
self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
|
||||
for axis in [1, -1]:
|
||||
with self.test_session(use_gpu=True):
|
||||
tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
|
||||
size_splits = [3, 7, 2]
|
||||
outputs = array_ops.split(tensor, size_splits, axis)
|
||||
for i, output in enumerate(outputs):
|
||||
self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
|
||||
|
||||
def _compare(self, x, dim, num):
|
||||
np_ans = np.split(x, num, dim)
|
||||
|
|
@ -246,7 +248,7 @@ class SplitOpTest(test.TestCase):
|
|||
def _RunAndVerify(self, dtype, large_num_splits=False):
|
||||
# Random dims of rank 5
|
||||
shape = np.random.randint(0, 5, size=5)
|
||||
split_dim = np.random.randint(0, 5)
|
||||
split_dim = np.random.randint(-5, 5)
|
||||
if large_num_splits:
|
||||
num_split = np.random.randint(9, 15)
|
||||
else:
|
||||
|
|
@ -295,6 +297,10 @@ class SplitOpTest(test.TestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=2)
|
||||
|
||||
# split dim less than -(rank of input)
|
||||
with self.assertRaises(ValueError):
|
||||
array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=-3)
|
||||
|
||||
# num_split does not evenly divide the size in split_dim.
|
||||
with self.assertRaisesRegexp(ValueError, "should evenly divide"):
|
||||
array_ops.split(value=[0, 1, 2, 3], num_or_size_splits=3, axis=0)
|
||||
|
|
|
|||
|
|
@ -1196,7 +1196,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
|
|||
evenly divide `value.shape[axis]`; otherwise the sum of sizes along the
|
||||
split dimension must match that of the `value`.
|
||||
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
|
||||
Must be in the range `[0, rank(value))`. Defaults to 0.
|
||||
Must be in the range `[-rank(value), rank(value))`. Defaults to 0.
|
||||
num: Optional, used to specify the number of outputs when it cannot be
|
||||
inferred from the shape of `size_splits`.
|
||||
name: A name for the operation (optional).
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user