mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Support negative axis for Split op
PiperOrigin-RevId: 157628162
This commit is contained in:
parent
bc236cfc3b
commit
0b8070253d
|
|
@ -637,27 +637,34 @@ Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
|
||||||
return MakeShapeFromPartialTensorShape(partial_shape, out);
|
return MakeShapeFromPartialTensorShape(partial_shape, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a new dimension whose value is given by a scalar input tensor.
|
Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
|
||||||
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
|
// Caller must ensure that <t> is not NULL.
|
||||||
const Tensor* t = input_tensor(idx);
|
|
||||||
if (t == nullptr) {
|
|
||||||
*out = UnknownDim();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
const int rank = t->dims();
|
const int rank = t->dims();
|
||||||
if (rank != 0) {
|
if (rank != 0) {
|
||||||
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
|
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 val;
|
|
||||||
if (t->dtype() == DT_INT32) {
|
if (t->dtype() == DT_INT32) {
|
||||||
val = t->scalar<int32>()();
|
*val = t->scalar<int32>()();
|
||||||
|
return Status::OK();
|
||||||
} else if (t->dtype() == DT_INT64) {
|
} else if (t->dtype() == DT_INT64) {
|
||||||
val = t->scalar<int64>()();
|
*val = t->scalar<int64>()();
|
||||||
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Scalar input for dim size must be int32 or int64");
|
"Scalar input for dim size must be int32 or int64");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||||
|
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
|
||||||
|
int64 val;
|
||||||
|
const Tensor* t = input_tensor(idx);
|
||||||
|
if (t == nullptr) {
|
||||||
|
*out = UnknownDim();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
|
||||||
if (val < 0) {
|
if (val < 0) {
|
||||||
return errors::InvalidArgument("Dimension size, given by scalar input ",
|
return errors::InvalidArgument("Dimension size, given by scalar input ",
|
||||||
idx, ", must be non-negative but is ", val);
|
idx, ", must be non-negative but is ", val);
|
||||||
|
|
@ -666,6 +673,35 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
|
||||||
|
int idx, int input_rank, DimensionHandle* out) {
|
||||||
|
int64 val;
|
||||||
|
const Tensor* t = input_tensor(idx);
|
||||||
|
if (t == nullptr) {
|
||||||
|
*out = UnknownDim();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
|
||||||
|
if (val < 0) {
|
||||||
|
if (input_rank < 0) {
|
||||||
|
*out = UnknownDim();
|
||||||
|
return Status::OK();
|
||||||
|
} else if (val + input_rank < 0) {
|
||||||
|
return errors::InvalidArgument("Dimension size, given by scalar input ",
|
||||||
|
val, " must be in range [-", input_rank,
|
||||||
|
", ", input_rank, ")");
|
||||||
|
} else {
|
||||||
|
val += input_rank;
|
||||||
|
}
|
||||||
|
} else if (input_rank >= 0 && val >= input_rank) {
|
||||||
|
return errors::InvalidArgument("Dimension size, given by scalar input ",
|
||||||
|
val, " must be in range [-", input_rank,
|
||||||
|
", ", input_rank, ")");
|
||||||
|
}
|
||||||
|
*out = MakeDim(val);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status InferenceContext::Divide(DimensionHandle dividend,
|
Status InferenceContext::Divide(DimensionHandle dividend,
|
||||||
DimensionOrConstant divisor,
|
DimensionOrConstant divisor,
|
||||||
bool evenly_divisible, DimensionHandle* out) {
|
bool evenly_divisible, DimensionHandle* out) {
|
||||||
|
|
|
||||||
|
|
@ -401,11 +401,24 @@ class InferenceContext {
|
||||||
|
|
||||||
inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
|
inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
|
||||||
|
|
||||||
|
// Returns in <val> a scalar value from an input tensor <t>. The input tensor
|
||||||
|
// must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
|
||||||
|
// input tensor is not NULL.
|
||||||
|
Status GetScalarFromTensor(const Tensor* t, int64* val);
|
||||||
|
|
||||||
// Returns a new dimension whose value is given by a scalar input tensor.
|
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||||
// The input tensor must be in host memory, since it is dereferenced to get
|
// The input tensor must be in host memory, since it is dereferenced to get
|
||||||
// the value.
|
// the value.
|
||||||
Status MakeDimForScalarInput(int idx, DimensionHandle* out);
|
Status MakeDimForScalarInput(int idx, DimensionHandle* out);
|
||||||
|
|
||||||
|
// Returns a new dimension whose value is given by a scalar input tensor.
|
||||||
|
// This allows for a negative input dimension given the rank of a separate
|
||||||
|
// tensor. This rank can be negative if unknown.
|
||||||
|
// The input tensor must be in host memory, since it is dereferenced to get
|
||||||
|
// the value.
|
||||||
|
Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
|
||||||
|
DimensionHandle* out);
|
||||||
|
|
||||||
// Look up the attr for the NodeDef being evaluated with name attr_name and
|
// Look up the attr for the NodeDef being evaluated with name attr_name and
|
||||||
// set *value to its value. If no attr with attr_name is found in def(), or
|
// set *value to its value. If no attr with attr_name is found in def(), or
|
||||||
// the attr does not have a matching type, a non-ok status will be returned.
|
// the attr does not have a matching type, a non-ok status will be returned.
|
||||||
|
|
|
||||||
|
|
@ -46,15 +46,18 @@ class SplitOpBase : public OpKernel {
|
||||||
explicit SplitOpBase(OpKernelConstruction* c) : OpKernel(c) {}
|
explicit SplitOpBase(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
void ComputeEasyCases(OpKernelContext* context, bool* done) {
|
void ComputeEasyCases(OpKernelContext* context, bool* done) {
|
||||||
const int32 split_dim = context->input(0).flat<int32>()(0);
|
|
||||||
const int32 num_split = num_outputs();
|
|
||||||
const Tensor& input = context->input(1);
|
const Tensor& input = context->input(1);
|
||||||
const TensorShape& input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
|
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
|
||||||
|
const int32 split_dim =
|
||||||
|
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||||
|
const int32 num_split = num_outputs();
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, 0 <= split_dim && split_dim < input_shape.dims(),
|
context, 0 <= split_dim && split_dim < input_shape.dims(),
|
||||||
errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
|
errors::InvalidArgument("-input rank(-", input.dims(),
|
||||||
input_shape.dims(), "), but got ", split_dim));
|
") <= split_dim < input rank (", input.dims(),
|
||||||
|
"), but got ", split_dim_orig));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, num_split > 0,
|
context, num_split > 0,
|
||||||
|
|
@ -129,10 +132,12 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
|
||||||
if (!context->status().ok() || done) {
|
if (!context->status().ok() || done) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int32 split_dim = context->input(0).flat<int32>()(0);
|
|
||||||
const int32 num_split = Base::num_outputs();
|
const int32 num_split = Base::num_outputs();
|
||||||
const Tensor& input = context->input(1);
|
const Tensor& input = context->input(1);
|
||||||
const TensorShape& input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
|
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
|
||||||
|
const int32 split_dim =
|
||||||
|
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||||
|
|
||||||
// Android also uses int32 indexing, so check here also.
|
// Android also uses int32 indexing, so check here also.
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
|
|
@ -204,15 +209,16 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
|
||||||
if (!context->status().ok() || done) {
|
if (!context->status().ok() || done) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int32 split_dim = context->input(0).flat<int32>()(0);
|
|
||||||
const int32 num_split = Base::num_outputs();
|
|
||||||
const Tensor& input = context->input(1);
|
const Tensor& input = context->input(1);
|
||||||
const TensorShape& input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
|
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
|
||||||
|
const int32 split_dim =
|
||||||
|
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||||
|
const int32 num_split = Base::num_outputs();
|
||||||
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
|
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
|
||||||
std::numeric_limits<int32>::max()),
|
std::numeric_limits<int32>::max()),
|
||||||
errors::InvalidArgument("Split on GPU requires input size "
|
errors::InvalidArgument("Split on GPU requires input size "
|
||||||
"< max int32"));
|
"< max int32"));
|
||||||
|
|
||||||
int32 prefix_dim_size;
|
int32 prefix_dim_size;
|
||||||
int32 split_dim_size;
|
int32 split_dim_size;
|
||||||
int32 suffix_dim_size;
|
int32 suffix_dim_size;
|
||||||
|
|
@ -260,10 +266,12 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
||||||
if (!context->status().ok() || done) {
|
if (!context->status().ok() || done) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int32 split_dim = context->input(0).flat<int32>()(0);
|
|
||||||
const int32 num_split = Base::num_outputs();
|
|
||||||
const Tensor& input = context->input(1);
|
const Tensor& input = context->input(1);
|
||||||
const TensorShape& input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
|
const int32 split_dim_orig = context->input(0).flat<int32>()(0);
|
||||||
|
const int32 split_dim =
|
||||||
|
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||||
|
const int32 num_split = Base::num_outputs();
|
||||||
|
|
||||||
// Android also uses int32 indexing, so check here also.
|
// Android also uses int32 indexing, so check here also.
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,9 @@ class SplitVOpBase : public OpKernel {
|
||||||
const TensorShape& input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
const Tensor& split_tensor = context->input(1);
|
const Tensor& split_tensor = context->input(1);
|
||||||
|
|
||||||
const int32 split_dim = context->input(2).flat<int32>()(0);
|
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
|
||||||
|
const int32 split_dim =
|
||||||
|
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context,
|
context,
|
||||||
|
|
@ -79,8 +81,9 @@ class SplitVOpBase : public OpKernel {
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, 0 <= split_dim && split_dim < input.dims(),
|
context, 0 <= split_dim && split_dim < input.dims(),
|
||||||
errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
|
errors::InvalidArgument("-input rank(-", input.dims(),
|
||||||
input.dims(), "), but got ", split_dim));
|
") <= split_dim < input rank (", input.dims(),
|
||||||
|
"), but got ", split_dim_orig));
|
||||||
|
|
||||||
Tlen input_size_split_dim = input_shape.dim_size(split_dim);
|
Tlen input_size_split_dim = input_shape.dim_size(split_dim);
|
||||||
|
|
||||||
|
|
@ -187,7 +190,9 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
|
||||||
const int32 num_split = Base::num_outputs();
|
const int32 num_split = Base::num_outputs();
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const TensorShape& input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
const int32 split_dim = context->input(2).flat<int32>()(0);
|
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
|
||||||
|
const int32 split_dim =
|
||||||
|
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||||
|
|
||||||
// Android also uses int32 indexing, so check here also.
|
// Android also uses int32 indexing, so check here also.
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
|
|
@ -257,7 +262,9 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
|
||||||
const int32 num_split = Base::num_outputs();
|
const int32 num_split = Base::num_outputs();
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const TensorShape& input_shape = input.shape();
|
const TensorShape& input_shape = input.shape();
|
||||||
const int32 split_dim = context->input(2).flat<int32>()(0);
|
const int32 split_dim_orig = context->input(2).flat<int32>()(0);
|
||||||
|
const int32 split_dim =
|
||||||
|
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
|
||||||
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
|
OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
|
||||||
std::numeric_limits<int32>::max()),
|
std::numeric_limits<int32>::max()),
|
||||||
errors::InvalidArgument("Split on GPU requires input size "
|
errors::InvalidArgument("Split on GPU requires input size "
|
||||||
|
|
|
||||||
|
|
@ -456,9 +456,10 @@ REGISTER_OP("Split")
|
||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
DimensionHandle split_dimension;
|
DimensionHandle split_dimension;
|
||||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(0, &split_dimension));
|
|
||||||
int num_split = c->num_outputs();
|
|
||||||
ShapeHandle input = c->input(1);
|
ShapeHandle input = c->input(1);
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
|
||||||
|
0, c->Rank(input), &split_dimension));
|
||||||
|
int num_split = c->num_outputs();
|
||||||
ShapeHandle out;
|
ShapeHandle out;
|
||||||
if (!c->ValueKnown(split_dimension)) {
|
if (!c->ValueKnown(split_dimension)) {
|
||||||
if (c->RankKnown(input)) {
|
if (c->RankKnown(input)) {
|
||||||
|
|
@ -484,7 +485,7 @@ REGISTER_OP("Split")
|
||||||
Splits a tensor into `num_split` tensors along one dimension.
|
Splits a tensor into `num_split` tensors along one dimension.
|
||||||
|
|
||||||
split_dim: 0-D. The dimension along which to split. Must be in the range
|
split_dim: 0-D. The dimension along which to split. Must be in the range
|
||||||
`[0, rank(value))`.
|
`[-rank(value), rank(value))`.
|
||||||
num_split: The number of ways to split. Must evenly divide
|
num_split: The number of ways to split. Must evenly divide
|
||||||
`value.shape[split_dim]`.
|
`value.shape[split_dim]`.
|
||||||
value: The tensor to split.
|
value: The tensor to split.
|
||||||
|
|
@ -503,9 +504,10 @@ REGISTER_OP("SplitV")
|
||||||
.Attr("Tlen: {int32, int64} = DT_INT64")
|
.Attr("Tlen: {int32, int64} = DT_INT64")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
DimensionHandle split_dimension;
|
DimensionHandle split_dimension;
|
||||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension));
|
|
||||||
int32 num_outputs = c->num_outputs();
|
|
||||||
ShapeHandle input = c->input(0);
|
ShapeHandle input = c->input(0);
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing(
|
||||||
|
2, c->Rank(input), &split_dimension));
|
||||||
|
int32 num_outputs = c->num_outputs();
|
||||||
int32 rank = c->Rank(input);
|
int32 rank = c->Rank(input);
|
||||||
ShapeHandle output_shape;
|
ShapeHandle output_shape;
|
||||||
const Tensor* size_splits = c->input_tensor(1);
|
const Tensor* size_splits = c->input_tensor(1);
|
||||||
|
|
@ -594,7 +596,7 @@ size_splits: list containing the sizes of each output tensor along the split
|
||||||
dimension. Must sum to the dimension of value along split_dim.
|
dimension. Must sum to the dimension of value along split_dim.
|
||||||
Can contain one -1 indicating that dimension is to be inferred.
|
Can contain one -1 indicating that dimension is to be inferred.
|
||||||
split_dim: 0-D. The dimension along which to split. Must be in the range
|
split_dim: 0-D. The dimension along which to split. Must be in the range
|
||||||
`[0, rank(value))`.
|
`[-rank(value), rank(value))`.
|
||||||
output: Tensors whose shape matches that of `value`
|
output: Tensors whose shape matches that of `value`
|
||||||
except along `split_dim`, where their sizes are
|
except along `split_dim`, where their sizes are
|
||||||
`size_splits[i]`.
|
`size_splits[i]`.
|
||||||
|
|
|
||||||
|
|
@ -993,6 +993,9 @@ TEST(ArrayOpsTest, Split_ShapeFn) {
|
||||||
// If the rank is known, we know the rank of each output.
|
// If the rank is known, we know the rank of each output.
|
||||||
INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
|
INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
|
||||||
|
|
||||||
|
// split_dim is unknown but other inputs are known.
|
||||||
|
INFER_OK(op, "?;[1,4]", "[?,?];[?,?]");
|
||||||
|
|
||||||
// split_dim is known.
|
// split_dim is known.
|
||||||
Tensor split_dim = test::AsTensor<int32>({1, 2});
|
Tensor split_dim = test::AsTensor<int32>({1, 2});
|
||||||
op.input_tensors[0] = &split_dim;
|
op.input_tensors[0] = &split_dim;
|
||||||
|
|
@ -1004,6 +1007,26 @@ TEST(ArrayOpsTest, Split_ShapeFn) {
|
||||||
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
|
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
|
||||||
INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
|
INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
|
||||||
"?;[1,5]");
|
"?;[1,5]");
|
||||||
|
|
||||||
|
// split_dim too large.
|
||||||
|
split_dim = test::AsScalar<int32>(3);
|
||||||
|
INFER_ERROR(
|
||||||
|
"Dimension size, given by scalar input 3 must be in range [-3, 3)", op,
|
||||||
|
"?;[1,4,8]");
|
||||||
|
|
||||||
|
// Negative split_dim.
|
||||||
|
split_dim = test::AsScalar<int32>(-1);
|
||||||
|
INFER_OK(op, "?;?", "?;?");
|
||||||
|
INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
|
||||||
|
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
|
||||||
|
INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
|
||||||
|
INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]");
|
||||||
|
split_dim = test::AsScalar<int32>(-2);
|
||||||
|
INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]");
|
||||||
|
split_dim = test::AsScalar<int32>(-4);
|
||||||
|
INFER_ERROR(
|
||||||
|
"Dimension size, given by scalar input -4 must be in range [-3, 3)", op,
|
||||||
|
"?;[1,4,8]");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArrayOpsTest, Tile_ShapeFn) {
|
TEST(ArrayOpsTest, Tile_ShapeFn) {
|
||||||
|
|
|
||||||
|
|
@ -53,20 +53,21 @@ class SplitOpTest(test.TestCase):
|
||||||
model_input = array_ops.placeholder(dtypes.float32)
|
model_input = array_ops.placeholder(dtypes.float32)
|
||||||
inp = np.zeros((1, 10))
|
inp = np.zeros((1, 10))
|
||||||
# check that we still fail at runtime if the shapes were unknown
|
# check that we still fail at runtime if the shapes were unknown
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
sess.run(array_ops.split(model_input, [4]), {model_input: inp})
|
sess.run(array_ops.split(model_input, [4]), {model_input: inp})
|
||||||
|
|
||||||
# test that we can pass a scalar Tensor as num_splits
|
# test that we can pass a scalar Tensor as num_splits
|
||||||
with self.test_session(use_gpu=False) as sess:
|
for axis in [0, -2]:
|
||||||
result = sess.run(
|
with self.test_session(use_gpu=True) as sess:
|
||||||
array_ops.split(
|
result = sess.run(
|
||||||
array_ops.ones([4, 4]),
|
array_ops.split(
|
||||||
num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1],
|
array_ops.ones([4, 4]),
|
||||||
axis=0))
|
num_or_size_splits=array_ops.ones([2, 2]).get_shape()[1],
|
||||||
|
axis=axis))
|
||||||
|
|
||||||
self.assertEqual(result[0].shape, (2, 4))
|
self.assertEqual(result[0].shape, (2, 4))
|
||||||
self.assertEqual(result[1].shape, (2, 4))
|
self.assertEqual(result[1].shape, (2, 4))
|
||||||
|
|
||||||
# test that none split dimensions remain, even if we don't know how
|
# test that none split dimensions remain, even if we don't know how
|
||||||
# the split_dim will be split, but we do know the axis
|
# the split_dim will be split, but we do know the axis
|
||||||
|
|
@ -80,7 +81,7 @@ class SplitOpTest(test.TestCase):
|
||||||
model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
|
model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
|
||||||
result = array_ops.split(model_input2, [2, 2], axis=0)[0]
|
result = array_ops.split(model_input2, [2, 2], axis=0)[0]
|
||||||
|
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
sess.run(result, feed_dict={model_input2: np.ones([4, 2])})
|
sess.run(result, feed_dict={model_input2: np.ones([4, 2])})
|
||||||
|
|
||||||
def testExplicitNum(self):
|
def testExplicitNum(self):
|
||||||
|
|
@ -116,7 +117,7 @@ class SplitOpTest(test.TestCase):
|
||||||
def _RunAndVerifyVariable(self, dtype, large_num_splits=False):
|
def _RunAndVerifyVariable(self, dtype, large_num_splits=False):
|
||||||
# Random dims of rank 5
|
# Random dims of rank 5
|
||||||
shape = np.random.randint(1, 5, size=5)
|
shape = np.random.randint(1, 5, size=5)
|
||||||
split_dim = np.random.randint(0, 5)
|
split_dim = np.random.randint(-5, 5)
|
||||||
if large_num_splits:
|
if large_num_splits:
|
||||||
num_split = np.random.randint(16, 25)
|
num_split = np.random.randint(16, 25)
|
||||||
else:
|
else:
|
||||||
|
|
@ -180,12 +181,13 @@ class SplitOpTest(test.TestCase):
|
||||||
self.assertAllEqual(result[:, 1:4], inp_grads[1])
|
self.assertAllEqual(result[:, 1:4], inp_grads[1])
|
||||||
|
|
||||||
def testOutputShape(self):
|
def testOutputShape(self):
|
||||||
with self.test_session(use_gpu=True):
|
for axis in [1, -1]:
|
||||||
tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
|
with self.test_session(use_gpu=True):
|
||||||
size_splits = [3, 7, 2]
|
tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
|
||||||
outputs = array_ops.split(tensor, size_splits, 1)
|
size_splits = [3, 7, 2]
|
||||||
for i, output in enumerate(outputs):
|
outputs = array_ops.split(tensor, size_splits, axis)
|
||||||
self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
|
for i, output in enumerate(outputs):
|
||||||
|
self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
|
||||||
|
|
||||||
def _compare(self, x, dim, num):
|
def _compare(self, x, dim, num):
|
||||||
np_ans = np.split(x, num, dim)
|
np_ans = np.split(x, num, dim)
|
||||||
|
|
@ -246,7 +248,7 @@ class SplitOpTest(test.TestCase):
|
||||||
def _RunAndVerify(self, dtype, large_num_splits=False):
|
def _RunAndVerify(self, dtype, large_num_splits=False):
|
||||||
# Random dims of rank 5
|
# Random dims of rank 5
|
||||||
shape = np.random.randint(0, 5, size=5)
|
shape = np.random.randint(0, 5, size=5)
|
||||||
split_dim = np.random.randint(0, 5)
|
split_dim = np.random.randint(-5, 5)
|
||||||
if large_num_splits:
|
if large_num_splits:
|
||||||
num_split = np.random.randint(9, 15)
|
num_split = np.random.randint(9, 15)
|
||||||
else:
|
else:
|
||||||
|
|
@ -295,6 +297,10 @@ class SplitOpTest(test.TestCase):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=2)
|
array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=2)
|
||||||
|
|
||||||
|
# split dim less than -(rank of input)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=-3)
|
||||||
|
|
||||||
# num_split does not evenly divide the size in split_dim.
|
# num_split does not evenly divide the size in split_dim.
|
||||||
with self.assertRaisesRegexp(ValueError, "should evenly divide"):
|
with self.assertRaisesRegexp(ValueError, "should evenly divide"):
|
||||||
array_ops.split(value=[0, 1, 2, 3], num_or_size_splits=3, axis=0)
|
array_ops.split(value=[0, 1, 2, 3], num_or_size_splits=3, axis=0)
|
||||||
|
|
|
||||||
|
|
@ -1196,7 +1196,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
|
||||||
evenly divide `value.shape[axis]`; otherwise the sum of sizes along the
|
evenly divide `value.shape[axis]`; otherwise the sum of sizes along the
|
||||||
split dimension must match that of the `value`.
|
split dimension must match that of the `value`.
|
||||||
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
|
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
|
||||||
Must be in the range `[0, rank(value))`. Defaults to 0.
|
Must be in the range `[-rank(value), rank(value))`. Defaults to 0.
|
||||||
num: Optional, used to specify the number of outputs when it cannot be
|
num: Optional, used to specify the number of outputs when it cannot be
|
||||||
inferred from the shape of `size_splits`.
|
inferred from the shape of `size_splits`.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user