#include "torch/csrc/autograd/python_function.h" #include #include #include #include #include #include "THP.h" #include "torch/csrc/autograd/functions/accumulate_grad.h" #include "torch/csrc/autograd/functions/basic_ops.h" #include "torch/csrc/autograd/functions/utils.h" #include "torch/csrc/autograd/python_cpp_function.h" #include "torch/csrc/autograd/python_hook.h" #include "torch/csrc/jit/tracer.h" #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/utils/auto_gil.h" #include "torch/csrc/utils/auto_gpu.h" #include "torch/csrc/Exceptions.h" #ifdef WITH_CUDA #include "cuda/AutoGPU.h" #endif using namespace torch; using namespace torch::autograd; using namespace torch::jit; PyObject *THPFunctionClass = NULL; PyObject *THPStochasticFunctionClass = NULL; PyObject *THPBatchNormBackwardBackwardFunction = NULL; #define THPFunction_assert(condition, ...) \ if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); } /** * Call into Python to allocate and zero a tensor as per info. */ static PyObject* _allocate_grad_output(output_info_type& info, AutoGPU& gpu_guard) { // TODO: no need to do this for non-differentiable outputs PyObject *tensor_cls = std::get<0>(info); gpu_guard.setDevice(std::get<1>(info)); std::vector &sizes = std::get<2>(info); std::vector long_sizes(sizes.begin(), sizes.end()); THPObjectPtr grad_size(THPSize_New(long_sizes.size(), long_sizes.data())); if (!grad_size) throw python_error(); THPObjectPtr new_grad(PyObject_CallFunctionObjArgs(tensor_cls, grad_size.get(), NULL)); if (!new_grad) throw python_error(); THPObjectPtr result(PyObject_CallMethod(new_grad.get(), "zero_", "")); if (!result) throw python_error(); return new_grad.release(); } namespace torch { namespace autograd { auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list { AutoGIL gil; THPObjectPtr pyInputs(PyTuple_New(inputs.size())); if (!pyInputs) throw python_error(); for (size_t i = 0; i != inputs.size(); ++i) { PyObject* input; if (inputs[i]) { input = createPyObject(inputs[i]->data); if (!input) throw python_error(); } else { input = Py_None; Py_INCREF(input); } PyTuple_SET_ITEM(pyInputs.get(), i, input); } THPObjectPtr r(PyObject_CallMethod( obj, "_do_backward", "OO", pyInputs.get(), Py_True)); if (!r) throw python_error(); auto num_outputs = PyTuple_GET_SIZE(r.get()); tensor_list tensor_results(num_outputs); for (int i = 0; i != num_outputs; ++i) { PyObject* obj = PyTuple_GET_ITEM(r.get(), i); if (obj != Py_None) { if (!THPModule_isTensor(obj)) { std::string msg("expected Tensor (got '"); msg += THPUtils_typename(obj); msg += "')'"; throw std::runtime_error(msg); } tensor_results[i] = createTensor(obj); } } // XXX: this might get requires_grad wrong - there's no way to figure out // if _do_backward didn't use ctx.saved_variables and as a result some // Variables might require grad, even if no args do. Unfortunately, this // leads to unexpected error messages ("no nodes require computing gradients"), // but I don't have a better idea. These functions would raise an error // in backward anyway. return wrap_outputs(inputs, std::move(tensor_results), [this](FunctionFlags &&f) { return std::make_shared(name() + " is not differentiable twice", std::move(f)); }); } // NOTE: this function is written in a way that assumes it's only called for backward; // it's used by engine.cpp. This is responsible for forwarding a call from // C++'s Function::apply to a Python method "apply". auto PyFunction::apply(const variable_list& inputs) -> variable_list { AutoGIL gil; AutoGPU _gpu_guard(-1); THPFunction* py_fn = (THPFunction*)obj; THPObjectPtr _legacy(PyObject_GetAttrString(obj, "_is_legacy")); if (_legacy == Py_True) { return legacy_apply(inputs); } // Massage a C++ variable_list into a Python arguments tuple auto num_inputs = inputs.size(); THPObjectPtr pyInputs(PyTuple_New(num_inputs)); if (!pyInputs) throw python_error(); auto& output_info = *py_fn->output_info; for (size_t i = 0; i < num_inputs; ++i) { PyObject* input; if (inputs[i]) { input = THPVariable_Wrap(inputs[i]); } else { THPObjectPtr tensor(_allocate_grad_output(output_info[i], _gpu_guard)); input = THPVariable_NewLeaf(tensor); } if (!input) throw python_error(); PyTuple_SET_ITEM(pyInputs.get(), i, input); } THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply")); if (!apply_fn) throw python_error(); THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get())); if (!r) throw python_error(); ensure_tuple(r); auto& is_variable_input = *py_fn->is_variable_input; int num_outputs = PyTuple_GET_SIZE(r.get()); int num_forward_inputs = is_variable_input.size(); // Returning too many results is ok, but only as long as they're all None. // Truncate the result tuple in that case. if (num_outputs > num_forward_inputs) { bool all_none = true; for (int i = num_forward_inputs; i < num_outputs; i++) { all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None; } if (all_none) { num_outputs = num_forward_inputs; r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs); if (!r) throw python_error(); } } // Now the number of gradients should match if (num_outputs != num_forward_inputs) { std::string msg("function "); msg += name() + " returned an incorrect number of gradients (expected "; msg += std::to_string(num_forward_inputs) + ", got " ; msg += std::to_string(num_outputs) + ")"; throw std::runtime_error(msg); } // Massage the Python results tuple back into a C++ variable_list variable_list results; results.reserve(num_outputs); for (int i = 0; i != num_outputs; ++i) { PyObject* output = PyTuple_GET_ITEM(r.get(), i); bool was_variable = is_variable_input[i]; if (!was_variable) { if (output != Py_None) { std::string msg("function "); msg += name() + " returned a gradient different than None at position "; msg += std::to_string(i + 1) + ", but the corresponding forward input was not a Variable"; throw std::runtime_error(msg); } continue; } if (output != Py_None) { if (!THPVariable_Check(output)) { std::string msg("expected Variable or None (got "); msg += THPUtils_typename(output); msg += ")"; throw std::runtime_error(msg); } results.emplace_back(((THPVariable*)output)->cdata); } else { results.emplace_back(); } } return results; } auto PyFunction::releaseVariables() -> void { AutoGIL gil; auto f = (THPFunction*) obj; delete f->saved_variables; f->saved_variables = nullptr; f->has_freed_buffers = 1; } auto PyFunction::name() -> std::string { AutoGIL gil; auto f = (THPFunction*) obj; return std::string(Py_TYPE(f)->tp_name); } auto PyFunction::getSharedPtr() -> std::shared_ptr { return THPFunction_asFunction((THPFunction*)obj); } }} // namespace torch::autograd // Traverse and clear are required for supporting Python's GC cycle handling. static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg) { for (auto& hook : self->cdata.pre_hooks) { if (auto pyhook = dynamic_cast(hook.get())) { Py_VISIT(pyhook->dict); } } for (auto& hook : self->cdata.post_hooks) { if (auto pyhook = dynamic_cast(hook.get())) { Py_VISIT(pyhook->dict); } } Py_VISIT(self->to_save); Py_VISIT(self->shared_pairs); Py_VISIT(self->non_differentiable); Py_VISIT(self->dirty_tensors); return 0; } static int THPFunction_clear(THPFunction *self) { self->cdata.num_inputs = 0; Py_CLEAR(self->needs_input_grad); Py_CLEAR(self->to_save); Py_CLEAR(self->shared_pairs); Py_CLEAR(self->non_differentiable); Py_CLEAR(self->dirty_tensors); auto saved_variables = self->saved_variables; self->saved_variables = NULL; delete saved_variables; auto output_info = self->output_info; self->output_info = NULL; delete output_info; auto is_variable_input = self->is_variable_input; self->is_variable_input = NULL; delete is_variable_input; // XXX: this will clear all hooks (not only Python ones) // I guess it's ok to leave it as is for now. auto pre_hooks = std::move(self->cdata.pre_hooks); auto post_hooks = std::move(self->cdata.post_hooks); return 0; } static void THPFunction_dealloc(THPFunction* self) { PyObject_GC_UnTrack(self); THPFunction_clear(self); self->cdata.~PyFunction(); Py_TYPE(self)->tp_free((PyObject*)self); } PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { PyObject* obj = type->tp_alloc(type, 0); if (!obj) return NULL; // Python zero-initializes the object memory, so there's no need to initialize // most fields THPFunction* self = (THPFunction*)obj; new (&self->cdata) PyFunction(obj); self->cdata.num_inputs = -1; self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass); return obj; } //////////////////////////////////////////////////////////////////////////////// // Forward //////////////////////////////////////////////////////////////////////////////// using t2var_type = std::unordered_map; // Bump the counters of all recorded dirty input tensors, adding each of them // into dirty_inputs. Also does some sanity checking. static void _mark_dirty(THPFunction *self, t2var_type &t2var, std::unordered_set &dirty_inputs) { // Increase versions of modified tensors if (!self->dirty_tensors) return; THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd " "internal error: dirty_tensors attribute is expected to be a tuple " "but is %s", THPUtils_typename(self->dirty_tensors)); Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors); for (int i = 0; i < num_dirty; i++) { PyObject *tensor = PyTuple_GET_ITEM(self->dirty_tensors, i); dirty_inputs.insert(tensor); THPVariable *variable; try { variable = t2var.at(tensor); } catch (std::out_of_range &e) { THPFunction_assert(THPModule_isTensor(tensor), "mark_dirty can " "only accept tensors, but argument %d is of type %s", i, THPUtils_typename(tensor)); THPFunction_assert(false, "mark_dirty only accepts input tensors, but " "argument %d isn't one", i); } auto &v_counter = *variable->cdata->version_counter; THPFunction_assert(v_counter.var_refcnt() == 1, "in-place operations can be " "only used on variables that don't share storage with any other " "variables, but detected that there are %d objects sharing it", v_counter.var_refcnt()); v_counter++; } // We're not going to ever need this so let's remove references now Py_DECREF(self->dirty_tensors); self->dirty_tensors = NULL; } static void _transplant_var(Variable& var, const std::shared_ptr& fn, int output_nr, bool is_volatile) { if (is_volatile) { var.grad_fn = nullptr; var.requires_grad = false; var.is_volatile = true; var.output_nr = 0; } else { var.grad_fn = fn; var.requires_grad = fn->is_executable; var.is_volatile = is_volatile; var.output_nr = output_nr; } var.grad = nullptr; var.hooks.clear(); if (auto grad_acc_fn = var.grad_accumulator.lock()) { auto grad_acc = dynamic_cast(grad_acc_fn.get()); grad_acc->variable.reset(); grad_acc->variable_grad.reset(); } } // Given a Python tuple of raw output tensors (raw_output), set each of // the corresponding entries in a different Python tuple (outputs) with // these tensors wrapped with variables. We save the gradient function (self) // to the variable if the output is not volatile (is_volatile). // // There is a considerable amount of complexity to handle if the operation // that produced these output tensors is inplace. A mapping of *input* // tensors to variables (t2var) is used to test if this occurred, and // the set of dirty tensors (dirty_inputs) is used to figure out what to // do in this case. static void _wrap_outputs(THPFunction *self, t2var_type &t2var, std::unordered_set &dirty_inputs, PyObject *raw_output, PyObject *outputs, bool is_volatile) { auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self); Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output); if (self->cdata.is_executable) { self->output_info = new std::vector(); self->output_info->reserve(num_outputs); } for (int i = 0; i < num_outputs; i++) { PyObject *output = PyTuple_GET_ITEM(raw_output, i); THPVariable *output_var; auto it = t2var.find(output); if (it == t2var.end()) { // A completely new tensor - just wrap it and continue if (is_volatile) { output_var = (THPVariable*)THPVariable_NewVolatile(output); } else { output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata); } } else { // If one of the outputs was also an input tensor it's a bit more complicated. THPVariable *input_var = it->second; auto& input_var_ = *input_var->cdata; if (input_var_.grad_fn) { Py_INCREF(input_var); output_var = input_var; // If it's not a leaf we want to move it in the graph so backprop // will be computed correctly, but only if it was modified. Otherwise // it's better to minimize the number of operations that mutate the graph. // grad_fn <- variable <- self ==> grad_fn <- self <- variable if (dirty_inputs.count(output) > 0) { _transplant_var(input_var_, cdata, i, is_volatile); } } else { // If the leaf Variable has been returned, we have to move it after the // current function to ensure the gradient is computed correctly. // There are two cases now: // 1. It has been modified in-place. If it didn't require_grad it's ok, // but if it does, then it's a clear error. // 2. It hasn't been modified. This means that it must have been // returned unchanged, and we can simply return a new Variable // referencing the same storage. if (dirty_inputs.count(output) > 0) { if (!input_var_.requires_grad) { Py_INCREF(input_var); output_var = input_var; _transplant_var(input_var_, cdata, i, is_volatile); } else { // input_var_.requires_grad throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation."); } } else { // An input has been returned, but it wasn't modified. It's better // not to move the Variable, because there are some legitimate cases // where making it non-leaf would break stuff (e.g. broadcast). Also, // returning the input Variable is not a good option either, // because if someone registers hooks on it, they will fire with grads // from all usages, not only from usages of this output. This is why // we'll return a copy and join their version counters. This has // a side-effect of making in-place ops on any of these Variables an // immediate error, but it would be raised anyway once someone // calls backward. if (is_volatile) { output_var = (THPVariable*)THPVariable_NewVolatile(output); } else { output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata); } if (!output_var) throw python_error(); output_var->cdata->version_counter->join_with(*input_var->cdata->version_counter); } } } if (!output_var) throw python_error(); if (self->output_info) { auto& output_tensor = output_var->cdata->data; self->output_info->emplace_back( (PyObject *)getPyTypeObject(output_tensor), output_tensor.type().isCuda() ? output_tensor.get_device() : -1, output_tensor.sizes() ); } t2var[output] = output_var; output_var->cdata->output_nr = i; PyTuple_SET_ITEM(outputs, i, (PyObject*)output_var); } } // Save any variables that requested by to_save static void _save_variables(THPFunction* self, t2var_type &t2var) { if (!self->to_save) return; THPFunction_assert(PyTuple_Check(self->to_save), "autograd internal " "error: to_save attribute is expected to be a tuple but is %s", THPUtils_typename(self->to_save)); Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save); self->saved_variables = new std::vector(); self->saved_variables->reserve(num_saved); auto cdata_ptr = &self->cdata; for (int i = 0; i < num_saved; i++) { PyObject *tensor = PyTuple_GET_ITEM(self->to_save, i); if (tensor == Py_None) { self->saved_variables->emplace_back(); continue; } THPVariable *variable; try { variable = t2var.at(tensor); } catch(std::out_of_range &e) { THPFunction_assert(THPModule_isTensor(tensor), "save_for_backward can only save tensors, but argument %d is of " "type %s", i, THPUtils_typename(tensor)); THPFunction_assert(false, "save_for_backward can only save input or output " "tensors, but argument %d doesn't satisfy this condition", i); } self->saved_variables->emplace_back(variable->cdata->save(cdata_ptr)); } // Free .to_save Py_DECREF(self->to_save); self->to_save = NULL; } static void _join_version_counters(THPFunction *self, t2var_type &t2var) { if (!self->shared_pairs) return; THPFunction_assert(PyTuple_Check(self->shared_pairs), "autograd internal " "error: shared_pairs attribute is expected to be a tuple but is %s", THPUtils_typename(self->shared_pairs)); Py_ssize_t num_shared = PyTuple_GET_SIZE(self->shared_pairs); for (int i = 0; i < num_shared; i++) { PyObject *shared_tuple = PyTuple_GET_ITEM(self->shared_pairs, i); THPFunction_assert(PyTuple_Check(shared_tuple), "mark_shared_storages " "accepts a number of pairs, but one of the arguments is of type %s", THPUtils_typename(shared_tuple)); THPFunction_assert(PyTuple_GET_SIZE(shared_tuple) == 2, "mark_shared_storages accepts pairs, but argument %d is a tuple of " "%d elements", i, PyTuple_GET_SIZE(shared_tuple)); // Now we're sure it's really a pair! THPVariable *v1, *v2; try { v1 = t2var.at(PyTuple_GET_ITEM(shared_tuple, 0)); v2 = t2var.at(PyTuple_GET_ITEM(shared_tuple, 1)); } catch(std::out_of_range &e) { // One tuple items wasn't present in t2var, so there are two cases: // 1. it's not a tensor // 2. it's not an input nor an output PyObject *t1 = PyTuple_GET_ITEM(shared_tuple, 0); PyObject *t2 = PyTuple_GET_ITEM(shared_tuple, 1); THPFunction_assert(THPModule_isTensor(t1) && THPModule_isTensor(t2), "mark_shared_storages accepts pairs of tensors, but one of them " "contains %s and %s", THPUtils_typename(t1), THPUtils_typename(t2)); THPFunction_assert(false, "mark_shared_storages only accepts pairs of input " "and output tensors, but argument %d doesn't satify this " "condition", i); } v2->cdata->version_counter->join_with(*v1->cdata->version_counter); } // Free .shared_pairs Py_DECREF(self->shared_pairs); self->shared_pairs = NULL; } // Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable) static void _mark_non_differentiable(THPFunction *self, t2var_type &t2var) { if (!self->non_differentiable) return; THPFunction_assert(PyTuple_Check(self->non_differentiable), "autograd " "internal error: non_differentiable attribute is expected to be a " "tuple but is %s", THPUtils_typename(self->non_differentiable)); Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable); for (int i = 0; i < num_nondiff; i++) { PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i); THPVariable *var; try { var = t2var.at(t); THPFunction_assert(var->cdata->grad_fn.get() == &self->cdata, "mark_non_differentiable only accepts output tensors, but " "argument %d isn't an output", i); } catch (std::out_of_range &e) { THPFunction_assert(THPModule_isTensor(t), "mark_non_differentiable " "only accepts tensor arguments, but got %s", THPUtils_typename(t)); THPFunction_assert(false, "mark_non_differentiable only accepts function " "outputs"); } var->cdata->requires_grad = 0; } Py_DECREF(self->non_differentiable); self->non_differentiable = NULL; } struct UnpackedInput { THPObjectPtr tensor_input; variable_list input_vars; }; struct InputFlags { FunctionFlags flags; THPObjectPtr needs_input_grad; std::vector is_variable_input; }; template std::pair unpack_input(PyObject *args) { UnpackedInput unpacked; InputFlags flags; auto num_args = PyTuple_GET_SIZE(args); unpacked.tensor_input = PyTuple_New(num_args); flags.needs_input_grad = PyTuple_New(num_args); for (int i = 0; i < num_args; i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); PyObject *new_arg; bool is_variable = THPVariable_Check(arg); flags.is_variable_input.push_back(is_variable); if (!is_variable) { if (enforce_variables) { THPUtils_setError("expected a Variable argument, but got %s", THPUtils_typename(arg)); throw python_error(); } Py_INCREF(arg); new_arg = arg; Py_INCREF(Py_False); PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False); } else { THPVariable* variable = (THPVariable*)arg; new_arg = THPVariable_get_data(variable); unpacked.input_vars.push_back(variable->cdata); PyObject* needs_grad = variable->cdata->requires_grad ? Py_True : Py_False; Py_INCREF(needs_grad); PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad); } PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg); } flags.flags = Function::flags(unpacked.input_vars); return std::make_pair(std::move(unpacked), std::move(flags)); } PyObject* process_outputs(THPFunction* grad_fn, const UnpackedInput& unpacked, THPObjectPtr&& raw_output, bool is_volatile) { bool unpack_output = ensure_tuple(raw_output); auto num_outputs = PyTuple_GET_SIZE(raw_output.get()); THPObjectPtr outputs(PyTuple_New(num_outputs)); if (!outputs) throw python_error(); grad_fn->cdata.num_inputs = num_outputs; // Initialize t2var map t2var_type t2var; for (auto& c_var : unpacked.input_vars) { THPVariable* py_var = (THPVariable*)c_var->pyobj; t2var.emplace(py_var->data, py_var); } std::unordered_set dirty_inputs; _mark_dirty(grad_fn, t2var, dirty_inputs); _wrap_outputs(grad_fn, t2var, dirty_inputs, raw_output, outputs, is_volatile); _join_version_counters(grad_fn, t2var); if (grad_fn->cdata.is_executable) { _mark_non_differentiable(grad_fn, t2var); _save_variables(grad_fn, t2var); } else { // Remove unnecessary attributes Py_XDECREF(grad_fn->to_save); grad_fn->to_save = NULL; Py_XDECREF(grad_fn->non_differentiable); grad_fn->non_differentiable = NULL; } // Unpack the output, unless .forward() returned a tuple if (unpack_output) { PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0); Py_INCREF(output); return output; } return outputs.release(); } struct TraceInfo { bool is_tracing; std::shared_ptr tracing_state; bool is_backward_traceable; std::shared_ptr eval_state; }; static TraceInfo trace_wrap_inputs(THPObjectPtr& raw_inputs, InputFlags& flags, UnpackedInput& unpacked) { TraceInfo info; info.is_tracing = tracer::isTracing(unpacked.input_vars); if (!info.is_tracing) return info; info.tracing_state = tracer::getTracingState(unpacked.input_vars); // TODO: actually trace backward of some ops (e.g. Add) info.is_backward_traceable = false; if (!info.is_backward_traceable) { // NOTE: this modifies unpacked.input_vars info.eval_state = tracer::EvalExitHook::registerHook(info.tracing_state, unpacked.input_vars); // Make a copy of raw_inputs with new Variables int num_inputs = PyTuple_GET_SIZE(raw_inputs.get()); THPObjectPtr updated_inputs(PyTuple_New(num_inputs)); int vars_i = 0; for (int i = 0; i < num_inputs; ++i) { PyObject *obj; if (flags.is_variable_input[i]) { obj = THPVariable_Wrap(unpacked.input_vars[vars_i++]); } else { obj = PyTuple_GET_ITEM(raw_inputs.get(), i); Py_INCREF(obj); } PyTuple_SET_ITEM(updated_inputs.get(), i, obj); } // Replace inputs and Recompute unpacked and flags std::tie(unpacked, flags) = unpack_input(updated_inputs.get()); raw_inputs = std::move(updated_inputs); } return info; } // This sort of reimplements unpack_input, but we have our own requirements static void trace_create(TraceInfo& info, PyObject* op_obj, PyObject *input_objects, THPObjectPtr& output_objects, const variable_list& input_vars, const std::vector& is_variable_input) { if (!info.is_tracing) return; // Isolate C variable ptrs in a vector bool unpack_output = ensure_tuple(output_objects); variable_list output_vars; for (int i = 0; i < PyTuple_GET_SIZE(output_objects.get()); ++i) { THPVariable *var = (THPVariable*)PyTuple_GET_ITEM(output_objects.get(), i); output_vars.emplace_back(var->cdata); } // Save scalar args and the calling convention auto& tracing_state = info.tracing_state; auto& graph = tracing_state->graph; auto num_args = PyTuple_GET_SIZE(input_objects); pyobj_list scalar_args; std::string arg_types; arg_types.reserve(num_args); scalar_args.reserve(num_args); for (int i = 0; i < num_args; i++) { PyObject *arg_object = PyTuple_GET_ITEM(input_objects, i); if (is_variable_input[i]) { arg_types.push_back('t'); } else { arg_types.push_back('s'); Py_INCREF(arg_object); scalar_args.emplace_back(arg_object); } } // Note [getValueTrace can allocate nodes] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // When an input variable is not traced, we create a constant instruction // to represent it. This means that you must invoke getValueTrace() BEFORE // actually constructing the function that takes these variables as inputs. // If we do it the other order, the graph will be in the wrong topological // order. // See Note [getValueTrace can allocate nodes] std::vector value_traces; value_traces.reserve(input_vars.size()); for (auto& i : input_vars) value_traces.emplace_back(tracer::getValueTrace(tracing_state, i)); // NB: this function is called only from THPFunction_apply, which is used only // when computing forward. All these functions are non-traceable by definition, // because they are implemented in terms of tensor operations. Hence, there's no // need for any conditionals in here and we can always create the node. // Construct the IR Node and its Selects Py_INCREF(op_obj); auto this_expr = graph->appendNewNode( THPObjectPtr(op_obj), arg_types, false, // TODO: remove is_legacy std::move(scalar_args)); for (auto t : value_traces) this_expr->addInput(t); int num_outputs = output_vars.size(); for (int i = 0; i < num_outputs; ++i) { auto& output = output_vars[i]; // NOTE: normally we don't add Select nodes when there's only a single // output, but Python nodes can't be optimized away, so we simplify the // code here. Node* sel = graph->appendNewNode