Fix indexing for overrides. (#49324)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/46277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49324

Reviewed By: mruberry

Differential Revision: D25959334

Pulled By: ezyang

fbshipit-source-id: bac48b8ffee89d10aa04c004de2b53b4e54a96c2
This commit is contained in:
Hameer Abbasi 2021-01-20 11:23:04 -08:00 committed by Facebook GitHub Bot
parent 16faabe7f0
commit cf1882adeb
4 changed files with 116 additions and 10 deletions

View File

@ -878,5 +878,72 @@ class TestWrapTorchFunction(TestCase):
self.assertEqual(f(A()), -1)
class TestIndexing(TestCase):
""" Regression tests for gh-46277 """
def test_getitem(self):
class A:
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
return -1
t = torch.tensor([5])
self.assertEqual(t[A()], -1)
self.assertEqual(t, torch.tensor([5]))
def test_getitem_subclass(self):
class A(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
return -1
t = torch.tensor([5])
self.assertEqual(t[A()], -1)
self.assertEqual(t[5, A()], -1)
self.assertEqual(t, torch.tensor([5]))
def test_setitem(self):
triggered = set()
class A:
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
triggered.add(func)
return -1
t = torch.tensor([5])
t[A()] = 1
t[5, A()] = 1
self.assertIn(Tensor.__setitem__, triggered)
self.assertEqual(t, torch.tensor([5]))
def test_setitem_val(self):
triggered = set()
class A:
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
triggered.add(func)
return -1
t = torch.tensor([5])
t[0] = A()
self.assertIn(Tensor.__setitem__, triggered)
self.assertEqual(t, torch.tensor([5]))
def test_setitem_subclass(self):
triggered = set()
class A(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
triggered.add(func)
return -1
t = torch.tensor([5])
t[A()] = 1
t[5, A()] = 1
self.assertIn(Tensor.__setitem__, triggered)
self.assertEqual(t, torch.tensor([5]))
if __name__ == '__main__':
run_tests()

View File

@ -55,10 +55,12 @@ Py_ssize_t THPVariable_length(PyObject* self) {
static inline int64_t count_specified_dimensions(PyObject* index) {
// Count the number of indexed dimensions (everything but ellipsis and None)
// -1 is a sentinel for __torch_function__
int64_t count = 0;
auto size = PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
for (Py_ssize_t i = 0; i < size; i++) {
PyObject* obj = PyTuple_GET_ITEM(index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
if (!THPVariable_CheckExact(obj) && check_has_torch_function(obj)) return -1;
if (THPVariable_Check(obj)) {
auto& var = reinterpret_cast<THPVariable*>(obj)->cdata;
const auto& var_scalar_type = var.scalar_type();
@ -135,10 +137,10 @@ static inline Variable applySlicing(
variable_list& outIndices,
bool is_tracing,
const at::Device& self_device,
const IntArrayRef& self_sizes) {
const IntArrayRef& self_sizes,
int64_t specified_dims) {
int64_t size = PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
int64_t dim = 0;
int64_t specified_dims = count_specified_dimensions(index);
if (specified_dims > (int64_t)self_sizes.size()) {
throw IndexError("too many indices for tensor of dimension %d", (int)(self_sizes.size()));
@ -267,9 +269,8 @@ static inline THPObjectPtr wrapTuple(PyObject* index) {
// indexing is needed, it calls C++ `at::indexing::dispatch_index`.
PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
py::tuple args_ = py::make_tuple(py::handle(index));
return handle_torch_function(self, "__getitem__", args_.ptr());
if (!THPVariable_CheckExact(self) && check_has_torch_function(self)) {
return handle_torch_function_indexing(self, index);
}
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
@ -311,8 +312,12 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
THPObjectPtr holder = wrapTuple(index);
variable_list variableIndices;
int64_t specified_dims = count_specified_dimensions(holder.get());
if (specified_dims == -1) {
return handle_torch_function_indexing(self, index);
}
Variable sliced = applySlicing(
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_.device(), self_.sizes());
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_.device(), self_.sizes(), specified_dims);
if (variableIndices.empty()) {
if (sliced.is_same(self_)) {
// ensure we return a shallow copy for things like x[...]
@ -344,9 +349,11 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
if (py_value == nullptr) {
throw TypeError("Tensor does not support deleting items");
}
if (check_has_torch_function(self)) {
py::tuple args_ = py::make_tuple(py::handle(index), py::handle(py_value));
py::object ret = py::reinterpret_steal<py::object>(handle_torch_function(self, "__setitem__", args_.ptr()));
if ((!THPVariable_CheckExact(self) && check_has_torch_function(self)) ||
(!THPVariable_CheckExact(py_value) && check_has_torch_function(py_value))) {
py::object ret = py::reinterpret_steal<py::object>(
handle_torch_function_indexing(self, index, py_value)
);
return 0;
}
@ -406,8 +413,15 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
THPObjectPtr holder = wrapTuple(index);
variable_list variableIndices;
int64_t specified_dims = count_specified_dimensions(holder.get());
if (specified_dims == -1) {
py::object val = py::reinterpret_steal<py::object>(
handle_torch_function_indexing(self, index, py_value)
);
return 0;
}
Variable sliced = applySlicing(
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_device, self_.sizes());
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_device, self_.sizes(), specified_dims);
if (variableIndices.empty()) {
at::indexing::copy_to(sliced, value);
return 0;

View File

@ -262,6 +262,28 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb
return handle_torch_function(r, nullptr, args, kwargs, torch_api, module_name);
}
auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* val) -> PyObject* {
const char *func_name = (val == nullptr) ? "__getitem__" : "__setitem__";
py::object index_tup;
if (PyTuple_Check(index)) {
index_tup = py::reinterpret_borrow<py::object>(index);
}
else {
index_tup = py::make_tuple(py::handle(index));
}
std::vector<py::handle> overridable_args;
is_tensor_and_append_overloaded(self, &overridable_args);
Py_ssize_t size = PyTuple_GET_SIZE(index_tup.ptr());
for (Py_ssize_t i = 0; i < size; i++) {
PyObject *obj = PyTuple_GetItem(index_tup.ptr(), i);
is_tensor_and_append_overloaded(obj, &overridable_args);
}
if (val != nullptr) is_tensor_and_append_overloaded(val, &overridable_args);
py::object func = PyObject_FastGetAttrString(THPVariableClass, (char *)func_name);
py::object args = (val == nullptr) ? py::make_tuple(py::handle(self), py::handle(index)) : py::make_tuple(py::handle(self), py::handle(index), py::handle(val));
return handle_torch_function_no_python_arg_parser(overridable_args, args.ptr(), nullptr, func_name, func.ptr(), "torch.Tensor");
}
/*
* obj has a __torch_function__ implementation and may either be a
* subclass of Tensor or a Tensor-like duck type. We may need to

View File

@ -743,6 +743,9 @@ auto handle_torch_function_getter(THPVariable* self, const std::string& property
// Used for setters of Tensor properties.
auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int;
// Used for __getitem__ and __setitem__
auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* val=nullptr) -> PyObject*;
/*
* Check if the input obj is Tensor type, including its subclass, or overloaded
* type. If the type defines __torch_function__, it also returns true.