#include "torch/csrc/autograd/python_function.h" #include #include #include #include #include #include "THP.h" #include "torch/csrc/autograd/functions/accumulate_grad.h" #include "torch/csrc/autograd/functions/basic_ops.h" #include "torch/csrc/autograd/functions/utils.h" #include "torch/csrc/autograd/python_cpp_function.h" #include "torch/csrc/autograd/python_hook.h" #include "torch/csrc/jit/tracer.h" #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/utils/auto_gil.h" #include "torch/csrc/utils/auto_gpu.h" #include "torch/csrc/Exceptions.h" #ifdef WITH_CUDA #include "cuda/AutoGPU.h" #endif using namespace torch; using namespace torch::autograd; using namespace torch::jit; PyObject *THPFunctionClass = NULL; PyObject *THPStochasticFunctionClass = NULL; PyObject *THPBatchNormBackwardBackwardFunction = NULL; #define THPFunction_assert(condition, ...) \ if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); } /** * Call into Python to allocate and zero a tensor as per info. */ static PyObject* _allocate_grad_output(output_info_type& info, AutoGPU& gpu_guard) { // TODO: no need to do this for non-differentiable outputs PyObject *tensor_cls = std::get<0>(info); gpu_guard.setDevice(std::get<1>(info)); std::vector &sizes = std::get<2>(info); std::vector long_sizes(sizes.begin(), sizes.end()); THPObjectPtr grad_size(THPSize_New(long_sizes.size(), long_sizes.data())); if (!grad_size) throw python_error(); THPObjectPtr new_grad(PyObject_CallFunctionObjArgs(tensor_cls, grad_size.get(), NULL)); if (!new_grad) throw python_error(); THPObjectPtr result(PyObject_CallMethod(new_grad.get(), "zero_", "")); if (!result) throw python_error(); return new_grad.release(); } namespace torch { namespace autograd { auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list { AutoGIL gil; THPObjectPtr pyInputs(PyTuple_New(inputs.size())); if (!pyInputs) throw python_error(); for (size_t i = 0; i != inputs.size(); ++i) { PyObject* input; if (inputs[i]) { input = createPyObject(inputs[i]->data); if (!input) throw python_error(); } else { input = Py_None; Py_INCREF(input); } PyTuple_SET_ITEM(pyInputs.get(), i, input); } THPObjectPtr r(PyObject_CallMethod( obj, "_do_backward", "OO", pyInputs.get(), Py_True)); if (!r) throw python_error(); auto num_outputs = PyTuple_GET_SIZE(r.get()); tensor_list tensor_results(num_outputs); for (int i = 0; i != num_outputs; ++i) { PyObject* obj = PyTuple_GET_ITEM(r.get(), i); if (obj != Py_None) { if (!THPModule_isTensor(obj)) { std::string msg("expected Tensor (got '"); msg += THPUtils_typename(obj); msg += "')'"; throw std::runtime_error(msg); } tensor_results[i] = createTensor(obj); } } // XXX: this might get requires_grad wrong - there's no way to figure out // if _do_backward didn't use ctx.saved_variables and as a result some // Variables might require grad, even if no args do. Unfortunately, this // leads to unexpected error messages ("no nodes require computing gradients"), // but I don't have a better idea. These functions would raise an error // in backward anyway. return wrap_outputs(inputs, std::move(tensor_results), [this](FunctionFlags &&f) { return std::make_shared(name() + " is not differentiable twice", std::move(f)); }); } // NOTE: this function is written in a way that assumes it's only called for backward; // it's used by engine.cpp. This is responsible for forwarding a call from // C++'s Function::apply to a Python method "apply". auto PyFunction::apply(const variable_list& inputs) -> variable_list { AutoGIL gil; AutoGPU _gpu_guard(-1); THPFunction* py_fn = (THPFunction*)obj; THPObjectPtr _legacy(PyObject_GetAttrString(obj, "_is_legacy")); if (_legacy == Py_True) { return legacy_apply(inputs); } // Massage a C++ variable_list into a Python arguments tuple auto num_inputs = inputs.size(); THPObjectPtr pyInputs(PyTuple_New(num_inputs)); if (!pyInputs) throw python_error(); auto& output_info = *py_fn->output_info; for (size_t i = 0; i < num_inputs; ++i) { PyObject* input; if (inputs[i]) { input = THPVariable_Wrap(inputs[i]); } else { THPObjectPtr tensor(_allocate_grad_output(output_info[i], _gpu_guard)); input = THPVariable_NewLeaf(tensor); } if (!input) throw python_error(); PyTuple_SET_ITEM(pyInputs.get(), i, input); } // TODO: theoretically we could take a shortcut here and call apply directly THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply")); if (!apply_fn) throw python_error(); THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get())); if (!r) throw python_error(); ensure_tuple(r); auto& is_variable_input = *py_fn->is_variable_input; int num_outputs = PyTuple_GET_SIZE(r.get()); int num_forward_inputs = is_variable_input.size(); // Returning too many results is ok, but only as long as they're all None. // Truncate the result tuple in that case. if (num_outputs > num_forward_inputs) { bool all_none = true; for (int i = num_forward_inputs; i < num_outputs; i++) { all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None; } if (all_none) { num_outputs = num_forward_inputs; r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs); if (!r) throw python_error(); } } // Now the number of gradients should match if (num_outputs != num_forward_inputs) { std::string msg("function "); msg += name() + " returned an incorrect number of gradients (expected "; msg += std::to_string(num_forward_inputs) + ", got " ; msg += std::to_string(num_outputs) + ")"; throw std::runtime_error(msg); } // Massage the Python results tuple back into a C++ variable_list variable_list results; results.reserve(num_outputs); for (int i = 0; i != num_outputs; ++i) { PyObject* output = PyTuple_GET_ITEM(r.get(), i); bool was_variable = is_variable_input[i]; if (!was_variable) { if (output != Py_None) { std::string msg("function "); msg += name() + " returned a gradient different than None at position "; msg += std::to_string(i + 1) + ", but the corresponding forward input was not a Variable"; throw std::runtime_error(msg); } continue; } if (output != Py_None) { if (!THPVariable_Check(output)) { std::string msg("expected Variable or None (got "); msg += THPUtils_typename(output); msg += ")"; throw std::runtime_error(msg); } results.emplace_back(((THPVariable*)output)->cdata); } else { results.emplace_back(); } } return results; } auto PyFunction::releaseVariables() -> void { AutoGIL gil; auto f = (THPFunction*) obj; delete f->saved_variables; f->saved_variables = nullptr; f->has_freed_buffers = 1; } auto PyFunction::name() -> std::string { AutoGIL gil; auto f = (THPFunction*) obj; return std::string(Py_TYPE(f)->tp_name); } }} // namespace torch::autograd // Traverse and clear are required for supporting Python's GC cycle handling. static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg) { for (auto& hook : self->cdata.pre_hooks) { if (auto pyhook = dynamic_cast(hook.get())) { Py_VISIT(pyhook->dict); } } for (auto& hook : self->cdata.post_hooks) { if (auto pyhook = dynamic_cast(hook.get())) { Py_VISIT(pyhook->dict); } } Py_VISIT(self->to_save); Py_VISIT(self->shared_pairs); Py_VISIT(self->non_differentiable); Py_VISIT(self->dirty_tensors); return 0; } static int THPFunction_clear(THPFunction *self) { self->cdata.num_inputs = 0; Py_CLEAR(self->needs_input_grad); Py_CLEAR(self->to_save); Py_CLEAR(self->shared_pairs); Py_CLEAR(self->non_differentiable); Py_CLEAR(self->dirty_tensors); auto saved_variables = self->saved_variables; self->saved_variables = NULL; delete saved_variables; auto output_info = self->output_info; self->output_info = NULL; delete output_info; auto is_variable_input = self->is_variable_input; self->is_variable_input = NULL; delete is_variable_input; // XXX: this will clear all hooks (not only Python ones) // I guess it's ok to leave it as is for now. auto pre_hooks = std::move(self->cdata.pre_hooks); auto post_hooks = std::move(self->cdata.post_hooks); return 0; } static void THPFunction_dealloc(THPFunction* self) { PyObject_GC_UnTrack(self); THPFunction_clear(self); self->cdata.~PyFunction(); Py_TYPE(self)->tp_free((PyObject*)self); } PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { PyObject* obj = type->tp_alloc(type, 0); if (!obj) return NULL; // Python zero-initializes the object memory, so there's no need to initialize // most fields THPFunction* self = (THPFunction*)obj; new (&self->cdata) PyFunction(obj); self->cdata.num_inputs = -1; self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass); return obj; } //////////////////////////////////////////////////////////////////////////////// // Forward //////////////////////////////////////////////////////////////////////////////// using t2var_type = std::unordered_map; // Bump the counters of all recorded dirty input tensors, adding each of them // into dirty_inputs. Also does some sanity checking. static void _mark_dirty(THPFunction *self, t2var_type &t2var, std::unordered_set &dirty_inputs) { // Increase versions of modified tensors if (!self->dirty_tensors) return; THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd " "internal error: dirty_tensors attribute is expected to be a tuple " "but is %s", THPUtils_typename(self->dirty_tensors)); Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors); for (int i = 0; i < num_dirty; i++) { PyObject *tensor = PyTuple_GET_ITEM(self->dirty_tensors, i); dirty_inputs.insert(tensor); THPVariable *variable; try { variable = t2var.at(tensor); } catch (std::out_of_range &e) { THPFunction_assert(THPModule_isTensor(tensor), "mark_dirty can " "only accept tensors, but argument %d is of type %s", i, THPUtils_typename(tensor)); THPFunction_assert(false, "mark_dirty only accepts input tensors, but " "argument %d isn't one", i); } auto &v_counter = *variable->cdata->version_counter; THPFunction_assert(v_counter.var_refcnt() == 1, "in-place operations can be " "only used on variables that don't share storage with any other " "variables, but detected that there are %d objects sharing it", v_counter.var_refcnt()); v_counter++; } // We're not going to ever need this so let's remove references now Py_DECREF(self->dirty_tensors); self->dirty_tensors = NULL; } static void _transplant_var(Variable& var, const std::shared_ptr& fn, int output_nr, bool is_volatile) { if (is_volatile) { var.grad_fn = nullptr; var.requires_grad = false; var.is_volatile = true; var.output_nr = 0; } else { var.grad_fn = fn; var.requires_grad = fn->is_executable; var.is_volatile = is_volatile; var.output_nr = output_nr; } var.grad = nullptr; var.hooks.clear(); if (auto grad_acc_fn = var.grad_accumulator.lock()) { auto grad_acc = dynamic_cast(grad_acc_fn.get()); grad_acc->variable.reset(); grad_acc->variable_grad.reset(); } } // Given a Python tuple of raw output tensors (raw_output), set each of // the corresponding entries in a different Python tuple (outputs) with // these tensors wrapped with variables. We save the gradient function (self) // to the variable if the output is not volatile (is_volatile). // // There is a considerable amount of complexity to handle if the operation // that produced these output tensors is inplace. A mapping of *input* // tensors to variables (t2var) is used to test if this occurred, and // the set of dirty tensors (dirty_inputs) is used to figure out what to // do in this case. static void _wrap_outputs(THPFunction *self, t2var_type &t2var, std::unordered_set &dirty_inputs, PyObject *raw_output, PyObject *outputs, Node * this_expr, bool is_volatile) { // Either both or none should be true assert(!(((bool)this_expr) ^ GlobalTracingState.tracing())); bool is_tracing = this_expr && GlobalTracingState.tracing(); auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self); Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output); if (self->cdata.is_executable) { self->output_info = new std::vector(); self->output_info->reserve(num_outputs); } for (int i = 0; i < num_outputs; i++) { PyObject *output = PyTuple_GET_ITEM(raw_output, i); THPVariable *output_var; auto it = t2var.find(output); if (it == t2var.end()) { // A completely new tensor - just wrap it and continue if (is_volatile) { output_var = (THPVariable*)THPVariable_NewVolatile(output); } else { output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata); } } else { // If one of the outputs was also an input tensor it's a bit more complicated. THPVariable *input_var = it->second; auto& input_var_ = *input_var->cdata; if (input_var_.grad_fn) { Py_INCREF(input_var); output_var = input_var; // If it's not a leaf we want to move it in the graph so backprop // will be computed correctly, but only if it was modified. Otherwise // it's better to minimize the number of operations that mutate the graph. // grad_fn <- variable <- self ==> grad_fn <- self <- variable if (dirty_inputs.count(output) > 0) { _transplant_var(input_var_, cdata, i, is_volatile); } } else { // If the leaf Variable has been returned, we have to move it after the // current function to ensure the gradient is computed correctly. // There are two cases now: // 1. It has been modified in-place. If it didn't require_grad it's ok, // but if it does, then it's a clear error. // 2. It hasn't been modified. This means that it must have been // returned unchanged, and we can simply return a new Variable // referencing the same storage. if (dirty_inputs.count(output) > 0) { if (!input_var_.requires_grad) { Py_INCREF(input_var); output_var = input_var; _transplant_var(input_var_, cdata, i, is_volatile); } else { // input_var_.requires_grad throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation."); } } else { // An input has been returned, but it wasn't modified. It's better // not to move the Variable, because there are some legitimate cases // where making it non-leaf would break stuff (e.g. broadcast). Also, // returning the input Variable is not a good option either, // because if someone registers hooks on it, they will fire with grads // from all usages, not only from usages of this output. This is why // we'll return a copy and join their version counters. This has // a side-effect of making in-place ops on any of these Variables an // immediate error, but it would be raised anyway once someone // calls backward. if (is_volatile) { output_var = (THPVariable*)THPVariable_NewVolatile(output); } else { output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata); } if (!output_var) throw python_error(); output_var->cdata->version_counter->join_with(*input_var->cdata->version_counter); } } } if (!output_var) throw python_error(); if (self->output_info) { auto& output_tensor = output_var->cdata->data; self->output_info->emplace_back( (PyObject *)getPyTypeObject(output_tensor), output_tensor.type().isCuda() ? output_tensor.get_device() : -1, output_tensor.sizes() ); } t2var[output] = output_var; if (is_tracing) { // NOTE: normally we don't add Select nodes when there's only a single // output, but Python nodes can't be optimized away, so we simplify the // code here. Node* sel = GlobalTracingState.current().appendNewNode