mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix cpp node instance check (#125875)
Mostly visible when calling multi_grad_hook and thus using this to test it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125875 Approved by: https://github.com/jackiexu1992, https://github.com/ezyang
This commit is contained in:
parent
07d6ab5aa2
commit
6ffc94fa62
|
|
@ -79,6 +79,7 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
|
|
@ -9508,6 +9509,59 @@ for shape in [(1,), ()]:
|
|||
t3 = torch.rand(2, requires_grad=True)
|
||||
t4 = torch.rand(2, requires_grad=True)
|
||||
|
||||
# Ensure we properly detect all types of Nodes here
|
||||
# C++ Node
|
||||
t1 = t1.mul(2)
|
||||
|
||||
# Python custom Function
|
||||
class Foo(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a):
|
||||
return a.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
return gO
|
||||
|
||||
t2 = Foo.apply(t2)
|
||||
|
||||
# C++ Node
|
||||
t3 = torch._C._functions.UndefinedGrad()(t3)
|
||||
|
||||
# C++ Custom Op
|
||||
cpp_source = """
|
||||
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
|
||||
static torch::Tensor forward(
|
||||
torch::autograd::AutogradContext* ctx,
|
||||
const torch::Tensor& x) {
|
||||
return x.clone();
|
||||
}
|
||||
|
||||
static torch::autograd::variable_list backward(
|
||||
torch::autograd::AutogradContext *ctx,
|
||||
torch::autograd::variable_list grad_output) {
|
||||
return grad_output;
|
||||
}
|
||||
};
|
||||
|
||||
torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
|
||||
return CustomOpAutogradFunction::apply(x);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(test_autograd_cpp_node, m) {
|
||||
m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
|
||||
}
|
||||
"""
|
||||
|
||||
module = load_inline(
|
||||
name="test_autograd_cpp_node",
|
||||
cpp_sources=cpp_source,
|
||||
functions="custom_op_backed_by_autograd_fn",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4)
|
||||
|
||||
res = [None] * 4
|
||||
count = [0]
|
||||
|
||||
|
|
|
|||
|
|
@ -20,8 +20,7 @@
|
|||
|
||||
using namespace torch::autograd;
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace torch::autograd {
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -227,6 +226,7 @@ PyTypeObject* _initFunctionPyTypeObject(
|
|||
const char* name,
|
||||
PyGetSetDef* function_properties,
|
||||
PyMethodDef* function_methods) {
|
||||
type.ob_base = {PyObject_HEAD_INIT(nullptr) 0};
|
||||
// NOLINTNEXTLINE(misc-redundant-expression)
|
||||
type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
|
||||
type.tp_name = name;
|
||||
|
|
@ -251,15 +251,17 @@ static std::unordered_set<PyTypeObject*> cpp_function_types_set;
|
|||
struct DefaultFunctionType {
|
||||
DefaultFunctionType() : type() {
|
||||
_initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
|
||||
Py_INCREF(&type);
|
||||
}
|
||||
|
||||
PyTypeObject type;
|
||||
};
|
||||
|
||||
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
|
||||
PyTypeObject* get_default_type() {
|
||||
static DefaultFunctionType default_type;
|
||||
return &(default_type.type);
|
||||
}
|
||||
|
||||
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
|
||||
if (!cdata) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
|
@ -278,7 +280,7 @@ PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
PyTypeObject* type;
|
||||
if (it == cpp_function_types_map.end()) {
|
||||
type = &default_type.type;
|
||||
type = get_default_type();
|
||||
} else {
|
||||
type = (PyTypeObject*)it->second.get();
|
||||
}
|
||||
|
|
@ -305,6 +307,9 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
|
|||
|
||||
bool THPCppFunction_Check(PyObject* obj) {
|
||||
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
|
||||
if ((PyTypeObject*)type.get() == get_default_type()) {
|
||||
return true;
|
||||
}
|
||||
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
|
||||
cpp_function_types_set.end()) {
|
||||
return false;
|
||||
|
|
@ -374,5 +379,4 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
|
|||
return handle;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
} // namespace torch::autograd
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user