Fix cpp node instance check (#125875)

Mostly visible when calling multi_grad_hook and thus using this to test it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125875
Approved by: https://github.com/jackiexu1992, https://github.com/ezyang
This commit is contained in:
albanD 2024-05-11 21:31:10 +00:00 committed by PyTorch MergeBot
parent 07d6ab5aa2
commit 6ffc94fa62
2 changed files with 65 additions and 7 deletions

View File

@ -79,6 +79,7 @@ from torch.testing._internal.common_utils import (
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from torch.utils.cpp_extension import load_inline
from torch.utils.hooks import RemovableHandle
@ -9508,6 +9509,59 @@ for shape in [(1,), ()]:
t3 = torch.rand(2, requires_grad=True)
t4 = torch.rand(2, requires_grad=True)
# Ensure we properly detect all types of Nodes here
# C++ Node
t1 = t1.mul(2)
# Python custom Function
class Foo(Function):
@staticmethod
def forward(ctx, a):
return a.clone()
@staticmethod
def backward(ctx, gO):
return gO
t2 = Foo.apply(t2)
# C++ Node
t3 = torch._C._functions.UndefinedGrad()(t3)
# C++ Custom Op
cpp_source = """
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& x) {
return x.clone();
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext *ctx,
torch::autograd::variable_list grad_output) {
return grad_output;
}
};
torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
return CustomOpAutogradFunction::apply(x);
}
TORCH_LIBRARY(test_autograd_cpp_node, m) {
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
}
"""
module = load_inline(
name="test_autograd_cpp_node",
cpp_sources=cpp_source,
functions="custom_op_backed_by_autograd_fn",
verbose=True,
)
t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4)
res = [None] * 4
count = [0]

View File

@ -20,8 +20,7 @@
using namespace torch::autograd;
namespace torch {
namespace autograd {
namespace torch::autograd {
namespace {
@ -227,6 +226,7 @@ PyTypeObject* _initFunctionPyTypeObject(
const char* name,
PyGetSetDef* function_properties,
PyMethodDef* function_methods) {
type.ob_base = {PyObject_HEAD_INIT(nullptr) 0};
// NOLINTNEXTLINE(misc-redundant-expression)
type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
type.tp_name = name;
@ -251,15 +251,17 @@ static std::unordered_set<PyTypeObject*> cpp_function_types_set;
struct DefaultFunctionType {
DefaultFunctionType() : type() {
_initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
Py_INCREF(&type);
}
PyTypeObject type;
};
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
PyTypeObject* get_default_type() {
static DefaultFunctionType default_type;
return &(default_type.type);
}
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
if (!cdata) {
Py_RETURN_NONE;
}
@ -278,7 +280,7 @@ PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
PyTypeObject* type;
if (it == cpp_function_types_map.end()) {
type = &default_type.type;
type = get_default_type();
} else {
type = (PyTypeObject*)it->second.get();
}
@ -305,6 +307,9 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
bool THPCppFunction_Check(PyObject* obj) {
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
if ((PyTypeObject*)type.get() == get_default_type()) {
return true;
}
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
cpp_function_types_set.end()) {
return false;
@ -374,5 +379,4 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
return handle;
}
} // namespace autograd
} // namespace torch
} // namespace torch::autograd