[XLA] Allow 0-sized slices in DynamicSlice and DynamicUpdateSlice; add tests.

PiperOrigin-RevId: 158015870
This commit is contained in:
Peter Hawkins 2017-06-05 08:12:29 -07:00 committed by TensorFlower Gardener
parent 48a4853ebe
commit 009789f742
2 changed files with 20 additions and 3 deletions

View File

@ -1103,7 +1103,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
const int64 input_dim_size = operand_shape.dimensions(dim);
const int64 slice_dim_size = slice_sizes[dim];
if (slice_dim_size <= 0) {
if (slice_dim_size < 0) {
return InvalidArgument("negative size index to dynamic slice: %lld",
slice_dim_size);
}
@ -1175,9 +1175,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
const int64 input_dim_size = operand_shape.dimensions(dim);
const int64 update_dim_size = update_shape.dimensions(dim);
if (update_dim_size <= 0) {
if (update_dim_size < 0) {
return InvalidArgument(
"size index %lld to dynamic update slice must be > 0",
"size index %lld to dynamic update slice must be >= 0",
update_dim_size);
}
if (update_dim_size > input_dim_size) {

View File

@ -58,6 +58,8 @@ class DynamicSliceTest : public ClientLibraryTestBase {
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4},
{6.0, 7.0, 0.0, 1.0});
// Zero element slice.
RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {0}, {});
}
template <typename IndexT>
@ -75,6 +77,12 @@ class DynamicSliceTest : public ClientLibraryTestBase {
RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
{1, 1}, {3, 3},
{{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}});
// Zero element slice: 2x0.
RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
{0, 0}, {2, 0}, {{}, {}});
// Zero element slice: 0x2.
RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
{0, 0}, {0, 2}, Array2D<float>(0, 2));
}
template <typename IndexT>
@ -200,6 +208,10 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
{8.0, 9.0, 10.0}, {6},
{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0});
// Zero-sized update.
RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
{}, {2},
{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
// clang-format on
}
@ -226,6 +238,11 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
{{10.0f, 11.0f}}, {2, 2},
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}});
// Zero-sized update.
RunR2<IndexT>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
{{}}, {2, 1},
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
// clang-format on
}