Implementing negative striding for python lists

ghstack-source-id: c2736c648c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33019
This commit is contained in:
Alexander Golynski 2020-02-11 09:36:59 -08:00
parent 9857d9b4cd
commit 989de7a0f8
8 changed files with 131 additions and 30 deletions

View File

@ -706,6 +706,11 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_
auto strides = self.strides().vec();
// TODO: support negative strides
TORCH_CHECK(step > 0, "slice step must be positive");
// INT64_MAX stands for default value.
if (start == INT64_MAX) {
start = 0;
}
if (start < 0) {
start += sizes[dim];
}

View File

@ -2488,7 +2488,7 @@
device_guard: False
supports_named_tensor: True
- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
- func: slice.Tensor(Tensor(a) self, int dim=0, int start=9223372036854775807, int end=9223372036854775807, int step=1) -> Tensor(a)
variants: function, method
device_guard: False
supports_named_tensor: True

View File

@ -21,9 +21,9 @@ white_list = [
# We export some functions and classes for test_jit.py directly from libtorch.so,
# it's not important to have BC for them
('_TorchScriptTesting.*', datetime.date(9999, 1, 1)),
('split_with_sizes', datetime.date(2020, 2, 1)),
('linear_relu_dynamic_fp16', datetime.date(2020, 2, 5)),
('aten::join', datetime.date(2020, 2, 10)),
('aten::slice', datetime.date(2020, 3, 1)),
]

View File

