mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix indexing for overrides. (#49324)
Summary: Fixes https://github.com/pytorch/pytorch/issues/46277 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49324 Reviewed By: mruberry Differential Revision: D25959334 Pulled By: ezyang fbshipit-source-id: bac48b8ffee89d10aa04c004de2b53b4e54a96c2
This commit is contained in:
parent
16faabe7f0
commit
cf1882adeb
|
|
@ -878,5 +878,72 @@ class TestWrapTorchFunction(TestCase):
|
|||
|
||||
self.assertEqual(f(A()), -1)
|
||||
|
||||
class TestIndexing(TestCase):
|
||||
""" Regression tests for gh-46277 """
|
||||
def test_getitem(self):
|
||||
class A:
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
return -1
|
||||
|
||||
t = torch.tensor([5])
|
||||
self.assertEqual(t[A()], -1)
|
||||
self.assertEqual(t, torch.tensor([5]))
|
||||
|
||||
def test_getitem_subclass(self):
|
||||
class A(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
return -1
|
||||
|
||||
t = torch.tensor([5])
|
||||
self.assertEqual(t[A()], -1)
|
||||
self.assertEqual(t[5, A()], -1)
|
||||
self.assertEqual(t, torch.tensor([5]))
|
||||
|
||||
def test_setitem(self):
|
||||
triggered = set()
|
||||
|
||||
class A:
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
triggered.add(func)
|
||||
return -1
|
||||
|
||||
t = torch.tensor([5])
|
||||
t[A()] = 1
|
||||
t[5, A()] = 1
|
||||
self.assertIn(Tensor.__setitem__, triggered)
|
||||
self.assertEqual(t, torch.tensor([5]))
|
||||
|
||||
def test_setitem_val(self):
|
||||
triggered = set()
|
||||
|
||||
class A:
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
triggered.add(func)
|
||||
return -1
|
||||
|
||||
t = torch.tensor([5])
|
||||
t[0] = A()
|
||||
self.assertIn(Tensor.__setitem__, triggered)
|
||||
self.assertEqual(t, torch.tensor([5]))
|
||||
|
||||
def test_setitem_subclass(self):
|
||||
triggered = set()
|
||||
|
||||
class A(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
triggered.add(func)
|
||||
return -1
|
||||
|
||||
t = torch.tensor([5])
|
||||
t[A()] = 1
|
||||
t[5, A()] = 1
|
||||
self.assertIn(Tensor.__setitem__, triggered)
|
||||
self.assertEqual(t, torch.tensor([5]))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -55,10 +55,12 @@ Py_ssize_t THPVariable_length(PyObject* self) {
|
|||
|
||||
static inline int64_t count_specified_dimensions(PyObject* index) {
|
||||
// Count the number of indexed dimensions (everything but ellipsis and None)
|
||||
// -1 is a sentinel for __torch_function__
|
||||
int64_t count = 0;
|
||||
auto size = PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
for (Py_ssize_t i = 0; i < size; i++) {
|
||||
PyObject* obj = PyTuple_GET_ITEM(index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
if (!THPVariable_CheckExact(obj) && check_has_torch_function(obj)) return -1;
|
||||
if (THPVariable_Check(obj)) {
|
||||
auto& var = reinterpret_cast<THPVariable*>(obj)->cdata;
|
||||
const auto& var_scalar_type = var.scalar_type();
|
||||
|
|
@ -135,10 +137,10 @@ static inline Variable applySlicing(
|
|||
variable_list& outIndices,
|
||||
bool is_tracing,
|
||||
const at::Device& self_device,
|
||||
const IntArrayRef& self_sizes) {
|
||||
const IntArrayRef& self_sizes,
|
||||
int64_t specified_dims) {
|
||||
int64_t size = PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
int64_t dim = 0;
|
||||
int64_t specified_dims = count_specified_dimensions(index);
|
||||
|
||||
if (specified_dims > (int64_t)self_sizes.size()) {
|
||||
throw IndexError("too many indices for tensor of dimension %d", (int)(self_sizes.size()));
|
||||
|
|
@ -267,9 +269,8 @@ static inline THPObjectPtr wrapTuple(PyObject* index) {
|
|||
// indexing is needed, it calls C++ `at::indexing::dispatch_index`.
|
||||
PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function(self)) {
|
||||
py::tuple args_ = py::make_tuple(py::handle(index));
|
||||
return handle_torch_function(self, "__getitem__", args_.ptr());
|
||||
if (!THPVariable_CheckExact(self) && check_has_torch_function(self)) {
|
||||
return handle_torch_function_indexing(self, index);
|
||||
}
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
OptionalDeviceGuard device_guard(device_of(self_));
|
||||
|
|
@ -311,8 +312,12 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
|
|||
THPObjectPtr holder = wrapTuple(index);
|
||||
|
||||
variable_list variableIndices;
|
||||
int64_t specified_dims = count_specified_dimensions(holder.get());
|
||||
if (specified_dims == -1) {
|
||||
return handle_torch_function_indexing(self, index);
|
||||
}
|
||||
Variable sliced = applySlicing(
|
||||
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_.device(), self_.sizes());
|
||||
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_.device(), self_.sizes(), specified_dims);
|
||||
if (variableIndices.empty()) {
|
||||
if (sliced.is_same(self_)) {
|
||||
// ensure we return a shallow copy for things like x[...]
|
||||
|
|
@ -344,9 +349,11 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
|
|||
if (py_value == nullptr) {
|
||||
throw TypeError("Tensor does not support deleting items");
|
||||
}
|
||||
if (check_has_torch_function(self)) {
|
||||
py::tuple args_ = py::make_tuple(py::handle(index), py::handle(py_value));
|
||||
py::object ret = py::reinterpret_steal<py::object>(handle_torch_function(self, "__setitem__", args_.ptr()));
|
||||
if ((!THPVariable_CheckExact(self) && check_has_torch_function(self)) ||
|
||||
(!THPVariable_CheckExact(py_value) && check_has_torch_function(py_value))) {
|
||||
py::object ret = py::reinterpret_steal<py::object>(
|
||||
handle_torch_function_indexing(self, index, py_value)
|
||||
);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -406,8 +413,15 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
|
|||
THPObjectPtr holder = wrapTuple(index);
|
||||
|
||||
variable_list variableIndices;
|
||||
int64_t specified_dims = count_specified_dimensions(holder.get());
|
||||
if (specified_dims == -1) {
|
||||
py::object val = py::reinterpret_steal<py::object>(
|
||||
handle_torch_function_indexing(self, index, py_value)
|
||||
);
|
||||
return 0;
|
||||
}
|
||||
Variable sliced = applySlicing(
|
||||
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_device, self_.sizes());
|
||||
self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_device, self_.sizes(), specified_dims);
|
||||
if (variableIndices.empty()) {
|
||||
at::indexing::copy_to(sliced, value);
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -262,6 +262,28 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb
|
|||
return handle_torch_function(r, nullptr, args, kwargs, torch_api, module_name);
|
||||
}
|
||||
|
||||
auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* val) -> PyObject* {
|
||||
const char *func_name = (val == nullptr) ? "__getitem__" : "__setitem__";
|
||||
py::object index_tup;
|
||||
if (PyTuple_Check(index)) {
|
||||
index_tup = py::reinterpret_borrow<py::object>(index);
|
||||
}
|
||||
else {
|
||||
index_tup = py::make_tuple(py::handle(index));
|
||||
}
|
||||
std::vector<py::handle> overridable_args;
|
||||
is_tensor_and_append_overloaded(self, &overridable_args);
|
||||
Py_ssize_t size = PyTuple_GET_SIZE(index_tup.ptr());
|
||||
for (Py_ssize_t i = 0; i < size; i++) {
|
||||
PyObject *obj = PyTuple_GetItem(index_tup.ptr(), i);
|
||||
is_tensor_and_append_overloaded(obj, &overridable_args);
|
||||
}
|
||||
if (val != nullptr) is_tensor_and_append_overloaded(val, &overridable_args);
|
||||
py::object func = PyObject_FastGetAttrString(THPVariableClass, (char *)func_name);
|
||||
py::object args = (val == nullptr) ? py::make_tuple(py::handle(self), py::handle(index)) : py::make_tuple(py::handle(self), py::handle(index), py::handle(val));
|
||||
return handle_torch_function_no_python_arg_parser(overridable_args, args.ptr(), nullptr, func_name, func.ptr(), "torch.Tensor");
|
||||
}
|
||||
|
||||
/*
|
||||
* obj has a __torch_function__ implementation and may either be a
|
||||
* subclass of Tensor or a Tensor-like duck type. We may need to
|
||||
|
|
|
|||
|
|
@ -743,6 +743,9 @@ auto handle_torch_function_getter(THPVariable* self, const std::string& property
|
|||
// Used for setters of Tensor properties.
|
||||
auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int;
|
||||
|
||||
// Used for __getitem__ and __setitem__
|
||||
auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* val=nullptr) -> PyObject*;
|
||||
|
||||
/*
|
||||
* Check if the input obj is Tensor type, including its subclass, or overloaded
|
||||
* type. If the type defines __torch_function__, it also returns true.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user