diff --git a/caffe2/operators/slice_op.cc b/caffe2/operators/slice_op.cc index 7acf854ba9d..f9fd3930326 100644 --- a/caffe2/operators/slice_op.cc +++ b/caffe2/operators/slice_op.cc @@ -17,7 +17,7 @@ Produces a slice of the input tensor. - Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments. -- If a negative value is passed for any of the start or end indices, it represents the number of elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element). +- If a negative value is passed for any of the start or end indices, it represents |value| - 1 elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element). Github Links: - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc @@ -67,11 +67,11 @@ Y: .Input( 1, "starts", - "(*Tensor``*): 1D tensor of start-indices for each dimension of data") + "(*Tensor``*): 1D tensor of start-indices for each dimension of data (dimensions following the sliced one might be omitted)") .Input( 2, "ends", - "(*Tensor``*): 1D tensor of end-indices for each dimension of data") + "(*Tensor``*): 1D tensor of end-indices for each dimension of data (dimensions following the sliced one might be omitted)") .Arg("starts", "(*Tuple(int)*): list of starting indices") .Arg("ends", "(*Tuple(int)*): list of ending indices") .TensorInferenceFunction([](const OperatorDef& def, @@ -90,9 +90,10 @@ Y: for (int i = 0; i < data.dims_size(); ++i) { if (i >= starts.size()) { + dst_sizes[i] = data.dims(i); continue; } - if (data.dims_size() > 0) { + if (data.dims(i) > 0) { auto start = starts[i]; auto end = ends[i]; if (start < 0) { diff --git a/caffe2/operators/slice_op.cu b/caffe2/operators/slice_op.cu index 7a843fee3a5..184385310c9 100644 --- a/caffe2/operators/slice_op.cu +++ b/caffe2/operators/slice_op.cu @@ -74,22 +74,23 @@ bool SliceImplGpu( if (i >= starts.numel()) { starts_idx[i] = 0; ends_idx[i] = data.size(i); + dst_sizes[i] = data.size(i); continue; } if (data.size(i) > 0) { auto start = starts_data[i]; auto end = ends_data[i]; if (start < 0) { - start = data.sizes()[i] + 1 + start; + start = data.size(i) + 1 + start; } if (end < 0) { - end = data.sizes()[i] + 1 + end; + end = data.size(i) + 1 + end; } - if (start > data.sizes()[i]) { - start = data.sizes()[i]; + if (start > data.size(i)) { + start = data.size(i); } - if (end > data.sizes()[i]) { - end = data.sizes()[i]; + if (end > data.size(i)) { + end = data.size(i); } CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(end, 0); @@ -115,7 +116,7 @@ bool SliceImplGpu( // for now only supports slicing in 1 dimension int dim = -1; for (int i = 0; i < data.dim(); ++i) { - if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { + if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) { CAFFE_ENFORCE_EQ( dim, -1, "Currently only possible to slice in 1 dimension."); dim = i; @@ -154,7 +155,7 @@ bool SliceImplGpu( size_t src_nbytes = data.nbytes(); size_t dst_nbytes = output->nbytes(); - size_t src_block_size = unit * data.sizes()[dim]; + size_t src_block_size = unit * data.size(dim); size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_offset = unit * starts_idx[dim]; @@ -187,7 +188,7 @@ bool SliceImplGpu( size_t dst_nbytes = gdata->nbytes(); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); - size_t dst_block_size = unit * data.sizes()[dim]; + size_t dst_block_size = unit * data.size(dim); size_t dst_offset = unit * starts_idx[dim]; if (num_blocks == 0 || dst_block_size == 0) { diff --git a/caffe2/operators/slice_op.h b/caffe2/operators/slice_op.h index 8d1990e54c3..9706472315b 100644 --- a/caffe2/operators/slice_op.h +++ b/caffe2/operators/slice_op.h @@ -33,23 +33,24 @@ bool SliceImpl( for (int i = 0; i < data.dim(); ++i) { if (i >= starts.numel()) { starts_idx[i] = 0; - ends_idx[i] = data.sizes()[i]; + ends_idx[i] = data.size(i); + dst_sizes[i] = data.size(i); continue; } - if (data.sizes()[i] > 0) { + if (data.size(i) > 0) { auto start = starts_data[i]; auto end = ends_data[i]; if (start < 0) { - start = data.sizes()[i] + 1 + start; + start = data.size(i) + 1 + start; } if (end < 0) { - end = data.sizes()[i] + 1 + end; + end = data.size(i) + 1 + end; } - if (start > data.sizes()[i]) { - start = data.sizes()[i]; + if (start > data.size(i)) { + start = data.size(i); } - if (end > data.sizes()[i]) { - end = data.sizes()[i]; + if (end > data.size(i)) { + end = data.size(i); } CAFFE_ENFORCE_GE(start, 0); CAFFE_ENFORCE_GE(end, 0); @@ -78,7 +79,7 @@ bool SliceImpl( // for now only supports slicing in 1 dimension int dim = -1; for (int i = 0; i < data.dim(); ++i) { - if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) { + if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) { CAFFE_ENFORCE_EQ( dim, -1, "Currently only possible to slice in 1 dimension."); dim = i; @@ -117,7 +118,7 @@ bool SliceImpl( size_t src_nbytes = data.nbytes(); size_t dst_nbytes = output->nbytes(); - size_t src_block_size = unit * data.sizes()[dim]; + size_t src_block_size = unit * data.size(dim); size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]); size_t src_offset = unit * starts_idx[dim]; @@ -155,7 +156,7 @@ bool SliceImpl( size_t dst_nbytes = gdata->nbytes(); size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]); - size_t dst_block_size = unit * data.sizes()[dim]; + size_t dst_block_size = unit * data.size(dim); size_t dst_offset = unit * starts_idx[dim]; if (num_blocks == 0 || dst_block_size == 0) {