Implement subclass priority for __torch_dispatch__ (#63411)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63411

In order to get this behavior, you have to use append_overloaded,
which I forgot to use in the previous implementation.  I exposed
an internal helper function which is more appropriate for dispatch
to Python where we know that an argument is definitely a Tensor (and
this test no longer needs to be done).

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D30374489

Pulled By: ezyang

fbshipit-source-id: 43b08c00d1958c9b26d82a025d19f0b67bb85590
This commit is contained in:
Edward Yang 2021-08-18 07:45:45 -07:00 committed by Facebook GitHub Bot
parent 061b36e2f5
commit c508433617
3 changed files with 47 additions and 3 deletions

View File

@ -246,6 +246,39 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
x.data.add_(2)
self.assertEqual(cur_vc, x._version)
def test_subclass_priority(self) -> None:
class ErrorA(RuntimeError):
pass
class ErrorB(RuntimeError):
pass
# The big tests for code coverage are test_precedence_semantics in
# test_overrides.py; this is just to make sure it is wired up at all
# correctly for __torch_dispatch__
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorA
class B(A):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorB
self.assertRaises(ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1))))
self.assertRaises(ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1))))
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1))))
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1))))
def test_format(self) -> None:
x = LoggingTensor(torch.ones(1))
s1 = str(x)

View File

@ -1562,7 +1562,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
if (ivalue.isTensor()) {
const auto& tensor = ivalue.toTensor();
if (isPythonTensor(tensor)) {
overloaded_args.emplace_back(py::cast(tensor));
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
}
} else if (ivalue.isList()) {
const auto& list = ivalue.toListRef();
@ -1571,7 +1571,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
if (nv.isTensor()) {
const auto& tensor = nv.toTensor();
if (isPythonTensor(tensor)) {
overloaded_args.emplace_back(py::cast(tensor));
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
}
}
}
@ -1620,7 +1620,8 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
// TODO: fix the constness of target
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
overloaded_args.emplace_back(self_p);
TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
append_overloaded_arg(&overloaded_args, self_p.ptr());
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());

View File

@ -810,4 +810,14 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
*/
bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args, int argnum, bool throw_error);
/* Given an argument that is definitely a tensor and is definitely overloaded,
* append it to the overloaded arguments list. Use this instead of
* is_tensor_and_append_overloaded in situations where you have a PyObject
* and you know it definitely is a Tensor and it is definitely overloaded.
*
* 'overloaded_args': the vector to append the overloaded args
* 'obj': the input tensor that is overloaded
*/
void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj);
} // namespace torch