mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[XLA] Allow 0-sized slices in DynamicSlice and DynamicUpdateSlice; add tests.
PiperOrigin-RevId: 158015870
This commit is contained in:
parent
48a4853ebe
commit
009789f742
|
|
@ -1103,7 +1103,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||
for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
|
||||
const int64 input_dim_size = operand_shape.dimensions(dim);
|
||||
const int64 slice_dim_size = slice_sizes[dim];
|
||||
if (slice_dim_size <= 0) {
|
||||
if (slice_dim_size < 0) {
|
||||
return InvalidArgument("negative size index to dynamic slice: %lld",
|
||||
slice_dim_size);
|
||||
}
|
||||
|
|
@ -1175,9 +1175,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||
for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
|
||||
const int64 input_dim_size = operand_shape.dimensions(dim);
|
||||
const int64 update_dim_size = update_shape.dimensions(dim);
|
||||
if (update_dim_size <= 0) {
|
||||
if (update_dim_size < 0) {
|
||||
return InvalidArgument(
|
||||
"size index %lld to dynamic update slice must be > 0",
|
||||
"size index %lld to dynamic update slice must be >= 0",
|
||||
update_dim_size);
|
||||
}
|
||||
if (update_dim_size > input_dim_size) {
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
|||
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
|
||||
RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4},
|
||||
{6.0, 7.0, 0.0, 1.0});
|
||||
// Zero element slice.
|
||||
RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {0}, {});
|
||||
}
|
||||
|
||||
template <typename IndexT>
|
||||
|
|
@ -75,6 +77,12 @@ class DynamicSliceTest : public ClientLibraryTestBase {
|
|||
RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
|
||||
{1, 1}, {3, 3},
|
||||
{{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}});
|
||||
// Zero element slice: 2x0.
|
||||
RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
|
||||
{0, 0}, {2, 0}, {{}, {}});
|
||||
// Zero element slice: 0x2.
|
||||
RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
|
||||
{0, 0}, {0, 2}, Array2D<float>(0, 2));
|
||||
}
|
||||
|
||||
template <typename IndexT>
|
||||
|
|
@ -200,6 +208,10 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
|||
RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
|
||||
{8.0, 9.0, 10.0}, {6},
|
||||
{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0});
|
||||
// Zero-sized update.
|
||||
RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
|
||||
{}, {2},
|
||||
{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
|
@ -226,6 +238,11 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
|
|||
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
|
||||
{{10.0f, 11.0f}}, {2, 2},
|
||||
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}});
|
||||
// Zero-sized update.
|
||||
RunR2<IndexT>(
|
||||
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
|
||||
{{}}, {2, 1},
|
||||
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user