#include #include #include #include #include #include #include #include using torch::autograd::variable_list; using torch::autograd::Variable; static PyObject* wrap_variables(const variable_list& c_variables); static variable_list unwrap_variables(PyObject* py_variables); static std::string hook_name(PyObject* hook); static void check_result(PyObject* original, PyObject* result, PyObject* hook); static void check_single_result(PyObject* original, PyObject* result, PyObject* hook); namespace torch { namespace autograd { PyFunctionPreHook::PyFunctionPreHook(PyObject* dict, int value_idx) : dict(dict) , value_idx(value_idx) { Py_INCREF(dict); } PyFunctionPreHook::~PyFunctionPreHook() { pybind11::gil_scoped_acquire gil; Py_DECREF(dict); } auto PyFunctionPreHook::operator()(const variable_list& values) -> variable_list { pybind11::gil_scoped_acquire gil; THPObjectPtr value(THPVariable_Wrap(values.at(value_idx))); if (!value) throw python_error(); PyObject *key, *hook; Py_ssize_t pos = 0; while (PyDict_Next(dict, &pos, &key, &hook)) { THPObjectPtr res(PyObject_CallFunctionObjArgs(hook, value.get(), nullptr)); if (!res) throw python_error(); if (res == Py_None) continue; check_single_result(value.get(), res.get(), hook); value = std::move(res); } variable_list results(values); if (value != Py_None) results[value_idx] = ((THPVariable*)value.get())->cdata; return results; } PyFunctionPostHook::PyFunctionPostHook(PyObject* dict) : dict(dict) { Py_INCREF(dict); } PyFunctionPostHook::~PyFunctionPostHook() { pybind11::gil_scoped_acquire gil; Py_DECREF(dict); } auto PyFunctionPostHook::operator()( const variable_list& _outputs, /* grad_inputs */ const variable_list& _inputs /* grad_outputs */) -> variable_list { pybind11::gil_scoped_acquire gil; THPObjectPtr outputs(wrap_variables(_outputs)); THPObjectPtr inputs(wrap_variables(_inputs)); PyObject *key, *hook; Py_ssize_t pos = 0; while (PyDict_Next(dict, &pos, &key, &hook)) { THPObjectPtr res(PyObject_CallFunctionObjArgs( hook, outputs.get(), inputs.get(), nullptr)); if (!res) throw python_error(); if (res == Py_None) continue; check_result(outputs, res, hook); outputs = std::move(res); } return unwrap_variables(outputs.get()); } }} // namespace torch::autograd static PyObject *wrap_variables(const variable_list& c_variables) { size_t num_vars = c_variables.size(); THPObjectPtr tuple(PyTuple_New(num_vars)); if (!tuple) throw python_error(); for (size_t i = 0; i < num_vars; ++i) { THPObjectPtr var(THPVariable_Wrap(c_variables[i])); if (!var) throw python_error(); PyTuple_SET_ITEM(tuple.get(), i, var.release()); } return tuple.release(); } static variable_list unwrap_variables(PyObject* py_variables) { variable_list results(PyTuple_GET_SIZE(py_variables)); for (size_t i = 0; i < results.size(); i++) { PyObject* item = PyTuple_GET_ITEM(py_variables, i); if (item == Py_None) { continue; } else if (THPVariable_Check(item)) { results[i] = ((THPVariable*)item)->cdata; } else { // this should never happen, but just in case... std::stringstream ss; ss << "expected variable but got " << Py_TYPE(item)->tp_name; throw std::runtime_error(ss.str()); } } return results; } static void check_result(PyObject* prev, PyObject* result, PyObject* hook) { if (!PyTuple_Check(result)) { PyErr_Format(PyExc_TypeError, "expected tuple, but hook returned '%s'", THPUtils_typename(result)); throw python_error(); } auto prev_size = PyTuple_GET_SIZE(prev); auto result_size = PyTuple_GET_SIZE(result); if (prev_size != result_size) { std::stringstream ss; auto name = hook_name(hook); ss << "hook '" << name << "' has returned an incorrect number "; ss << "of values (got " << result_size << ", but expected "; ss << prev_size << ")"; throw std::runtime_error(ss.str()); } for (auto i = 0; i < prev_size; i++) { check_single_result(PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook); } } static void check_single_result(PyObject* _original, PyObject* _result, PyObject* hook) { if (_result == Py_None) return; if (_original == Py_None) { throw std::runtime_error("can't replace a None gradient with a non-None value"); } if (!PyObject_IsInstance(_result, THPVariableClass)) { PyErr_Format(PyExc_TypeError, "expected Variable, but hook returned '%s'", THPUtils_typename(_result)); throw python_error(); } auto& original = ((THPVariable*)_original)->cdata; auto& result = ((THPVariable*)_result)->cdata; torch::autograd::check_variable_result(original, result, hook_name(hook)); } static std::string hook_name(PyObject* hook) { if (PyObject_HasAttrString(hook, "__name__")) { THPObjectPtr name(PyObject_GetAttrString(hook, "__name__")); if (!name) throw python_error(); if (name && THPUtils_checkString(name.get())) { return THPUtils_unpackString(name.get()); } } return ""; }