mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
43 lines
1.2 KiB
C++
43 lines
1.2 KiB
C++
#include "torch/csrc/autograd/python_hook.h"
|
|
|
|
#include "THP.h"
|
|
#include "torch/csrc/autograd/python_variable.h"
|
|
#include "torch/csrc/utils/auto_gil.h"
|
|
#include "torch/csrc/utils/object_ptr.h"
|
|
#include "torch/csrc/Exceptions.h"
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
PyGradHook::PyGradHook(PyObject* dict) : dict(dict) {
|
|
Py_INCREF(dict);
|
|
}
|
|
|
|
PyGradHook::~PyGradHook() {
|
|
AutoGIL gil;
|
|
Py_DECREF(dict);
|
|
}
|
|
|
|
auto PyGradHook::operator()(const std::shared_ptr<Variable>& _grad) -> std::shared_ptr<Variable> {
|
|
AutoGIL gil;
|
|
|
|
THPObjectPtr grad = THPVariable_Wrap(_grad);
|
|
if (!grad) throw python_error();
|
|
|
|
PyObject *key, *value;
|
|
Py_ssize_t pos = 0;
|
|
while (PyDict_Next(dict, &pos, &key, &value)) {
|
|
THPObjectPtr res = PyObject_CallFunctionObjArgs(value, grad.get(), nullptr);
|
|
if (!res) throw python_error();
|
|
if (res == Py_None) continue;
|
|
if (!PyObject_IsInstance(res.get(), THPVariableClass)) {
|
|
PyErr_Format(PyExc_TypeError, "expected Variable, but hook returned '%s'",
|
|
THPUtils_typename(res.get()));
|
|
throw python_error();
|
|
}
|
|
grad = std::move(res);
|
|
}
|
|
return ((THPVariable*)grad.get())->cdata;
|
|
}
|
|
|
|
}} // namespace torch::autograd
|