mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Implement subclass priority for __torch_dispatch__ (#63411)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63411 In order to get this behavior, you have to use append_overloaded, which I forgot to use in the previous implementation. I exposed an internal helper function which is more appropriate for dispatch to Python where we know that an argument is definitely a Tensor (and this test no longer needs to be done). Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D30374489 Pulled By: ezyang fbshipit-source-id: 43b08c00d1958c9b26d82a025d19f0b67bb85590
This commit is contained in:
parent
061b36e2f5
commit
c508433617
|
|
@ -246,6 +246,39 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
|
||||||
x.data.add_(2)
|
x.data.add_(2)
|
||||||
self.assertEqual(cur_vc, x._version)
|
self.assertEqual(cur_vc, x._version)
|
||||||
|
|
||||||
|
def test_subclass_priority(self) -> None:
|
||||||
|
class ErrorA(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ErrorB(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# The big tests for code coverage are test_precedence_semantics in
|
||||||
|
# test_overrides.py; this is just to make sure it is wired up at all
|
||||||
|
# correctly for __torch_dispatch__
|
||||||
|
class A(torch.Tensor):
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, elem):
|
||||||
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
raise ErrorA
|
||||||
|
|
||||||
|
class B(A):
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, elem):
|
||||||
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
raise ErrorB
|
||||||
|
|
||||||
|
self.assertRaises(ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1))))
|
||||||
|
self.assertRaises(ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1))))
|
||||||
|
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1))))
|
||||||
|
self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1))))
|
||||||
|
|
||||||
def test_format(self) -> None:
|
def test_format(self) -> None:
|
||||||
x = LoggingTensor(torch.ones(1))
|
x = LoggingTensor(torch.ones(1))
|
||||||
s1 = str(x)
|
s1 = str(x)
|
||||||
|
|
|
||||||
|
|
@ -1562,7 +1562,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
|
||||||
if (ivalue.isTensor()) {
|
if (ivalue.isTensor()) {
|
||||||
const auto& tensor = ivalue.toTensor();
|
const auto& tensor = ivalue.toTensor();
|
||||||
if (isPythonTensor(tensor)) {
|
if (isPythonTensor(tensor)) {
|
||||||
overloaded_args.emplace_back(py::cast(tensor));
|
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
|
||||||
}
|
}
|
||||||
} else if (ivalue.isList()) {
|
} else if (ivalue.isList()) {
|
||||||
const auto& list = ivalue.toListRef();
|
const auto& list = ivalue.toListRef();
|
||||||
|
|
@ -1571,7 +1571,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
|
||||||
if (nv.isTensor()) {
|
if (nv.isTensor()) {
|
||||||
const auto& tensor = nv.toTensor();
|
const auto& tensor = nv.toTensor();
|
||||||
if (isPythonTensor(tensor)) {
|
if (isPythonTensor(tensor)) {
|
||||||
overloaded_args.emplace_back(py::cast(tensor));
|
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1620,7 +1620,8 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
|
||||||
// TODO: fix the constness of target
|
// TODO: fix the constness of target
|
||||||
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
|
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
|
||||||
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
|
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
|
||||||
overloaded_args.emplace_back(self_p);
|
TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
|
||||||
|
append_overloaded_arg(&overloaded_args, self_p.ptr());
|
||||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
|
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
|
||||||
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
|
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -810,4 +810,14 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
|
||||||
*/
|
*/
|
||||||
bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args, int argnum, bool throw_error);
|
bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args, int argnum, bool throw_error);
|
||||||
|
|
||||||
|
/* Given an argument that is definitely a tensor and is definitely overloaded,
|
||||||
|
* append it to the overloaded arguments list. Use this instead of
|
||||||
|
* is_tensor_and_append_overloaded in situations where you have a PyObject
|
||||||
|
* and you know it definitely is a Tensor and it is definitely overloaded.
|
||||||
|
*
|
||||||
|
* 'overloaded_args': the vector to append the overloaded args
|
||||||
|
* 'obj': the input tensor that is overloaded
|
||||||
|
*/
|
||||||
|
void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj);
|
||||||
|
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user