pytorch/torch/csrc/autograd/python_cpp_function.h
soulitzer a876432aea Expose torch._will_engine_execute_node (#84773)
Addresses: https://github.com/pytorch/pytorch/issues/83617

This PR a way to query the TLS graph task's exec_info which is a map mapping the Node to a bool indicating whether it will be executed in the current backward pass (as determined by the inputs= argument for .grad of .backward).
- this works with both custom Function nodes and normal codegened nodes
-  to be able to verify whether the pyobject passed is an actual node, we now store pointers to PyTypeObjects into a set on registration.
- error out when .backward without inputs= to avoid silently returning True

Alternatives:
- not sure if it is possible to bind to Python from a raw pointer to Node. At least we wouldn't be able to use existing logic, and the Python object should only hold a weak reference to the Node.
- other solutions to the motivating issue seem to require more extensive modification to the engine

See the issue linked for an example of usage
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84773
Approved by: https://github.com/albanD
2022-09-28 20:13:52 +00:00

103 lines
4.1 KiB
C++

#pragma once
#include <torch/csrc/python_headers.h>
#include <memory>
#include <typeinfo>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/utils/object_ptr.h>
namespace torch {
namespace autograd {
struct THPCppFunction {
PyObject_HEAD std::shared_ptr<Node> cdata;
};
template <typename Ctor>
PyObject* CppFunction_pynew(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
THPObjectPtr obj(type->tp_alloc(type, 0));
if (!obj)
return nullptr;
THPCppFunction* f = (THPCppFunction*)obj.get();
HANDLE_TH_ERRORS
new (&f->cdata) std::shared_ptr<Node>(Ctor()(args));
END_HANDLE_TH_ERRORS
if (!f->cdata) {
return nullptr;
}
return obj.release();
}
#define THP_FUNCTION_DEFAULT_METHODS \
{(char*)"_register_hook_dict", \
THPCppFunction_register_hook_dict, \
METH_O, \
nullptr}, \
{(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \
{(char*)"register_prehook", \
THPCppFunction_register_prehook, \
METH_O, \
nullptr}, \
{ \
(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr \
}
#define THP_FUNCTION_DEFAULT_PROPERTIES \
{(char*)"next_functions", \
(getter)THPCppFunction_next_functions, \
nullptr, \
nullptr, \
nullptr}, \
{(char*)"requires_grad", \
(getter)THPCppFunction_requires_grad, \
nullptr, \
nullptr, \
nullptr}, \
{ \
(char*)"metadata", (getter)THPCppFunction_metadata, nullptr, nullptr, \
nullptr \
}
PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook);
PyObject* THPCppFunction_metadata(THPCppFunction* self, void* _unused);
PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void* _unused);
PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var);
PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook);
PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook);
PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs);
PyTypeObject* _initFunctionPyTypeObject(
PyTypeObject& type,
const char* name,
PyGetSetDef* function_properties,
PyMethodDef* function_methods);
PyObject* registerFunctionHook(Node& fn, PyObject* hook);
PyObject* registerFunctionPreHook(Node& fn, PyObject* hook);
template <typename Ctor>
PyTypeObject* createForwardFunctionPyTypeObject(
PyTypeObject& type,
const char* name,
PyGetSetDef* function_properties = nullptr,
PyMethodDef* function_methods = nullptr) {
type.tp_new = &CppFunction_pynew<Ctor>;
return _initFunctionPyTypeObject(
type, name, function_properties, function_methods);
}
void registerCppFunction(const std::type_info& type, PyTypeObject* pytype);
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata);
bool THPCppFunction_Check(PyObject* obj);
} // namespace autograd
} // namespace torch