diff --git a/test/test_autograd.py b/test/test_autograd.py index 79880a1d628..e20e8b18eba 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -79,6 +79,7 @@ from torch.testing._internal.common_utils import ( from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torch.utils.cpp_extension import load_inline from torch.utils.hooks import RemovableHandle @@ -9508,6 +9509,59 @@ for shape in [(1,), ()]: t3 = torch.rand(2, requires_grad=True) t4 = torch.rand(2, requires_grad=True) + # Ensure we properly detect all types of Nodes here + # C++ Node + t1 = t1.mul(2) + + # Python custom Function + class Foo(Function): + @staticmethod + def forward(ctx, a): + return a.clone() + + @staticmethod + def backward(ctx, gO): + return gO + + t2 = Foo.apply(t2) + + # C++ Node + t3 = torch._C._functions.UndefinedGrad()(t3) + + # C++ Custom Op + cpp_source = """ +struct CustomOpAutogradFunction : public torch::autograd::Function { + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& x) { + return x.clone(); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext *ctx, + torch::autograd::variable_list grad_output) { + return grad_output; + } +}; + +torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { + return CustomOpAutogradFunction::apply(x); +} + +TORCH_LIBRARY(test_autograd_cpp_node, m) { + m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); +} + """ + + module = load_inline( + name="test_autograd_cpp_node", + cpp_sources=cpp_source, + functions="custom_op_backed_by_autograd_fn", + verbose=True, + ) + + t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4) + res = [None] * 4 count = [0] diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index 66a5a152ec7..6c86e2136de 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -20,8 +20,7 @@ using namespace torch::autograd; -namespace torch { -namespace autograd { +namespace torch::autograd { namespace { @@ -227,6 +226,7 @@ PyTypeObject* _initFunctionPyTypeObject( const char* name, PyGetSetDef* function_properties, PyMethodDef* function_methods) { + type.ob_base = {PyObject_HEAD_INIT(nullptr) 0}; // NOLINTNEXTLINE(misc-redundant-expression) type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC; type.tp_name = name; @@ -251,15 +251,17 @@ static std::unordered_set cpp_function_types_set; struct DefaultFunctionType { DefaultFunctionType() : type() { _initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr); - Py_INCREF(&type); } PyTypeObject type; }; -PyObject* functionToPyObject(const std::shared_ptr& cdata) { +PyTypeObject* get_default_type() { static DefaultFunctionType default_type; + return &(default_type.type); +} +PyObject* functionToPyObject(const std::shared_ptr& cdata) { if (!cdata) { Py_RETURN_NONE; } @@ -278,7 +280,7 @@ PyObject* functionToPyObject(const std::shared_ptr& cdata) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) PyTypeObject* type; if (it == cpp_function_types_map.end()) { - type = &default_type.type; + type = get_default_type(); } else { type = (PyTypeObject*)it->second.get(); } @@ -305,6 +307,9 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) { bool THPCppFunction_Check(PyObject* obj) { THPObjectPtr type = THPObjectPtr(PyObject_Type(obj)); + if ((PyTypeObject*)type.get() == get_default_type()) { + return true; + } if (cpp_function_types_set.find((PyTypeObject*)type.get()) == cpp_function_types_set.end()) { return false; @@ -374,5 +379,4 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) { return handle; } -} // namespace autograd -} // namespace torch +} // namespace torch::autograd