mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[caffe2] SliceOp axes indexing fixes. (#45432)
Summary: Fixes https://github.com/pytorch/pytorch/issues/45431 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45432 Reviewed By: albanD Differential Revision: D24132547 Pulled By: dzhulgakov fbshipit-source-id: d67f7a92d806fb8ac8fc8f522b251d3a8fb83037
This commit is contained in:
parent
3fbddb92b1
commit
c1af91a13a
|
|
@ -17,7 +17,7 @@ Produces a slice of the input tensor.
|
||||||
|
|
||||||
- Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments.
|
- Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments.
|
||||||
|
|
||||||
- If a negative value is passed for any of the start or end indices, it represents the number of elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element).
|
- If a negative value is passed for any of the start or end indices, it represents |value| - 1 elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element).
|
||||||
|
|
||||||
Github Links:
|
Github Links:
|
||||||
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc
|
||||||
|
|
@ -67,11 +67,11 @@ Y:
|
||||||
.Input(
|
.Input(
|
||||||
1,
|
1,
|
||||||
"starts",
|
"starts",
|
||||||
"(*Tensor`<int>`*): 1D tensor of start-indices for each dimension of data")
|
"(*Tensor`<int>`*): 1D tensor of start-indices for each dimension of data (dimensions following the sliced one might be omitted)")
|
||||||
.Input(
|
.Input(
|
||||||
2,
|
2,
|
||||||
"ends",
|
"ends",
|
||||||
"(*Tensor`<int>`*): 1D tensor of end-indices for each dimension of data")
|
"(*Tensor`<int>`*): 1D tensor of end-indices for each dimension of data (dimensions following the sliced one might be omitted)")
|
||||||
.Arg("starts", "(*Tuple(int)*): list of starting indices")
|
.Arg("starts", "(*Tuple(int)*): list of starting indices")
|
||||||
.Arg("ends", "(*Tuple(int)*): list of ending indices")
|
.Arg("ends", "(*Tuple(int)*): list of ending indices")
|
||||||
.TensorInferenceFunction([](const OperatorDef& def,
|
.TensorInferenceFunction([](const OperatorDef& def,
|
||||||
|
|
@ -90,9 +90,10 @@ Y:
|
||||||
|
|
||||||
for (int i = 0; i < data.dims_size(); ++i) {
|
for (int i = 0; i < data.dims_size(); ++i) {
|
||||||
if (i >= starts.size()) {
|
if (i >= starts.size()) {
|
||||||
|
dst_sizes[i] = data.dims(i);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (data.dims_size() > 0) {
|
if (data.dims(i) > 0) {
|
||||||
auto start = starts[i];
|
auto start = starts[i];
|
||||||
auto end = ends[i];
|
auto end = ends[i];
|
||||||
if (start < 0) {
|
if (start < 0) {
|
||||||
|
|
|
||||||
|
|
@ -74,22 +74,23 @@ bool SliceImplGpu(
|
||||||
if (i >= starts.numel()) {
|
if (i >= starts.numel()) {
|
||||||
starts_idx[i] = 0;
|
starts_idx[i] = 0;
|
||||||
ends_idx[i] = data.size(i);
|
ends_idx[i] = data.size(i);
|
||||||
|
dst_sizes[i] = data.size(i);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (data.size(i) > 0) {
|
if (data.size(i) > 0) {
|
||||||
auto start = starts_data[i];
|
auto start = starts_data[i];
|
||||||
auto end = ends_data[i];
|
auto end = ends_data[i];
|
||||||
if (start < 0) {
|
if (start < 0) {
|
||||||
start = data.sizes()[i] + 1 + start;
|
start = data.size(i) + 1 + start;
|
||||||
}
|
}
|
||||||
if (end < 0) {
|
if (end < 0) {
|
||||||
end = data.sizes()[i] + 1 + end;
|
end = data.size(i) + 1 + end;
|
||||||
}
|
}
|
||||||
if (start > data.sizes()[i]) {
|
if (start > data.size(i)) {
|
||||||
start = data.sizes()[i];
|
start = data.size(i);
|
||||||
}
|
}
|
||||||
if (end > data.sizes()[i]) {
|
if (end > data.size(i)) {
|
||||||
end = data.sizes()[i];
|
end = data.size(i);
|
||||||
}
|
}
|
||||||
CAFFE_ENFORCE_GE(start, 0);
|
CAFFE_ENFORCE_GE(start, 0);
|
||||||
CAFFE_ENFORCE_GE(end, 0);
|
CAFFE_ENFORCE_GE(end, 0);
|
||||||
|
|
@ -115,7 +116,7 @@ bool SliceImplGpu(
|
||||||
// for now only supports slicing in 1 dimension
|
// for now only supports slicing in 1 dimension
|
||||||
int dim = -1;
|
int dim = -1;
|
||||||
for (int i = 0; i < data.dim(); ++i) {
|
for (int i = 0; i < data.dim(); ++i) {
|
||||||
if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) {
|
if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) {
|
||||||
CAFFE_ENFORCE_EQ(
|
CAFFE_ENFORCE_EQ(
|
||||||
dim, -1, "Currently only possible to slice in 1 dimension.");
|
dim, -1, "Currently only possible to slice in 1 dimension.");
|
||||||
dim = i;
|
dim = i;
|
||||||
|
|
@ -154,7 +155,7 @@ bool SliceImplGpu(
|
||||||
size_t src_nbytes = data.nbytes();
|
size_t src_nbytes = data.nbytes();
|
||||||
size_t dst_nbytes = output->nbytes();
|
size_t dst_nbytes = output->nbytes();
|
||||||
|
|
||||||
size_t src_block_size = unit * data.sizes()[dim];
|
size_t src_block_size = unit * data.size(dim);
|
||||||
size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
||||||
size_t src_offset = unit * starts_idx[dim];
|
size_t src_offset = unit * starts_idx[dim];
|
||||||
|
|
||||||
|
|
@ -187,7 +188,7 @@ bool SliceImplGpu(
|
||||||
size_t dst_nbytes = gdata->nbytes();
|
size_t dst_nbytes = gdata->nbytes();
|
||||||
|
|
||||||
size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
||||||
size_t dst_block_size = unit * data.sizes()[dim];
|
size_t dst_block_size = unit * data.size(dim);
|
||||||
size_t dst_offset = unit * starts_idx[dim];
|
size_t dst_offset = unit * starts_idx[dim];
|
||||||
|
|
||||||
if (num_blocks == 0 || dst_block_size == 0) {
|
if (num_blocks == 0 || dst_block_size == 0) {
|
||||||
|
|
|
||||||
|
|
@ -33,23 +33,24 @@ bool SliceImpl(
|
||||||
for (int i = 0; i < data.dim(); ++i) {
|
for (int i = 0; i < data.dim(); ++i) {
|
||||||
if (i >= starts.numel()) {
|
if (i >= starts.numel()) {
|
||||||
starts_idx[i] = 0;
|
starts_idx[i] = 0;
|
||||||
ends_idx[i] = data.sizes()[i];
|
ends_idx[i] = data.size(i);
|
||||||
|
dst_sizes[i] = data.size(i);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (data.sizes()[i] > 0) {
|
if (data.size(i) > 0) {
|
||||||
auto start = starts_data[i];
|
auto start = starts_data[i];
|
||||||
auto end = ends_data[i];
|
auto end = ends_data[i];
|
||||||
if (start < 0) {
|
if (start < 0) {
|
||||||
start = data.sizes()[i] + 1 + start;
|
start = data.size(i) + 1 + start;
|
||||||
}
|
}
|
||||||
if (end < 0) {
|
if (end < 0) {
|
||||||
end = data.sizes()[i] + 1 + end;
|
end = data.size(i) + 1 + end;
|
||||||
}
|
}
|
||||||
if (start > data.sizes()[i]) {
|
if (start > data.size(i)) {
|
||||||
start = data.sizes()[i];
|
start = data.size(i);
|
||||||
}
|
}
|
||||||
if (end > data.sizes()[i]) {
|
if (end > data.size(i)) {
|
||||||
end = data.sizes()[i];
|
end = data.size(i);
|
||||||
}
|
}
|
||||||
CAFFE_ENFORCE_GE(start, 0);
|
CAFFE_ENFORCE_GE(start, 0);
|
||||||
CAFFE_ENFORCE_GE(end, 0);
|
CAFFE_ENFORCE_GE(end, 0);
|
||||||
|
|
@ -78,7 +79,7 @@ bool SliceImpl(
|
||||||
// for now only supports slicing in 1 dimension
|
// for now only supports slicing in 1 dimension
|
||||||
int dim = -1;
|
int dim = -1;
|
||||||
for (int i = 0; i < data.dim(); ++i) {
|
for (int i = 0; i < data.dim(); ++i) {
|
||||||
if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) {
|
if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) {
|
||||||
CAFFE_ENFORCE_EQ(
|
CAFFE_ENFORCE_EQ(
|
||||||
dim, -1, "Currently only possible to slice in 1 dimension.");
|
dim, -1, "Currently only possible to slice in 1 dimension.");
|
||||||
dim = i;
|
dim = i;
|
||||||
|
|
@ -117,7 +118,7 @@ bool SliceImpl(
|
||||||
size_t src_nbytes = data.nbytes();
|
size_t src_nbytes = data.nbytes();
|
||||||
size_t dst_nbytes = output->nbytes();
|
size_t dst_nbytes = output->nbytes();
|
||||||
|
|
||||||
size_t src_block_size = unit * data.sizes()[dim];
|
size_t src_block_size = unit * data.size(dim);
|
||||||
size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
||||||
size_t src_offset = unit * starts_idx[dim];
|
size_t src_offset = unit * starts_idx[dim];
|
||||||
|
|
||||||
|
|
@ -155,7 +156,7 @@ bool SliceImpl(
|
||||||
size_t dst_nbytes = gdata->nbytes();
|
size_t dst_nbytes = gdata->nbytes();
|
||||||
|
|
||||||
size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
||||||
size_t dst_block_size = unit * data.sizes()[dim];
|
size_t dst_block_size = unit * data.size(dim);
|
||||||
size_t dst_offset = unit * starts_idx[dim];
|
size_t dst_offset = unit * starts_idx[dim];
|
||||||
|
|
||||||
if (num_blocks == 0 || dst_block_size == 0) {
|
if (num_blocks == 0 || dst_block_size == 0) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user