mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implementing negative striding for python lists
ghstack-source-id: c2736c648c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33019
This commit is contained in:
parent
9857d9b4cd
commit
989de7a0f8
|
|
@ -706,6 +706,11 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_
|
|||
auto strides = self.strides().vec();
|
||||
// TODO: support negative strides
|
||||
TORCH_CHECK(step > 0, "slice step must be positive");
|
||||
|
||||
// INT64_MAX stands for default value.
|
||||
if (start == INT64_MAX) {
|
||||
start = 0;
|
||||
}
|
||||
if (start < 0) {
|
||||
start += sizes[dim];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2488,7 +2488,7 @@
|
|||
device_guard: False
|
||||
supports_named_tensor: True
|
||||
|
||||
- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
|
||||
- func: slice.Tensor(Tensor(a) self, int dim=0, int start=9223372036854775807, int end=9223372036854775807, int step=1) -> Tensor(a)
|
||||
variants: function, method
|
||||
device_guard: False
|
||||
supports_named_tensor: True
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ white_list = [
|
|||
# We export some functions and classes for test_jit.py directly from libtorch.so,
|
||||
# it's not important to have BC for them
|
||||
('_TorchScriptTesting.*', datetime.date(9999, 1, 1)),
|
||||
('split_with_sizes', datetime.date(2020, 2, 1)),
|
||||
('linear_relu_dynamic_fp16', datetime.date(2020, 2, 5)),
|
||||
('aten::join', datetime.date(2020, 2, 10)),
|
||||
('aten::slice', datetime.date(2020, 3, 1)),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5908,6 +5908,9 @@ a")
|
|||
numel = torch.tensor(size).prod().item()
|
||||
return torch.arange(numel).view(size)
|
||||
|
||||
def consec_list(size):
|
||||
return list(range(size))
|
||||
|
||||
def check_indexing(indexing, tensor):
|
||||
template = dedent("""
|
||||
def func(x):
|
||||
|
|
@ -5929,6 +5932,17 @@ a")
|
|||
|
||||
self._check_code(template.format(indexing), "func", [tensor, value1, value2])
|
||||
|
||||
# Torchscript assumes type Tensor by default, so we need this explicit
|
||||
# declaration.
|
||||
def check_indexing_list_int(indexing, list):
|
||||
template = dedent("""
|
||||
def func(x):
|
||||
# type: (List[int]) -> Any
|
||||
return x{}
|
||||
""")
|
||||
|
||||
self._check_code(template.format(indexing), "func", [list])
|
||||
|
||||
# basic slices
|
||||
check_indexing('[0]', consec((3, 3)))
|
||||
check_indexing('[1]', consec((3, 3), 10))
|
||||
|
|
@ -5981,6 +5995,49 @@ a")
|
|||
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
|
||||
check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
|
||||
|
||||
# positive striding
|
||||
check_indexing_list_int('[0]', consec_list(6))
|
||||
check_indexing_list_int('[1]', consec_list(7))
|
||||
check_indexing_list_int('[2]', consec_list(8))
|
||||
check_indexing_list_int('[2]', consec_list(9))
|
||||
check_indexing_list_int('[-1]', consec_list(10))
|
||||
check_indexing_list_int('[0:2]', consec_list(11))
|
||||
check_indexing_list_int('[1:-1]', consec_list(12))
|
||||
check_indexing_list_int('[-3:-1]', consec_list(13))
|
||||
check_indexing_list_int('[1:]', consec_list(15))
|
||||
check_indexing_list_int('[:1]', consec_list(16))
|
||||
check_indexing_list_int('[:]', consec_list(17))
|
||||
check_indexing_list_int('[::]', consec_list(0))
|
||||
check_indexing_list_int('[1000::]', consec_list(0))
|
||||
check_indexing_list_int('[:1000:]', consec_list(0))
|
||||
|
||||
# negative striding
|
||||
check_indexing_list_int('[::-1]', consec_list(7))
|
||||
check_indexing_list_int('[:3:-1]', consec_list(7))
|
||||
check_indexing_list_int('[3::-1]', consec_list(7))
|
||||
check_indexing_list_int('[1000::-1]', consec_list(7))
|
||||
check_indexing_list_int('[3:0:-1]', consec_list(7))
|
||||
check_indexing_list_int('[3:-1000:-1]', consec_list(7))
|
||||
check_indexing_list_int('[0:0:-1]', consec_list(7))
|
||||
check_indexing_list_int('[0:-1000:-1]', consec_list(7))
|
||||
|
||||
# only step is specified
|
||||
check_indexing_list_int('[::-1]', consec_list(0))
|
||||
check_indexing_list_int('[::-1]', consec_list(7))
|
||||
check_indexing_list_int('[::-2]', consec_list(7))
|
||||
check_indexing_list_int('[::2]', consec_list(7))
|
||||
check_indexing_list_int('[::42]', consec_list(7))
|
||||
check_indexing_list_int('[::-42]', consec_list(7))
|
||||
check_indexing_list_int('[::42]', consec_list(0))
|
||||
check_indexing_list_int('[::-42]', consec_list(0))
|
||||
check_indexing_list_int('[::9223372036854775807]', consec_list(42))
|
||||
check_indexing_list_int('[::-9223372036854775807]', consec_list(42))
|
||||
with self.assertRaisesRegex(RuntimeError, "out of bounds"):
|
||||
check_indexing_list_int('[::-9223372036854775808]', consec_list(42))
|
||||
with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
|
||||
check_indexing_list_int('[::0]', consec_list(42))
|
||||
|
||||
|
||||
def test_index_ellipses(self):
|
||||
vals = [":", 1, None]
|
||||
for _ in range(100):
|
||||
|
|
|
|||
|
|
@ -773,7 +773,7 @@
|
|||
- name: sinh(Tensor self) -> Tensor
|
||||
self: grad * self.cosh()
|
||||
|
||||
- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
|
||||
- name: slice.Tensor(Tensor(a) self, int dim=0, int start=9223372036854775807, int end=9223372036854775807, int step=1) -> Tensor(a)
|
||||
self: slice_backward(grad, self.sizes(), dim, start, end, step)
|
||||
|
||||
- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
|
||||
|
|
|
|||
|
|
@ -1975,35 +1975,75 @@ int listMulIntRight(Stack& stack) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
// Stolen (with appropriate modifications) from cpython repo
|
||||
// Objects/sliceobject.c with comment:
|
||||
// this is harder to get right than you might think
|
||||
//
|
||||
// This adjusts indexes according to python list semantics and returns number
|
||||
// of elements in the resulting list.
|
||||
static int64_t PySlice_AdjustIndices(
|
||||
int64_t length, int64_t* start, int64_t* stop, int64_t step) {
|
||||
TORCH_CHECK(step != 0, "List slice should have non-zero step")
|
||||
TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds")
|
||||
|
||||
// Comes from PySlice_Unpack.
|
||||
if (*start == INT64_MAX) {
|
||||
*start = (step < 0) ? INT64_MAX : 0;
|
||||
}
|
||||
if (*stop == INT64_MAX) {
|
||||
*stop = (step < 0) ? INT64_MIN : INT64_MAX;
|
||||
}
|
||||
|
||||
// Comes from PySlice_AdjustIndices.
|
||||
if (*start < 0) {
|
||||
*start += length;
|
||||
if (*start < 0) {
|
||||
*start = (step < 0) ? -1 : 0;
|
||||
}
|
||||
} else if (*start >= length) {
|
||||
*start = (step < 0) ? length - 1 : length;
|
||||
}
|
||||
|
||||
if (*stop < 0) {
|
||||
*stop += length;
|
||||
if (*stop < 0) {
|
||||
*stop = (step < 0) ? -1 : 0;
|
||||
}
|
||||
} else if (*stop >= length) {
|
||||
*stop = (step < 0) ? length - 1 : length;
|
||||
}
|
||||
|
||||
if (step < 0) {
|
||||
if (*stop < *start) {
|
||||
return (*start - *stop - 1) / (-step) + 1;
|
||||
}
|
||||
} else {
|
||||
if (*start < *stop) {
|
||||
return (*stop - *start - 1) / step + 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int listSlice(Stack& stack) {
|
||||
int64_t step = pop(stack).to<int64_t>();
|
||||
int64_t end = pop(stack).to<int64_t>();
|
||||
int64_t stop = pop(stack).to<int64_t>();
|
||||
int64_t start = pop(stack).to<int64_t>();
|
||||
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
||||
|
||||
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
||||
const int64_t list_size = list.size();
|
||||
|
||||
// clamp start and end to the bounds of the list
|
||||
const auto normalized_start =
|
||||
std::max((int64_t)0, normalizeIndex(start, list_size));
|
||||
const auto normalized_end =
|
||||
std::min(list_size, normalizeIndex(end, list_size));
|
||||
|
||||
c10::List<T> sliced_list = make_result_list<T>(list.elementType());
|
||||
if (normalized_end <= normalized_start) {
|
||||
// early exit if the slice is trivially empty
|
||||
push(stack, std::move(sliced_list));
|
||||
return 0;
|
||||
}
|
||||
const int64_t num_values =
|
||||
PySlice_AdjustIndices(list_size, &start, &stop, step);
|
||||
sliced_list.reserve(num_values);
|
||||
|
||||
sliced_list.reserve(normalized_end - normalized_start);
|
||||
|
||||
for (auto i = normalized_start; i < normalized_end;) {
|
||||
int i = start;
|
||||
for (int j = 0; j < num_values; ++j) {
|
||||
sliced_list.push_back(list.get(i));
|
||||
i += step;
|
||||
}
|
||||
|
||||
push(stack, std::move(sliced_list));
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -2580,9 +2620,11 @@ RegisterOperators reg2({
|
|||
"[] b) -> " decl_type "[]", \
|
||||
listInplaceAdd<c_type::value_type>, \
|
||||
aliasAnalysisFromSchema()), \
|
||||
/* INT64_MAX=9223372036854775807 represents unspeficied value, aka Py_None */ \
|
||||
Operator( \
|
||||
"aten::slice(" decl_type \
|
||||
"[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
|
||||
"[] l, int start=9223372036854775807, \
|
||||
int end=9223372036854775807, int step=1) -> " decl_type \
|
||||
"[]", \
|
||||
listSlice<c_type::value_type>, \
|
||||
aliasAnalysisFromSchema()), \
|
||||
|
|
|
|||
|
|
@ -2937,17 +2937,14 @@ struct to_ir {
|
|||
// aten::slice, we should separate it from this function.
|
||||
if (dim) {
|
||||
AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
|
||||
|
||||
args.emplace_back(dim);
|
||||
} else {
|
||||
AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
|
||||
}
|
||||
|
||||
args.emplace_back(loc, "begin", emitExpr(Expr(slice.startOr(0))));
|
||||
args.emplace_back(loc, "start", emitExpr(Expr(slice.startOr(INT64_MAX))));
|
||||
const auto has_end = slice.end().present();
|
||||
if (has_end) {
|
||||
args.emplace_back(loc, "end", emitExpr(Expr(slice.end().get())));
|
||||
}
|
||||
args.emplace_back(loc, "end", emitExpr(Expr(slice.endOr(INT64_MAX))));
|
||||
if (input->type()->cast<TupleType>()) {
|
||||
auto has_step = slice.step().present();
|
||||
if (has_step) {
|
||||
|
|
|
|||
|
|
@ -893,15 +893,15 @@ struct SliceExpr : public Expr {
|
|||
Maybe<Expr> step() const {
|
||||
return Maybe<Expr>(subtree(2));
|
||||
}
|
||||
Expr startOr(int alternative) const {
|
||||
Expr startOr(int64_t alternative) const {
|
||||
const auto startOption = start();
|
||||
return startOption.present() ? startOption.get() : createInt(alternative);
|
||||
}
|
||||
Expr endOr(int alternative) const {
|
||||
Expr endOr(int64_t alternative) const {
|
||||
const auto endOption = end();
|
||||
return endOption.present() ? endOption.get() : createInt(alternative);
|
||||
}
|
||||
Expr stepOr(int alternative) const {
|
||||
Expr stepOr(int64_t alternative) const {
|
||||
const auto stepOption = step();
|
||||
return stepOption.present() ? stepOption.get() : createInt(alternative);
|
||||
}
|
||||
|
|
@ -915,7 +915,7 @@ struct SliceExpr : public Expr {
|
|||
}
|
||||
|
||||
private:
|
||||
Expr createInt(int value) const {
|
||||
Expr createInt(int64_t value) const {
|
||||
return Expr(Const::create(range(), c10::to_string(value)));
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user