mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implement subclass priority for __torch_dispatch__ (#63411)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63411 In order to get this behavior, you have to use append_overloaded, which I forgot to use in the previous implementation. I exposed an internal helper function which is more appropriate for dispatch to Python where we know that an argument is definitely a Tensor (and this test no longer needs to be done). Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D30374489 Pulled By: ezyang fbshipit-source-id: 43b08c00d1958c9b26d82a025d19f0b67bb85590
This commit is contained in:
parent
061b36e2f5
commit
c508433617
|
|
@ -246,6 +246,39 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
|
|||
x.data.add_(2)
|
||||
self.assertEqual(cur_vc, x._version)
|
||||
|
||||
def test_subclass_priority(self) -> None:
|
||||
class ErrorA(RuntimeError):
|
||||
pass
|
||||
|
||||
class ErrorB(RuntimeError):
|
||||
pass
|
||||
|
||||
# The big tests for code coverage are test_precedence_semantics in
|
||||
# test_overrides.py; this is just to make sure it is wired up at all
|
||||
# correctly for __torch_dispatch__
|
||||
class A(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
raise ErrorA
|
||||
|
||||
class B(A):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
raise ErrorB
|
||||
|
||||
self.assertRaises(ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1))))
|
||||
self.assertRaises(ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1))))
|
||||
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1))))
|
||||
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1))))
|
||||
|
||||
def test_format(self) -> None:
|
||||
x = LoggingTensor(torch.ones(1))
|
||||
s1 = str(x)
|
||||
|
|
|
|||
|
|
@ -1562,7 +1562,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
|
|||
if (ivalue.isTensor()) {
|
||||
const auto& tensor = ivalue.toTensor();
|
||||
if (isPythonTensor(tensor)) {
|
||||
overloaded_args.emplace_back(py::cast(tensor));
|
||||
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
|
||||
}
|
||||
} else if (ivalue.isList()) {
|
||||
const auto& list = ivalue.toListRef();
|
||||
|
|
@ -1571,7 +1571,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
|
|||
if (nv.isTensor()) {
|
||||
const auto& tensor = nv.toTensor();
|
||||
if (isPythonTensor(tensor)) {
|
||||
overloaded_args.emplace_back(py::cast(tensor));
|
||||
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1620,7 +1620,8 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
|
|||
// TODO: fix the constness of target
|
||||
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
|
||||
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
|
||||
overloaded_args.emplace_back(self_p);
|
||||
TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
|
||||
append_overloaded_arg(&overloaded_args, self_p.ptr());
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
|
||||
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
|
||||
|
||||
|
|
|
|||
|
|
@ -810,4 +810,14 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
|
|||
*/
|
||||
bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args, int argnum, bool throw_error);
|
||||
|
||||
/* Given an argument that is definitely a tensor and is definitely overloaded,
|
||||
* append it to the overloaded arguments list. Use this instead of
|
||||
* is_tensor_and_append_overloaded in situations where you have a PyObject
|
||||
* and you know it definitely is a Tensor and it is definitely overloaded.
|
||||
*
|
||||
* 'overloaded_args': the vector to append the overloaded args
|
||||
* 'obj': the input tensor that is overloaded
|
||||
*/
|
||||
void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj);
|
||||
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user