mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
401 lines
12 KiB
C++
401 lines
12 KiB
C++
#include <c10/util/irange.h>
|
|
#include <torch/csrc/autograd/python_cpp_function.h>
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <cstdio>
|
|
#include <memory>
|
|
#include <typeindex>
|
|
#include <unordered_map>
|
|
|
|
#include <pybind11/pybind11.h>
|
|
#include <torch/csrc/DynamicTypes.h>
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/autograd/python_anomaly_mode.h>
|
|
#include <torch/csrc/autograd/python_function.h>
|
|
#include <torch/csrc/autograd/python_hook.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
|
|
using namespace torch::autograd;
|
|
|
|
namespace torch::autograd {
|
|
|
|
namespace {
|
|
|
|
PyObject* THPCppFunction_call(
|
|
PyObject* self,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
if (kwargs && PyDict_Size(kwargs) != 0) {
|
|
return PyErr_Format(PyExc_TypeError, "keyword arguments are not supported");
|
|
}
|
|
|
|
auto num_inputs = PyTuple_GET_SIZE(args);
|
|
auto num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs();
|
|
if (num_inputs != num_inputs_required) {
|
|
return PyErr_Format(
|
|
PyExc_TypeError,
|
|
"expected %d arguments, got %d instead",
|
|
num_inputs_required,
|
|
num_inputs);
|
|
}
|
|
variable_list vars(num_inputs);
|
|
for (int i = 0; i != num_inputs; ++i) {
|
|
PyObject* arg = PyTuple_GET_ITEM(args, i);
|
|
if (arg == Py_None) {
|
|
continue;
|
|
}
|
|
if (!THPVariable_Check(arg)) {
|
|
return PyErr_Format(PyExc_TypeError, "argument %d is not a Variable", i);
|
|
}
|
|
vars[i] = THPVariable_Unpack(arg);
|
|
}
|
|
|
|
variable_list output;
|
|
|
|
HANDLE_TH_ERRORS {
|
|
pybind11::gil_scoped_release nogil;
|
|
output = (*((THPCppFunction*)self)->cdata)(std::move(vars));
|
|
}
|
|
END_HANDLE_TH_ERRORS
|
|
|
|
auto num_outputs = output.size();
|
|
if (num_outputs == 1) {
|
|
// assume we want to unpack one element tuples for now
|
|
return THPVariable_Wrap(output[0]);
|
|
}
|
|
|
|
THPObjectPtr tuple(PyTuple_New(static_cast<Py_ssize_t>(num_outputs)));
|
|
for (size_t i = 0; i != num_outputs; ++i) {
|
|
PyTuple_SET_ITEM(tuple.get(), i, THPVariable_Wrap(output[i]));
|
|
}
|
|
return tuple.release();
|
|
}
|
|
|
|
int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
|
|
if ((((THPCppFunction*)self)->cdata).use_count() == 1) {
|
|
// The fields traversed below are owned by the cpp grad_fn, which we own a
|
|
// reference to. We should only them traverse however if we are the only
|
|
// owner of the grad_fn, otherwise we risk prematurely gc'ing the grad_fn.
|
|
//
|
|
// See: https://github.com/pytorch/pytorch/issues/102174
|
|
auto& fn = *((THPCppFunction*)self)->cdata;
|
|
for (const auto& hook : fn.tensor_pre_hooks()) {
|
|
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
|
|
Py_VISIT(pyhook->dict);
|
|
}
|
|
}
|
|
// NOTE [retains_grad_hook PyObject traversal]
|
|
// In theory this shouldn't be necessary, because retains_grad_hooks should
|
|
// not contain any PyFunctionTensorPreHooks. The alternative is to have a
|
|
// check that actually guarantees this.
|
|
for (const auto& pair : fn.retains_grad_hooks()) {
|
|
if (auto pyhook =
|
|
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
|
|
Py_VISIT(pyhook->dict);
|
|
}
|
|
}
|
|
for (const auto& hook : fn.pre_hooks()) {
|
|
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
|
|
Py_VISIT(pyhook->dict);
|
|
}
|
|
}
|
|
for (const auto& hook : fn.post_hooks()) {
|
|
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
|
|
Py_VISIT(pyhook->dict);
|
|
}
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int THPCppFunction_clear(PyObject* self) {
|
|
auto f = (THPCppFunction*)self;
|
|
// Remove the weak ref of the c++ object if it exist
|
|
if (f->cdata) {
|
|
f->cdata->set_pyobj(nullptr);
|
|
}
|
|
f->cdata.reset();
|
|
return 0;
|
|
}
|
|
|
|
void THPCppFunction_dealloc(PyObject* self) {
|
|
PyObject_GC_UnTrack(self);
|
|
THPCppFunction_clear(self);
|
|
((THPCppFunction*)self)->cdata.~shared_ptr();
|
|
Py_TYPE(self)->tp_free(self);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused) {
|
|
auto cdata = reinterpret_cast<const THPCppFunction*>(self)->cdata;
|
|
const auto num_next = cdata->num_outputs();
|
|
THPObjectPtr py_functions(PyTuple_New(num_next));
|
|
if (!py_functions)
|
|
return nullptr;
|
|
for (const auto i : c10::irange(num_next)) {
|
|
auto& c_tuple = cdata->next_edge(i);
|
|
THPObjectPtr tuple(PyTuple_New(2));
|
|
if (!tuple)
|
|
return nullptr;
|
|
PyObject* py_fn = functionToPyObject(c_tuple.function);
|
|
if (!py_fn)
|
|
return nullptr;
|
|
PyTuple_SET_ITEM(tuple.get(), 0, py_fn);
|
|
PyObject* py_idx = THPUtils_packUInt32(c_tuple.input_nr);
|
|
if (!py_idx)
|
|
return nullptr;
|
|
PyTuple_SET_ITEM(tuple.get(), 1, py_idx);
|
|
PyTuple_SET_ITEM(py_functions.get(), i, tuple.release());
|
|
}
|
|
return py_functions.release();
|
|
}
|
|
|
|
PyObject* THPCppFunction_metadata(PyObject* self, void* _unused) {
|
|
auto* metadata =
|
|
static_cast<PyAnomalyMetadata*>(
|
|
reinterpret_cast<THPCppFunction*>(self)->cdata->metadata())
|
|
->dict();
|
|
|
|
Py_XINCREF(metadata);
|
|
return metadata;
|
|
}
|
|
|
|
PyObject* THPCppFunction_requires_grad(PyObject* self, void* unused) {
|
|
Py_RETURN_TRUE;
|
|
}
|
|
|
|
PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) {
|
|
if (!THPVariable_Check(_var)) {
|
|
return PyErr_Format(
|
|
PyExc_TypeError, "_register_hook_dict expected a variable");
|
|
}
|
|
auto var = (THPVariable*)_var;
|
|
auto& fn = *((THPCppFunction*)self)->cdata;
|
|
fn.add_tensor_pre_hook(std::make_unique<PyFunctionTensorPreHook>(
|
|
var->backward_hooks, THPVariable_Unpack(var).output_nr()));
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) {
|
|
auto& fn = *((THPCppFunction*)self)->cdata;
|
|
return registerFunctionHook(fn, hook);
|
|
}
|
|
|
|
PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook) {
|
|
auto& fn = *((THPCppFunction*)self)->cdata;
|
|
return registerFunctionPreHook(fn, hook);
|
|
}
|
|
|
|
PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs) {
|
|
auto& fn = *((THPCppFunction*)self)->cdata;
|
|
return THPUtils_packString(fn.name());
|
|
}
|
|
|
|
PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs) {
|
|
auto& fn = *((THPCppFunction*)self)->cdata;
|
|
return THPUtils_packUInt64(fn.sequence_nr());
|
|
}
|
|
|
|
PyObject* THPCppFunction_set_sequence_nr(
|
|
PyObject* self,
|
|
PyObject* sequence_nr) {
|
|
HANDLE_TH_ERRORS
|
|
auto& fn = *((THPCppFunction*)self)->cdata;
|
|
fn.set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPCppFunction_input_metadata(PyObject* self, void* closure) {
|
|
HANDLE_TH_ERRORS;
|
|
auto& fn = *((THPCppFunction*)self)->cdata;
|
|
const auto num_inputs =
|
|
fn.num_inputs(); // Assuming there's a method to get the number of inputs
|
|
THPObjectPtr list(PyTuple_New(num_inputs));
|
|
if (!list) {
|
|
return nullptr;
|
|
}
|
|
for (size_t i = 0; i < num_inputs; ++i) {
|
|
const auto& metadata = fn.input_metadata(i);
|
|
THPObjectPtr item(py::cast(metadata).release().ptr());
|
|
if (!item) {
|
|
return nullptr;
|
|
}
|
|
PyTuple_SET_ITEM(list.get(), i, item.release());
|
|
}
|
|
return list.release();
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
|
static struct PyMethodDef default_methods[] = {
|
|
THP_FUNCTION_DEFAULT_METHODS,
|
|
{nullptr}};
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
|
static struct PyGetSetDef default_properties[] = {
|
|
THP_FUNCTION_DEFAULT_PROPERTIES,
|
|
{nullptr}};
|
|
|
|
PyTypeObject* _initFunctionPyTypeObject(
|
|
PyTypeObject& type,
|
|
const char* name,
|
|
PyGetSetDef* function_properties,
|
|
PyMethodDef* function_methods) {
|
|
type.ob_base = {PyObject_HEAD_INIT(nullptr) 0};
|
|
// NOLINTNEXTLINE(misc-redundant-expression)
|
|
type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC;
|
|
type.tp_name = name;
|
|
type.tp_basicsize = sizeof(THPCppFunction);
|
|
type.tp_call = THPCppFunction_call;
|
|
type.tp_methods = function_methods ? function_methods : default_methods;
|
|
type.tp_getset =
|
|
function_properties ? function_properties : default_properties;
|
|
type.tp_dealloc = THPCppFunction_dealloc;
|
|
type.tp_traverse = THPCppFunction_traverse;
|
|
type.tp_clear = THPCppFunction_clear;
|
|
if (PyType_Ready(&type) < 0) {
|
|
auto msg = std::string("Unable to instantiate PyTypeObject for ") + name;
|
|
throw std::runtime_error(msg);
|
|
}
|
|
return &type;
|
|
}
|
|
|
|
static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types_map;
|
|
static std::unordered_set<PyTypeObject*> cpp_function_types_set;
|
|
|
|
struct DefaultFunctionType {
|
|
DefaultFunctionType() : type() {
|
|
_initFunctionPyTypeObject(type, "CppFunction", nullptr, nullptr);
|
|
}
|
|
|
|
PyTypeObject type;
|
|
};
|
|
|
|
PyTypeObject* get_default_type() {
|
|
static DefaultFunctionType default_type;
|
|
return &(default_type.type);
|
|
}
|
|
|
|
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
|
|
if (!cdata) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
if (auto pfw = dynamic_cast<PyNode*>(cdata.get())) {
|
|
PyObject* obj = pfw->obj;
|
|
Py_INCREF(obj);
|
|
return obj;
|
|
}
|
|
|
|
if (cdata->pyobj()) {
|
|
Py_INCREF(cdata->pyobj());
|
|
} else {
|
|
auto& fn = *cdata;
|
|
auto it = cpp_function_types_map.find(std::type_index(typeid(fn)));
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
PyTypeObject* type;
|
|
if (it == cpp_function_types_map.end()) {
|
|
type = get_default_type();
|
|
} else {
|
|
type = (PyTypeObject*)it->second.get();
|
|
}
|
|
|
|
THPObjectPtr obj(type->tp_alloc(type, 0));
|
|
if (!obj)
|
|
return nullptr;
|
|
THPCppFunction* f = (THPCppFunction*)obj.get();
|
|
new (&f->cdata) std::shared_ptr<Node>(cdata);
|
|
|
|
// No INCREF here as we only have a weak reference
|
|
cdata->set_pyobj(obj.release());
|
|
}
|
|
|
|
return cdata->pyobj();
|
|
}
|
|
|
|
void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
|
|
Py_INCREF((PyObject*)pytype);
|
|
cpp_function_types_map[std::type_index(type)] =
|
|
THPObjectPtr((PyObject*)pytype);
|
|
cpp_function_types_set.insert(pytype);
|
|
}
|
|
|
|
bool THPCppFunction_Check(PyObject* obj) {
|
|
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
|
|
if ((PyTypeObject*)type.get() == get_default_type()) {
|
|
return true;
|
|
}
|
|
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
|
|
cpp_function_types_set.end()) {
|
|
return false;
|
|
} else {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
PyObject* callRegisterFn(PyObject* dict, PyObject* hook) {
|
|
THPObjectPtr register_fn(
|
|
PyObject_GetAttrString(THPFunctionClass, "_register_hook"));
|
|
if (!register_fn) {
|
|
return nullptr;
|
|
}
|
|
THPObjectPtr res(
|
|
PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr));
|
|
if (!res) {
|
|
return nullptr;
|
|
}
|
|
return res.release();
|
|
}
|
|
|
|
PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
|
|
PyObject* dict = Py_None;
|
|
for (const auto& hook : fn.post_hooks()) {
|
|
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
|
|
dict = pyhook->dict;
|
|
break;
|
|
}
|
|
}
|
|
THPObjectPtr res{callRegisterFn(dict, hook)};
|
|
if (!res) {
|
|
return nullptr;
|
|
}
|
|
if (dict == Py_None) {
|
|
dict = PyTuple_GET_ITEM(res.get(), 0);
|
|
fn.add_post_hook(std::make_unique<PyFunctionPostHook>(dict));
|
|
}
|
|
|
|
PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
|
|
Py_INCREF(handle);
|
|
return handle;
|
|
}
|
|
|
|
// This is almost a copy of the function above except post -> pre
|
|
PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
|
|
PyObject* dict = Py_None;
|
|
for (const auto& hook : fn.pre_hooks()) {
|
|
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
|
|
dict = pyhook->dict;
|
|
break;
|
|
}
|
|
}
|
|
THPObjectPtr res{callRegisterFn(dict, hook)};
|
|
if (!res) {
|
|
return nullptr;
|
|
}
|
|
if (dict == Py_None) {
|
|
dict = PyTuple_GET_ITEM(res.get(), 0);
|
|
fn.add_pre_hook(std::make_unique<PyFunctionPreHook>(dict));
|
|
}
|
|
|
|
PyObject* handle = PyTuple_GET_ITEM(res.get(), 1);
|
|
Py_INCREF(handle);
|
|
return handle;
|
|
}
|
|
|
|
} // namespace torch::autograd
|