@ -5908,6 +5908,9 @@ a")
numel = torch.tensor(size).prod().item()
return torch.arange(numel).view(size)
def consec_list(size):
return list(range(size))
def check_indexing(indexing, tensor):
template = dedent("""
def func(x):
@ -5929,6 +5932,17 @@ a")
self._check_code(template.format(indexing), "func", [tensor, value1, value2])
# Torchscript assumes type Tensor by default, so we need this explicit
# declaration.
def check_indexing_list_int(indexing, list):
template = dedent("""
def func(x):
# type: (List[int]) -> Any
return x{}
""")
self._check_code(template.format(indexing), "func", [list])
# basic slices
check_indexing('[0]', consec((3, 3)))
check_indexing('[1]', consec((3, 3), 10))
@ -5981,6 +5995,49 @@ a")
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
# positive striding
check_indexing_list_int('[0]', consec_list(6))
check_indexing_list_int('[1]', consec_list(7))
check_indexing_list_int('[2]', consec_list(8))
check_indexing_list_int('[2]', consec_list(9))
check_indexing_list_int('[-1]', consec_list(10))
check_indexing_list_int('[0:2]', consec_list(11))
check_indexing_list_int('[1:-1]', consec_list(12))
check_indexing_list_int('[-3:-1]', consec_list(13))
check_indexing_list_int('[1:]', consec_list(15))
check_indexing_list_int('[:1]', consec_list(16))
check_indexing_list_int('[:]', consec_list(17))
check_indexing_list_int('[::]', consec_list(0))
check_indexing_list_int('[1000::]', consec_list(0))
check_indexing_list_int('[:1000:]', consec_list(0))
# negative striding
check_indexing_list_int('[::-1]', consec_list(7))
check_indexing_list_int('[:3:-1]', consec_list(7))
check_indexing_list_int('[3::-1]', consec_list(7))
check_indexing_list_int('[1000::-1]', consec_list(7))
check_indexing_list_int('[3:0:-1]', consec_list(7))
check_indexing_list_int('[3:-1000:-1]', consec_list(7))
check_indexing_list_int('[0:0:-1]', consec_list(7))
check_indexing_list_int('[0:-1000:-1]', consec_list(7))
# only step is specified
check_indexing_list_int('[::-1]', consec_list(0))
check_indexing_list_int('[::-1]', consec_list(7))
check_indexing_list_int('[::-2]', consec_list(7))
check_indexing_list_int('[::2]', consec_list(7))
check_indexing_list_int('[::42]', consec_list(7))
check_indexing_list_int('[::-42]', consec_list(7))
check_indexing_list_int('[::42]', consec_list(0))
check_indexing_list_int('[::-42]', consec_list(0))
check_indexing_list_int('[::9223372036854775807]', consec_list(42))
check_indexing_list_int('[::-9223372036854775807]', consec_list(42))
with self.assertRaisesRegex(RuntimeError, "out of bounds"):
check_indexing_list_int('[::-9223372036854775808]', consec_list(42))
with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
check_indexing_list_int('[::0]', consec_list(42))
def test_index_ellipses(self):
vals = [":", 1, None]
for _ in range(100):

View File

@ -773,7 +773,7 @@
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh()
- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
- name: slice.Tensor(Tensor(a) self, int dim=0, int start=9223372036854775807, int end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward(grad, self.sizes(), dim, start, end, step)
- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)

View File

@ -1975,35 +1975,75 @@ int listMulIntRight(Stack& stack) {
return 0;
}
// Stolen (with appropriate modifications) from cpython repo
// Objects/sliceobject.c with comment:
// this is harder to get right than you might think
//
// This adjusts indexes according to python list semantics and returns number
// of elements in the resulting list.
static int64_t PySlice_AdjustIndices(
int64_t length, int64_t* start, int64_t* stop, int64_t step) {
TORCH_CHECK(step != 0, "List slice should have non-zero step")
TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds")
// Comes from PySlice_Unpack.
if (*start == INT64_MAX) {
*start = (step < 0) ? INT64_MAX : 0;
}
if (*stop == INT64_MAX) {
*stop = (step < 0) ? INT64_MIN : INT64_MAX;
}
// Comes from PySlice_AdjustIndices.
if (*start < 0) {
*start += length;
if (*start < 0) {
*start = (step < 0) ? -1 : 0;
}
} else if (*start >= length) {
*start = (step < 0) ? length - 1 : length;
}
if (*stop < 0) {
*stop += length;
if (*stop < 0) {
*stop = (step < 0) ? -1 : 0;
}
} else if (*stop >= length) {
*stop = (step < 0) ? length - 1 : length;
}
if (step < 0) {
if (*stop < *start) {
return (*start - *stop - 1) / (-step) + 1;
}
} else {
if (*start < *stop) {
return (*stop - *start - 1) / step + 1;
}
}
return 0;
}
template <typename T>
int listSlice(Stack& stack) {
int64_t step = pop(stack).to<int64_t>();
int64_t end = pop(stack).to<int64_t>();
int64_t stop = pop(stack).to<int64_t>();
int64_t start = pop(stack).to<int64_t>();
c10::List<T> list = pop(stack).to<c10::List<T>>();
c10::List<T> list = pop(stack).to<c10::List<T>>();
const int64_t list_size = list.size();
// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));
c10::List<T> sliced_list = make_result_list<T>(list.elementType());
if (normalized_end <= normalized_start) {
// early exit if the slice is trivially empty
push(stack, std::move(sliced_list));
return 0;
}
const int64_t num_values =
PySlice_AdjustIndices(list_size, &start, &stop, step);
sliced_list.reserve(num_values);
sliced_list.reserve(normalized_end - normalized_start);
for (auto i = normalized_start; i < normalized_end;) {
int i = start;
for (int j = 0; j < num_values; ++j) {
sliced_list.push_back(list.get(i));
i += step;
}
push(stack, std::move(sliced_list));
return 0;
}
@ -2580,9 +2620,11 @@ RegisterOperators reg2({
"[] b) -> " decl_type "[]", \
listInplaceAdd<c_type::value_type>, \
aliasAnalysisFromSchema()), \
/* INT64_MAX=9223372036854775807 represents unspeficied value, aka Py_None */ \
Operator( \
"aten::slice(" decl_type \
"[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
"[] l, int start=9223372036854775807, \
int end=9223372036854775807, int step=1) -> " decl_type \
"[]", \
listSlice<c_type::value_type>, \
aliasAnalysisFromSchema()), \

View File

@ -2937,17 +2937,14 @@ struct to_ir {
// aten::slice, we should separate it from this function.
if (dim) {
AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
args.emplace_back(dim);
} else {
AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
}
args.emplace_back(loc, "begin", emitExpr(Expr(slice.startOr(0))));
args.emplace_back(loc, "start", emitExpr(Expr(slice.startOr(INT64_MAX))));
const auto has_end = slice.end().present();
if (has_end) {
args.emplace_back(loc, "end", emitExpr(Expr(slice.end().get())));
}
args.emplace_back(loc, "end", emitExpr(Expr(slice.endOr(INT64_MAX))));
if (input->type()->cast<TupleType>()) {
auto has_step = slice.step().present();
if (has_step) {

View File

@ -893,15 +893,15 @@ struct SliceExpr : public Expr {
Maybe<Expr> step() const {
return Maybe<Expr>(subtree(2));
}
Expr startOr(int alternative) const {
Expr startOr(int64_t alternative) const {
const auto startOption = start();
return startOption.present() ? startOption.get() : createInt(alternative);
}
Expr endOr(int alternative) const {
Expr endOr(int64_t alternative) const {
const auto endOption = end();
return endOption.present() ? endOption.get() : createInt(alternative);
}
Expr stepOr(int alternative) const {
Expr stepOr(int64_t alternative) const {
const auto stepOption = step();
return stepOption.present() ? stepOption.get() : createInt(alternative);
}
@ -915,7 +915,7 @@ struct SliceExpr : public Expr {
}
private:
Expr createInt(int value) const {
Expr createInt(int64_t value) const {
return Expr(Const::create(range(), c10::to_string(value)));
}
};