[caffe2] SliceOp axes indexing fixes. (#45432)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/45431

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45432

Reviewed By: albanD

Differential Revision: D24132547

Pulled By: dzhulgakov

fbshipit-source-id: d67f7a92d806fb8ac8fc8f522b251d3a8fb83037
This commit is contained in:
n-v-k 2020-10-06 13:14:58 -07:00 committed by Facebook GitHub Bot
parent 3fbddb92b1
commit c1af91a13a
3 changed files with 27 additions and 24 deletions

View File

@ -17,7 +17,7 @@ Produces a slice of the input tensor.
- Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments. - Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments.
- If a negative value is passed for any of the start or end indices, it represents the number of elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element). - If a negative value is passed for any of the start or end indices, it represents |value| - 1 elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element).
Github Links: Github Links:
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc
@ -67,11 +67,11 @@ Y:
.Input( .Input(
1, 1,
"starts", "starts",
"(*Tensor`<int>`*): 1D tensor of start-indices for each dimension of data") "(*Tensor`<int>`*): 1D tensor of start-indices for each dimension of data (dimensions following the sliced one might be omitted)")
.Input( .Input(
2, 2,
"ends", "ends",
"(*Tensor`<int>`*): 1D tensor of end-indices for each dimension of data") "(*Tensor`<int>`*): 1D tensor of end-indices for each dimension of data (dimensions following the sliced one might be omitted)")
.Arg("starts", "(*Tuple(int)*): list of starting indices") .Arg("starts", "(*Tuple(int)*): list of starting indices")
.Arg("ends", "(*Tuple(int)*): list of ending indices") .Arg("ends", "(*Tuple(int)*): list of ending indices")
.TensorInferenceFunction([](const OperatorDef& def, .TensorInferenceFunction([](const OperatorDef& def,
@ -90,9 +90,10 @@ Y:
for (int i = 0; i < data.dims_size(); ++i) { for (int i = 0; i < data.dims_size(); ++i) {
if (i >= starts.size()) { if (i >= starts.size()) {
dst_sizes[i] = data.dims(i);
continue; continue;
} }
if (data.dims_size() > 0) { if (data.dims(i) > 0) {
auto start = starts[i]; auto start = starts[i];
auto end = ends[i]; auto end = ends[i];
if (start < 0) { if (start < 0) {

View File

@ -74,22 +74,23 @@ bool SliceImplGpu(
if (i >= starts.numel()) { if (i >= starts.numel()) {
starts_idx[i] = 0; starts_idx[i] = 0;
ends_idx[i] = data.size(i); ends_idx[i] = data.size(i);
dst_sizes[i] = data.size(i);
continue; continue;
} }
if (data.size(i) > 0) { if (data.size(i) > 0) {
auto start = starts_data[i]; auto start = starts_data[i];
auto end = ends_data[i]; auto end = ends_data[i];
if (start < 0) { if (start < 0) {
start = data.sizes()[i] + 1 + start; start = data.size(i) + 1 + start;
} }
if (end < 0) { if (end < 0) {
end = data.sizes()[i] + 1 + end; end = data.size(i) + 1 + end;
} }
if (start > data.sizes()[i]) { if (start > data.size(i)) {
start = data.sizes()[i]; start = data.size(i);
} }
if (end > data.sizes()[i]) { if (end > data.size(i)) {
end = data.sizes()[i]; end = data.size(i);
} }
CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(start, 0);
CAFFE_ENFORCE_GE(end, 0); CAFFE_ENFORCE_GE(end, 0);
@ -115,7 +116,7 @@ bool SliceImplGpu(
// for now only supports slicing in 1 dimension // for now only supports slicing in 1 dimension
int dim = -1; int dim = -1;
for (int i = 0; i < data.dim(); ++i) { for (int i = 0; i < data.dim(); ++i) {
if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) {
CAFFE_ENFORCE_EQ( CAFFE_ENFORCE_EQ(
dim, -1, "Currently only possible to slice in 1 dimension."); dim, -1, "Currently only possible to slice in 1 dimension.");
dim = i; dim = i;
@ -154,7 +155,7 @@ bool SliceImplGpu(
size_t src_nbytes = data.nbytes(); size_t src_nbytes = data.nbytes();
size_t dst_nbytes = output->nbytes(); size_t dst_nbytes = output->nbytes();
size_t src_block_size = unit * data.sizes()[dim]; size_t src_block_size = unit * data.size(dim);
size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
size_t src_offset = unit * starts_idx[dim]; size_t src_offset = unit * starts_idx[dim];
@ -187,7 +188,7 @@ bool SliceImplGpu(
size_t dst_nbytes = gdata->nbytes(); size_t dst_nbytes = gdata->nbytes();
size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
size_t dst_block_size = unit * data.sizes()[dim]; size_t dst_block_size = unit * data.size(dim);
size_t dst_offset = unit * starts_idx[dim]; size_t dst_offset = unit * starts_idx[dim];
if (num_blocks == 0 || dst_block_size == 0) { if (num_blocks == 0 || dst_block_size == 0) {

View File

@ -33,23 +33,24 @@ bool SliceImpl(
for (int i = 0; i < data.dim(); ++i) { for (int i = 0; i < data.dim(); ++i) {
if (i >= starts.numel()) { if (i >= starts.numel()) {
starts_idx[i] = 0; starts_idx[i] = 0;
ends_idx[i] = data.sizes()[i]; ends_idx[i] = data.size(i);
dst_sizes[i] = data.size(i);
continue; continue;
} }
if (data.sizes()[i] > 0) { if (data.size(i) > 0) {
auto start = starts_data[i]; auto start = starts_data[i];
auto end = ends_data[i]; auto end = ends_data[i];
if (start < 0) { if (start < 0) {
start = data.sizes()[i] + 1 + start; start = data.size(i) + 1 + start;
} }
if (end < 0) { if (end < 0) {
end = data.sizes()[i] + 1 + end; end = data.size(i) + 1 + end;
} }
if (start > data.sizes()[i]) { if (start > data.size(i)) {
start = data.sizes()[i]; start = data.size(i);
} }
if (end > data.sizes()[i]) { if (end > data.size(i)) {
end = data.sizes()[i]; end = data.size(i);
} }
CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(start, 0);
CAFFE_ENFORCE_GE(end, 0); CAFFE_ENFORCE_GE(end, 0);
@ -78,7 +79,7 @@ bool SliceImpl(
// for now only supports slicing in 1 dimension // for now only supports slicing in 1 dimension
int dim = -1; int dim = -1;
for (int i = 0; i < data.dim(); ++i) { for (int i = 0; i < data.dim(); ++i) {
if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) {
CAFFE_ENFORCE_EQ( CAFFE_ENFORCE_EQ(
dim, -1, "Currently only possible to slice in 1 dimension."); dim, -1, "Currently only possible to slice in 1 dimension.");
dim = i; dim = i;
@ -117,7 +118,7 @@ bool SliceImpl(
size_t src_nbytes = data.nbytes(); size_t src_nbytes = data.nbytes();
size_t dst_nbytes = output->nbytes(); size_t dst_nbytes = output->nbytes();
size_t src_block_size = unit * data.sizes()[dim]; size_t src_block_size = unit * data.size(dim);
size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
size_t src_offset = unit * starts_idx[dim]; size_t src_offset = unit * starts_idx[dim];
@ -155,7 +156,7 @@ bool SliceImpl(
size_t dst_nbytes = gdata->nbytes(); size_t dst_nbytes = gdata->nbytes();
size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
size_t dst_block_size = unit * data.sizes()[dim]; size_t dst_block_size = unit * data.size(dim);
size_t dst_offset = unit * starts_idx[dim]; size_t dst_offset = unit * starts_idx[dim];
if (num_blocks == 0 || dst_block_size == 0) { if (num_blocks == 0 || dst_block_size == 0) {