mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
JIT cleanup (#7631)
Cleans up dead code in the JIT: * Remove interpreter_autograd_function * Remove Handles * Remove HandleBuilder * Remove creates_handles, and tracing_autograd_python_function flags * Remove unused var_args * Fix submodules
This commit is contained in:
parent
e6f7e1807d
commit
286cd04a20
1
setup.py
1
setup.py
|
|
@ -648,7 +648,6 @@ main_sources = [
|
|||
"torch/csrc/jit/export.cpp",
|
||||
"torch/csrc/jit/import.cpp",
|
||||
"torch/csrc/jit/autodiff.cpp",
|
||||
"torch/csrc/jit/interpreter_autograd_function.cpp",
|
||||
"torch/csrc/jit/python_arg_flatten.cpp",
|
||||
"torch/csrc/jit/variable_flags.cpp",
|
||||
"torch/csrc/jit/passes/create_autodiff_subgraphs.cpp",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
graph(%0 : Double(2, 2)) {
|
||||
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0), scope: Dropout
|
||||
%1 : Double(2, 2) = ^Dropout(0.6, True, False)(%0), scope: Dropout
|
||||
return (%1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -206,7 +206,6 @@ set(TORCH_SRCS
|
|||
${TORCH_SRC_DIR}/csrc/jit/tracer_state.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/type.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/interpreter_autograd_function.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/export.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/import.cpp
|
||||
${TORCH_SRC_DIR}/csrc/onnx/onnx.cpp
|
||||
|
|
|
|||
|
|
@ -487,7 +487,6 @@ static PyObject* initModule() {
|
|||
ASSERT_TRUE(THPVariable_initModule(module));
|
||||
ASSERT_TRUE(THPFunction_initModule(module));
|
||||
ASSERT_TRUE(THPEngine_initModule(module));
|
||||
torch::autograd::initAutogradClosureBindings(module);
|
||||
torch::jit::initJITBindings(module);
|
||||
torch::onnx::initONNXBindings(module);
|
||||
torch::autograd::initNNFunctions(module);
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ void THPAutograd_initFunctions();
|
|||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
void initAutogradClosureBindings(PyObject* module);
|
||||
|
||||
PyMethodDef* python_functions();
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -37,11 +37,7 @@ variable_list Function::traced_apply(variable_list inputs) {
|
|||
|
||||
// Insert a CppOp in the trace.
|
||||
auto& graph = state->graph;
|
||||
std::vector<VariableFlags> var_flags;
|
||||
for(auto & input: inputs) {
|
||||
var_flags.push_back(VariableFlags::of(input));
|
||||
}
|
||||
auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
|
||||
auto* this_node = graph->createCppOp(get_shared_ptr());
|
||||
#ifndef NO_PYTHON
|
||||
this_node->setSourceLocation(std::make_shared<StringSourceLocation>(
|
||||
jit::tracer::getPythonInterpreterStackTrace()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#include "basic_ops.h"
|
||||
#include "tensor.h"
|
||||
#include "special.h"
|
||||
#include "torch/csrc/jit/interpreter_autograd_function.h"
|
||||
#include "torch/csrc/autograd/functions/pybind.h"
|
||||
#include "torch/csrc/autograd/python_cpp_function.h"
|
||||
#include "torch/csrc/autograd/generated/python_functions.h"
|
||||
|
|
@ -99,9 +98,6 @@ void THPAutograd_initFunctions()
|
|||
static PyTypeObject EvalClass;
|
||||
addClass<Eval, NoCtor>(module, EvalClass, "Eval");
|
||||
|
||||
static PyTypeObject InterpreterAutogradClass;
|
||||
addClass<torch::jit::InterpreterAutogradFunction, NoCtor>(module, InterpreterAutogradClass, "InterpreterAutogradFunction");
|
||||
|
||||
static PyTypeObject CopyBackwardsClass;
|
||||
addClass<CopyBackwards, NoCtor>(module, CopyBackwardsClass, "CopyBackwards");
|
||||
|
||||
|
|
@ -118,18 +114,3 @@ void THPAutograd_initFunctions()
|
|||
throw python_error();
|
||||
}
|
||||
}
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
void initAutogradClosureBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
py::class_<jit::InterpreterFunctionFactory,std::shared_ptr<jit::InterpreterFunctionFactory>>(m, "InterpreterFunctionFactory")
|
||||
.def("__call__", &jit::InterpreterFunctionFactory::construct_function)
|
||||
;
|
||||
|
||||
m.def("_jit_createInterpreterFactory", [](jit::tracer::TracingState* tracing_state) {
|
||||
return std::make_shared<jit::InterpreterFunctionFactory>(tracing_state);
|
||||
});
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -602,19 +602,7 @@ static void _trace_post_record(
|
|||
|
||||
auto state_lock = trace_info.state->lock();
|
||||
trace_info.n->i_(attr::inplace, is_inplace);
|
||||
|
||||
// See definition in function.cpp.
|
||||
THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")};
|
||||
if (!passes_py_bool) throw python_error();
|
||||
bool passes_state_transparently = passes_py_bool == Py_True;
|
||||
// NB: this path is executed only for forward of Python functions, so there's no need to check
|
||||
// tracing_state->in_eval_subgraph (it's always false, because they are never part of backward
|
||||
// subgraphs AND we don't even materialize the forward function).
|
||||
if (trace_info.state->creates_handles && !passes_state_transparently) {
|
||||
// TODO: sgross and ezyang don't know if this is right
|
||||
tracer::nontraceableBackwardSubgraph(input_vars, output_vars);
|
||||
Function::set_up_context_edge(trace_info.n, input_vars, output_vars);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const UnpackedInput& unpacked,
|
||||
|
|
|
|||
|
|
@ -336,6 +336,8 @@ struct PreprocessGraph {
|
|||
// which are annoying to handle since 99% of values are at::Tensor anyway
|
||||
// instead we create a fake subclass of TensorImpl that can be subclassed
|
||||
// to hold arbitrary things
|
||||
// Note: this is currently unused but will probably be useful in the future,
|
||||
// so we keep it around
|
||||
struct ContainerTensor : public at::TensorImpl {
|
||||
public:
|
||||
ContainerTensor()
|
||||
|
|
@ -365,75 +367,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
// Dummy function is the last function that the autograd engine calls
|
||||
// when evaluating Eval nodes. Its input tensors are the outputs that the
|
||||
// Eval node needs to produce.
|
||||
// We interscept these values using an Autograd callback. So the function itself
|
||||
// never runs.
|
||||
struct DummyFunction : autograd::Function {
|
||||
virtual autograd::variable_list apply(const autograd::variable_list& inputs) override {
|
||||
throw std::logic_error("DummyFunction::apply() called, but it should be blocked by a callback returning false");
|
||||
}
|
||||
};
|
||||
|
||||
// An AutogradHandle holds the information needed to run an Autograd backward pass
|
||||
// after running a forward operator (such as PythonOp, CppOp, or for double-backwards another Eval Op)
|
||||
// The EvalOperation uses AutogradHandle to perform this operation.
|
||||
struct AutogradHandle : public ContainerTensor {
|
||||
|
||||
// The inputs of DummyFunction are the gradients of the forward passes
|
||||
// inputs, and the _outputs_ of the run of the Autograd engine computing backward.
|
||||
// there is one entry in this list for each forward input that requires
|
||||
// gradients
|
||||
std::shared_ptr<DummyFunction> forward_inputs;
|
||||
|
||||
// there is one entry in this list for each output of the forward pass
|
||||
// that represents the location in the backwaard pass where the gradient
|
||||
// of this output should be inserted at the beginning of the backward pass
|
||||
autograd::edge_list forward_outputs;
|
||||
};
|
||||
|
||||
// HandleBuilder is used to construct the correct Autograd Handle objects
|
||||
// for use in a future stage.
|
||||
// It is used even when the future stage does not require a handle since
|
||||
// it also performs the conversions between Tensor and Variable, which
|
||||
// behave differently depending on whether a future handle needs to be
|
||||
// created.
|
||||
struct HandleBuilder {
|
||||
HandleBuilder(bool requires_handle) {
|
||||
if(requires_handle) {
|
||||
handle = new AutogradHandle();
|
||||
handle->forward_inputs = std::make_shared<DummyFunction>();
|
||||
}
|
||||
}
|
||||
autograd::Variable addInput(at::Tensor && input_, const VariableFlags & flags_) {
|
||||
autograd::Variable& input = static_cast<autograd::Variable&>(input_);
|
||||
if(handle && flags_.requires_grad) {
|
||||
auto variable = autograd::make_variable(input.data(), /*requires_grad=*/false);
|
||||
autograd::create_gradient_edge(variable, handle->forward_inputs);
|
||||
return variable;
|
||||
} else {
|
||||
return autograd::make_variable(input.data(), /*requires_grad=*/false);
|
||||
}
|
||||
}
|
||||
at::Tensor addOutput(const autograd::Variable & output) {
|
||||
if(handle) {
|
||||
handle->forward_outputs.push_back(output.gradient_edge());
|
||||
}
|
||||
return output.detach();
|
||||
}
|
||||
void writeTo(Stack & outputs) {
|
||||
// outputs takes ownership of handle
|
||||
if(handle) {
|
||||
outputs.push_back(at::Tensor(handle, /*retain=*/false));
|
||||
handle = nullptr;
|
||||
}
|
||||
}
|
||||
private:
|
||||
AutogradHandle* handle = nullptr;
|
||||
};
|
||||
|
||||
bool hasHandleOutput(Node * n) {
|
||||
if(n->outputs().size() == 0)
|
||||
return false;
|
||||
|
|
@ -444,8 +377,7 @@ bool hasHandleOutput(Node * n) {
|
|||
#ifndef NO_PYTHON
|
||||
Operation createPythonOperation(PythonOp* op) {
|
||||
py::function func = py::reinterpret_borrow<py::function>(py::handle(op->pyobj.get()));
|
||||
bool tracing_autograd_python_function = op->tracing_autograd_python_function;
|
||||
bool has_handle = hasHandleOutput(op);
|
||||
JIT_ASSERT(!hasHandleOutput(op));
|
||||
size_t num_inputs = 0;
|
||||
for(auto arg_type : op->cconv) {
|
||||
if(arg_type == 't')
|
||||
|
|
@ -457,100 +389,49 @@ Operation createPythonOperation(PythonOp* op) {
|
|||
size_t i = 0;
|
||||
size_t next_scalar = 0;
|
||||
size_t next_tensor = 0;
|
||||
HandleBuilder builder(has_handle);
|
||||
// Note: The first branch here should be considered deprecated and will
|
||||
// probably be removed in the future.
|
||||
//
|
||||
// tracing_autograd_python_function indicates that we need to hook this
|
||||
// PythonOp up to autograd with the HandleBuilder
|
||||
if (tracing_autograd_python_function) {
|
||||
for (auto arg_type : op->cconv) {
|
||||
if (arg_type == 's') {
|
||||
py_inputs[i] = py::reinterpret_borrow<py::object>(
|
||||
op->scalar_args[next_scalar++].get());
|
||||
} else if (arg_type == 't') {
|
||||
py_inputs[i] = py::reinterpret_steal<py::object>(
|
||||
THPVariable_Wrap(builder.addInput(
|
||||
std::move(peek(stack, next_tensor, num_inputs)),
|
||||
op->var_flags.at(next_tensor))));
|
||||
next_tensor++;
|
||||
}
|
||||
i++;
|
||||
for (auto arg_type : op->cconv) {
|
||||
if (arg_type == 's') {
|
||||
py_inputs[i] = py::reinterpret_borrow<py::object>(
|
||||
op->scalar_args[next_scalar++].get());
|
||||
} else if (arg_type == 't') {
|
||||
auto var = peek(stack, next_tensor, num_inputs);
|
||||
py_inputs[i] =
|
||||
py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
|
||||
next_tensor++;
|
||||
}
|
||||
drop(stack, num_inputs);
|
||||
py::object py_outputs(func(*py_inputs));
|
||||
auto num_outputs = op->outputs().size();
|
||||
auto addOutput = [&](py::handle entry) {
|
||||
if (!THPVariable_Check(entry.ptr())) {
|
||||
throw std::runtime_error(
|
||||
"Function.apply returned a non-Variable output");
|
||||
}
|
||||
THPVariable* var = (THPVariable*)entry.ptr();
|
||||
stack.push_back(builder.addOutput(var->cdata));
|
||||
};
|
||||
if (!PyTuple_Check(py_outputs.ptr())) {
|
||||
if (num_outputs != 1) {
|
||||
throw std::runtime_error(
|
||||
"Function.apply returned the wrong number of outputs.");
|
||||
}
|
||||
addOutput(py_outputs);
|
||||
} else {
|
||||
auto output_tuple = py::tuple(py_outputs);
|
||||
if (output_tuple.size() != num_outputs) {
|
||||
throw std::runtime_error(
|
||||
"Function.apply returned the wrong number of outputs.");
|
||||
}
|
||||
for (py::handle entry : output_tuple) {
|
||||
addOutput(entry);
|
||||
}
|
||||
}
|
||||
builder.writeTo(stack);
|
||||
return 0;
|
||||
} else {
|
||||
for (auto arg_type : op->cconv) {
|
||||
if (arg_type == 's') {
|
||||
py_inputs[i] = py::reinterpret_borrow<py::object>(
|
||||
op->scalar_args[next_scalar++].get());
|
||||
} else if (arg_type == 't') {
|
||||
auto var = peek(stack, next_tensor, num_inputs);
|
||||
py_inputs[i] =
|
||||
py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
|
||||
next_tensor++;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
drop(stack, num_inputs);
|
||||
py::object py_outputs(func(*py_inputs));
|
||||
|
||||
auto num_outputs = op->outputs().size();
|
||||
auto addOutput = [&](py::handle entry) {
|
||||
if (!THPVariable_Check(entry.ptr())) {
|
||||
throw std::runtime_error(
|
||||
"Function application returned a non-Variable output");
|
||||
}
|
||||
THPVariable* var = (THPVariable*)entry.ptr();
|
||||
auto cdata = var->cdata;
|
||||
stack.push_back(std::move(cdata));
|
||||
};
|
||||
|
||||
if (!PyTuple_Check(py_outputs.ptr())) {
|
||||
if (num_outputs != 1) {
|
||||
throw std::runtime_error(
|
||||
"Function.apply returned the wrong number of outputs.");
|
||||
}
|
||||
addOutput(py_outputs);
|
||||
} else {
|
||||
auto output_tuple = py::tuple(py_outputs);
|
||||
if (output_tuple.size() != num_outputs) {
|
||||
throw std::runtime_error(
|
||||
"Function application returned the wrong number of outputs.");
|
||||
}
|
||||
for (py::handle entry : py::tuple(py_outputs)) {
|
||||
addOutput(entry);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
i++;
|
||||
}
|
||||
drop(stack, num_inputs);
|
||||
py::object py_outputs(func(*py_inputs));
|
||||
|
||||
auto num_outputs = op->outputs().size();
|
||||
auto addOutput = [&](py::handle entry) {
|
||||
if (!THPVariable_Check(entry.ptr())) {
|
||||
throw std::runtime_error(
|
||||
"Function application returned a non-Variable output");
|
||||
}
|
||||
THPVariable* var = (THPVariable*)entry.ptr();
|
||||
auto cdata = var->cdata;
|
||||
stack.push_back(std::move(cdata));
|
||||
};
|
||||
|
||||
if (!PyTuple_Check(py_outputs.ptr())) {
|
||||
if (num_outputs != 1) {
|
||||
throw std::runtime_error(
|
||||
"Function.apply returned the wrong number of outputs.");
|
||||
}
|
||||
addOutput(py_outputs);
|
||||
} else {
|
||||
auto output_tuple = py::tuple(py_outputs);
|
||||
if (output_tuple.size() != num_outputs) {
|
||||
throw std::runtime_error(
|
||||
"Function application returned the wrong number of outputs.");
|
||||
}
|
||||
for (py::handle entry : py::tuple(py_outputs)) {
|
||||
addOutput(entry);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
#else
|
||||
|
|
@ -564,55 +445,18 @@ Operation createPythonOperation(PythonOp* op) {
|
|||
|
||||
Operation createCppOperation(CppOp* op) {
|
||||
std::shared_ptr<autograd::Function> func = op->fn;
|
||||
bool has_handle = hasHandleOutput(op);
|
||||
JIT_ASSERT(!hasHandleOutput(op));
|
||||
auto num_inputs = op->inputs().size();
|
||||
return [=](Stack & stack) {
|
||||
HandleBuilder builder(has_handle);
|
||||
autograd::variable_list v_inputs;
|
||||
for(size_t i = 0; i < num_inputs; i++) {
|
||||
v_inputs.push_back(builder.addInput(std::move(peek(stack, i, num_inputs)), op->var_flags[i]));
|
||||
v_inputs.push_back(std::move(peek(stack, i, num_inputs)));
|
||||
}
|
||||
drop(stack, num_inputs);
|
||||
autograd::variable_list v_outputs = (*func)(v_inputs);
|
||||
for(auto & output : v_outputs) {
|
||||
stack.push_back(builder.addOutput(output));
|
||||
stack.push_back(output);
|
||||
}
|
||||
builder.writeTo(stack);
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
Operation createEvalOperation(CppOp * op) {
|
||||
bool has_handle_output = hasHandleOutput(op);
|
||||
auto num_inputs = op->inputs().size();
|
||||
return [=](Stack & stack) {
|
||||
at::Tensor handle_t = std::move(stack.back());
|
||||
AutogradHandle * handle_in = dynamic_cast<AutogradHandle*>(handle_t.get());
|
||||
JIT_ASSERT(handle_in);
|
||||
HandleBuilder builder(has_handle_output);
|
||||
auto& engine = torch::autograd::Engine::getDefaultEngine();
|
||||
autograd::variable_list v_inputs;
|
||||
for(size_t i = 0; i < num_inputs - 1; i++) {
|
||||
v_inputs.push_back(builder.addInput(std::move(peek(stack, i, num_inputs)), op->var_flags[i]));
|
||||
}
|
||||
drop(stack, num_inputs);
|
||||
// TODO: handle create_graph appropriately
|
||||
bool create_graph = true;
|
||||
// note: node handle_in->use_count() == 1 means that we are guarenteed that we have the only
|
||||
// only copy of the handle. This might make it seem it is ok to pass keep_graph=False.
|
||||
// However, it is possible for 'copied_next_fns' to grab functions used by _other_ handles,
|
||||
// and these functions will be executed in this run. Since these other handles
|
||||
// may still be alive, it is not safe to release the graph
|
||||
// TODO: we could cache this list in AutogradHandle (it's read only)
|
||||
autograd::edge_list output_edges;
|
||||
const auto num_inputs = handle_in->forward_inputs->num_inputs();
|
||||
output_edges.reserve(num_inputs);
|
||||
for (uint32_t i = 0; i < num_inputs; ++i)
|
||||
output_edges.emplace_back(handle_in->forward_inputs, i);
|
||||
auto values = engine.execute(handle_in->forward_outputs, v_inputs, true, create_graph, output_edges);
|
||||
for(auto & v : values)
|
||||
stack.push_back(builder.addOutput(v));
|
||||
builder.writeTo(stack);
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
|
@ -623,11 +467,8 @@ Operation getOperation(jit::Node* node) {
|
|||
IR_IFM(node, PythonOp)
|
||||
return createPythonOperation(value);
|
||||
IR_ELSEIFM(CppOp)
|
||||
if(dynamic_cast<autograd::Eval*>(value->fn.get())) {
|
||||
return createEvalOperation(value);
|
||||
} else {
|
||||
return createCppOperation(value);
|
||||
}
|
||||
JIT_ASSERT(!dynamic_cast<autograd::Eval*>(value->fn.get()));
|
||||
return createCppOperation(value);
|
||||
IR_ELSEIF(FusionGroup)
|
||||
auto fusion_fn = sharedFusionCompiler().getOrCompile(value);
|
||||
auto num_inputs = value->inputs().size();
|
||||
|
|
|
|||
|
|
@ -1,187 +0,0 @@
|
|||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/autograd/variable.h"
|
||||
#include "torch/csrc/jit/interpreter.h"
|
||||
#include "torch/csrc/jit/interpreter_autograd_function.h"
|
||||
#include "torch/csrc/jit/ir.h"
|
||||
#include "torch/csrc/jit/tracer.h"
|
||||
#include "torch/csrc/jit/tracer_state.h"
|
||||
#include "torch/csrc/jit/variable_flags.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
using namespace torch::jit::tracer;
|
||||
|
||||
static at::Tensor zeroTensorWithType(const TensorType & type) {
|
||||
auto device = (type.device() < 0)? at::kCPU : at::kCUDA;
|
||||
auto & at_type = at::getType(device, type.scalarType());
|
||||
// note: this has to be a contiguous tensor of zeros, because the fusion engine
|
||||
// specialized to what is normally here which might be fully dense
|
||||
return autograd::make_variable(at::zeros(at_type, type.sizes()));
|
||||
}
|
||||
|
||||
autograd::variable_list InterpreterAutogradFunction::apply(
|
||||
const autograd::variable_list& inputs) {
|
||||
// Initial correctness checks.
|
||||
if (stage_ == stage_details_.size()) {
|
||||
throw std::runtime_error(std::string("Function compiled only for ") +
|
||||
std::to_string(stage_details_.size() - 1) + " derivatives. Use nderivs argument " +
|
||||
"to request more.");
|
||||
}
|
||||
if (used_) throw std::runtime_error(autograd::ERR_BACKWARD_TWICE);
|
||||
used_ |= !keep_graph_;
|
||||
|
||||
const auto & details = stage_details_[stage_];
|
||||
|
||||
// Validate inputs
|
||||
std::vector<at::Tensor> stack;
|
||||
stack.reserve(inputs.size());
|
||||
TORCH_ASSERT(inputs.size() == num_inputs_);
|
||||
TORCH_ASSERT(inputs.size() == details.input_flags.size());
|
||||
for (std::size_t i = 0; i < (std::size_t)inputs.size(); ++i) {
|
||||
auto actual_flags = VariableFlags::of(inputs[i]);
|
||||
auto traced_flags = details.input_flags[i];
|
||||
|
||||
// check that this trace is general enough to handle the input
|
||||
// flags of the actual tensor. We can't handle the following two cases
|
||||
// because we won't have a trace containing computation of either
|
||||
// the tensor itself (not defined) or the stage for its gradient (requires_grad=False)
|
||||
if(!traced_flags.defined && actual_flags.defined) {
|
||||
throw std::runtime_error("JIT interpreter received a defined input, but the"
|
||||
" trace was compiled with the input being undefined.");
|
||||
}
|
||||
if(!traced_flags.requires_grad && actual_flags.requires_grad) {
|
||||
throw std::runtime_error("JIT inteperpreter recieved an input with "
|
||||
" requires_grad=True, but was compiled with requires_grad=False");
|
||||
}
|
||||
|
||||
// The remaining cases we can handle. If the gradient was not
|
||||
// required but the trace will compute it, then we just compute it and
|
||||
// ignore the result.
|
||||
// However, if we are passed an undefined tensor, but the trace
|
||||
// expects a defined tensor, then we have to give it one.
|
||||
// Undefined tensors are used as stand-ins for zero tensors, so
|
||||
// we create a zero-filled tensor of the right size
|
||||
if(!actual_flags.defined) {
|
||||
// [Temporary workaround for variants] until tracer produces all variants:
|
||||
// This case appears commonly when you have a function
|
||||
// x, y = fn(z)
|
||||
// and only use x then gradient for y
|
||||
// will be undefined. If you reuse the same trace with and _sometimes_ use y
|
||||
// then in the cases where you don't use it, the grad_y input in stage 1
|
||||
// will be undefined. To ensure we can continue, we create a 0 gradient,
|
||||
// using trace information to figure out what shape it should be
|
||||
if(traced_flags.defined) {
|
||||
stack.push_back(zeroTensorWithType(interp_.tensorTypeForInput(i)));
|
||||
} else {
|
||||
stack.push_back(at::Tensor());
|
||||
}
|
||||
} else {
|
||||
stack.push_back(inputs[i].detach());
|
||||
}
|
||||
}
|
||||
|
||||
// Run the interpreter
|
||||
InterpreterState interp = (keep_graph_) ? interp_.clone() : interp_;
|
||||
interp.runOneStage(stack);
|
||||
|
||||
// Lazily create grad_fn
|
||||
std::shared_ptr<Function> grad_fn;
|
||||
auto make_grad_fn = [&]() {
|
||||
grad_fn = std::make_shared<InterpreterAutogradFunction>(
|
||||
std::move(interp), stage_details_, stage_ + 1);
|
||||
|
||||
// Running this next stage is actually not valid (nderiv is too low)
|
||||
// but we don't know if the user will ever ask for it so we don't error out here.
|
||||
// Instead we have to return early because we rely on stage_details_[stage+1] in the
|
||||
// remaining code
|
||||
if(stage_ + 1 == stage_details_.size())
|
||||
return;
|
||||
|
||||
// Patch next_edges to include prevous stage next_edges
|
||||
// This is needed because stage N is really a derivative of
|
||||
// all stages from 1 to N-1. If a part of stage x graph is
|
||||
// reused in stage y (y > x), it is inlined by the tracer,
|
||||
// and so we need to copy next_fns because those Variables
|
||||
// aren't real inputs to that stage, so that's the only place
|
||||
// where we can get them.
|
||||
for (auto copied_idx : stage_details_[stage_ + 1].copied_next_fns) {
|
||||
grad_fn->add_next_edge(next_edges_[copied_idx]);
|
||||
}
|
||||
// Add grad_fns corresponding to inputs
|
||||
for(size_t i = 0; i < inputs.size(); ++i) {
|
||||
// If an input isn't used, there's no gradient for it, and next stage
|
||||
// won't even have its grad in the trace. Don't create an entry for it.
|
||||
if (!details.used_inputs[i]) continue;
|
||||
auto & input = inputs[i];
|
||||
if (!details.input_flags[i].requires_grad) {
|
||||
continue; // See Note [Null-edge pruning]
|
||||
} else if (!input.defined() || !input.requires_grad()) {
|
||||
// See Note [Temporary workaround for variants]
|
||||
grad_fn->add_next_edge({});
|
||||
continue;
|
||||
}
|
||||
grad_fn->add_next_edge(input.gradient_edge());
|
||||
}
|
||||
};
|
||||
|
||||
// Wrap the outputs
|
||||
// TODO: handle views
|
||||
autograd::variable_list result;
|
||||
JIT_ASSERT(stack.size() == details.output_flags.size());
|
||||
auto num_outputs = stack.size();
|
||||
for (std::size_t i = 0; i < num_outputs; ++i) {
|
||||
auto & flags = details.output_flags[i];
|
||||
if (flags.requires_grad) { // See Note [Null-edge pruning]
|
||||
if (!grad_fn) make_grad_fn();
|
||||
auto variable = static_cast<const Variable&>(stack[i]);
|
||||
autograd::create_gradient_edge(variable, grad_fn);
|
||||
result.push_back(std::move(variable));
|
||||
} else {
|
||||
result.push_back(static_cast<const Variable&>(stack[i]));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
InterpreterFunctionFactory::InterpreterFunctionFactory(TracingState *state) {
|
||||
code_ = jit::Code(state->graph);
|
||||
stage_details_.resize(state->graph->stage() + 1);
|
||||
auto graph_inputs = state->graph->inputs();
|
||||
auto inputs_it = graph_inputs.begin();
|
||||
for (std::size_t stage = 0; stage < state->graph->stage() + 1; ++stage) {
|
||||
auto & details = stage_details_[stage];
|
||||
std::tie(details.input_flags, details.output_flags) = std::move(state->var_flags[stage]);
|
||||
for (std::size_t i = 0; inputs_it != graph_inputs.end() && (*inputs_it)->stage() == stage; ++i, ++inputs_it) {
|
||||
details.used_inputs.push_back((*inputs_it)->uses().size() > 0);
|
||||
}
|
||||
if (stage >= 1) {
|
||||
auto & current_outputs = state->output_edges[stage];
|
||||
auto & prev_outputs = state->output_edges[stage - 1];
|
||||
for (auto & output : current_outputs) {
|
||||
// Check if output appears in outputs of previous stage
|
||||
auto prev_it = std::find(prev_outputs.begin(), prev_outputs.end(), output);
|
||||
if (prev_it == prev_outputs.end()) continue;
|
||||
// If yes, find its index and append that to the list of edges that will need
|
||||
// to be copied in InterpreterAutogradFunction.
|
||||
details.copied_next_fns.push_back(std::distance(prev_outputs.begin(), prev_it));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<InterpreterAutogradFunction> InterpreterFunctionFactory::construct() {
|
||||
return std::make_shared<InterpreterAutogradFunction>(code_, stage_details_);
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
#include "torch/csrc/jit/interpreter.h"
|
||||
#include "torch/csrc/jit/variable_flags.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch { namespace jit {
|
||||
namespace tracer {
|
||||
struct TracingState;
|
||||
} // namespace tracer
|
||||
|
||||
struct StageDetails {
|
||||
std::vector<VariableFlags> input_flags;
|
||||
std::vector<VariableFlags> output_flags;
|
||||
std::vector<int> copied_next_fns;
|
||||
std::vector<bool> used_inputs;
|
||||
};
|
||||
|
||||
struct InterpreterAutogradFunction : autograd::Function {
|
||||
InterpreterAutogradFunction(
|
||||
const jit::Code& code,
|
||||
const std::vector<StageDetails>& stage_details)
|
||||
// Stage 0 isn't run through the autograd, so we set this
|
||||
// here just in case it is used.
|
||||
: Function(/*num_inputs=*/stage_details.at(0).input_flags.size()),
|
||||
interp_(code),
|
||||
stage_details_(stage_details),
|
||||
stage_(0) {}
|
||||
|
||||
InterpreterAutogradFunction(InterpreterState interp,
|
||||
const std::vector<StageDetails>& stage_details,
|
||||
std::size_t stage)
|
||||
: interp_(std::move(interp))
|
||||
, stage_details_(stage_details)
|
||||
, stage_(stage) {}
|
||||
|
||||
// apply() is a protected method in `autograd::Function` since users should
|
||||
// usually use the call operator, which invokes either `apply()` or
|
||||
// `traced_apply()` depending on whether the function is traced. For
|
||||
// InterpreterAutogradFunctions, however, we don't need this extra tracing
|
||||
// logic. So we make it public here.
|
||||
using autograd::Function::apply;
|
||||
|
||||
virtual void will_release_variables() override {
|
||||
keep_graph_ = false;
|
||||
}
|
||||
|
||||
virtual autograd::variable_list apply(const autograd::variable_list& inputs) override;
|
||||
|
||||
private:
|
||||
InterpreterState interp_;
|
||||
const std::vector<StageDetails>& stage_details_;
|
||||
size_t stage_;
|
||||
bool keep_graph_ = true;
|
||||
bool used_ = false;
|
||||
};
|
||||
|
||||
struct InterpreterFunctionFactory {
|
||||
explicit InterpreterFunctionFactory(tracer::TracingState *state);
|
||||
// Return `InterpreterAutogradFunction` because it has its apply() public.
|
||||
std::shared_ptr<InterpreterAutogradFunction> construct();
|
||||
// For when we need to pass a function with this signature.
|
||||
std::shared_ptr<autograd::Function> construct_function() {
|
||||
return construct();
|
||||
}
|
||||
|
||||
private:
|
||||
jit::Code code_;
|
||||
std::vector<StageDetails> stage_details_;
|
||||
};
|
||||
|
||||
|
||||
}}
|
||||
|
|
@ -106,13 +106,10 @@ void PythonOp::cloneFrom(Node * other_) {
|
|||
this->cconv = other->cconv;
|
||||
Py_INCREF(other->pyobj.get());
|
||||
this->pyobj = THPObjectPtr(other->pyobj.get());
|
||||
this->var_flags = other->var_flags;
|
||||
for(auto & sa : other->scalar_args) {
|
||||
Py_INCREF(sa.get());
|
||||
this->scalar_args.emplace_back(sa.get());
|
||||
}
|
||||
this->tracing_autograd_python_function =
|
||||
other->tracing_autograd_python_function;
|
||||
}
|
||||
|
||||
}} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -979,10 +979,8 @@ public:
|
|||
Node* createPythonOp(
|
||||
THPObjectPtr&& pyobj,
|
||||
const std::string& cconv,
|
||||
std::vector<VariableFlags>&& var_flags,
|
||||
pyobj_list&& scalar_args,
|
||||
bool tracing_autograd_python_function = true);
|
||||
Node * createCppOp(const std::shared_ptr<torch::autograd::Function> & fn, std::vector<VariableFlags> && var_flags);
|
||||
pyobj_list&& scalar_args);
|
||||
Node * createCppOp(const std::shared_ptr<torch::autograd::Function> & fn);
|
||||
// clone n, making a new node in _this_ graph.
|
||||
// use node_map to translate inputs of n to inputs of the cloned node
|
||||
// if copy_blocks is false, it will not recursively clone the nested blocks
|
||||
|
|
@ -1271,14 +1269,10 @@ struct PythonOp : public Node {
|
|||
PythonOp* init(
|
||||
THPObjectPtr&& pyobj,
|
||||
const std::string& cconv,
|
||||
std::vector<VariableFlags>&& var_flags,
|
||||
pyobj_list&& scalar_args,
|
||||
bool tracing_autograd_python_function = true) {
|
||||
pyobj_list&& scalar_args) {
|
||||
this->pyobj = std::move(pyobj);
|
||||
this->scalar_args = std::move(scalar_args);
|
||||
this->cconv = cconv;
|
||||
this->var_flags = std::move(var_flags);
|
||||
this->tracing_autograd_python_function = tracing_autograd_python_function;
|
||||
return this;
|
||||
}
|
||||
virtual Node * allocNewInstance(Graph * g) override {
|
||||
|
|
@ -1300,27 +1294,21 @@ struct PythonOp : public Node {
|
|||
// 's' -- python scalar argument
|
||||
// 't' -- tensor argument
|
||||
std::string cconv;
|
||||
bool tracing_autograd_python_function;
|
||||
// Scalar arguments to the Python function. Not necessarily passed to
|
||||
// the function in this order; see cconv for the correct order.
|
||||
std::vector<THPObjectPtr> scalar_args;
|
||||
std::vector<VariableFlags> var_flags;
|
||||
std::string name() const;
|
||||
virtual void cloneFrom(Node * other_) override;
|
||||
};
|
||||
inline Node* Graph::createPythonOp(
|
||||
THPObjectPtr&& pyobj,
|
||||
const std::string& cconv,
|
||||
std::vector<VariableFlags>&& var_flags,
|
||||
pyobj_list&& scalar_args,
|
||||
bool tracing_autograd_python_function) {
|
||||
pyobj_list&& scalar_args) {
|
||||
auto op = new PythonOp(this);
|
||||
return op->init(
|
||||
std::move(pyobj),
|
||||
cconv,
|
||||
std::move(var_flags),
|
||||
std::move(scalar_args),
|
||||
tracing_autograd_python_function);
|
||||
std::move(scalar_args));
|
||||
}
|
||||
|
||||
// A Cpp operator is an operator which dispatches directly to an autograd function.
|
||||
|
|
@ -1330,12 +1318,10 @@ struct CppOp : public Node {
|
|||
CppOp(Graph * g)
|
||||
: Node(g,prim::CppOp) {}
|
||||
std::shared_ptr<torch::autograd::Function> fn;
|
||||
std::vector<VariableFlags> var_flags;
|
||||
std::string name() const;
|
||||
CppOp* init(std::shared_ptr<torch::autograd::Function> fn, std::vector<VariableFlags> && var_flags) {
|
||||
CppOp* init(std::shared_ptr<torch::autograd::Function> fn) {
|
||||
JIT_ASSERT(fn);
|
||||
this->fn = std::move(fn);
|
||||
this->var_flags = std::move(var_flags);
|
||||
return this;
|
||||
}
|
||||
virtual Node * allocNewInstance(Graph * g) override {
|
||||
|
|
@ -1345,12 +1331,11 @@ struct CppOp : public Node {
|
|||
Node::cloneFrom(other_);
|
||||
auto other = other_->cast<CppOp>();
|
||||
this->fn = other->fn;
|
||||
this->var_flags = other->var_flags;
|
||||
}
|
||||
};
|
||||
inline Node * Graph::createCppOp(const std::shared_ptr<torch::autograd::Function> & fn, std::vector<VariableFlags> && var_flags) {
|
||||
inline Node * Graph::createCppOp(const std::shared_ptr<torch::autograd::Function> & fn) {
|
||||
auto op = new CppOp(this);
|
||||
return op->init(fn, std::move(var_flags));
|
||||
return op->init(fn);
|
||||
}
|
||||
|
||||
inline graph_node_list_iterator Node::iterator() {
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ void initPythonIRBindings(PyObject * module_) {
|
|||
std::string cconv(inputs.size(), 't');
|
||||
func.attr("symbolic") = symbolic;
|
||||
Node* new_node = g.insertNode(g.createPythonOp(
|
||||
THPObjectPtr(func.release().ptr()), cconv, {}, {}, false));
|
||||
THPObjectPtr(func.release().ptr()), cconv, {}));
|
||||
for (auto i : inputs)
|
||||
new_node->addInput(i);
|
||||
std::vector<Value*> outputs;
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ void initPythonTracerBindings(PyObject* module_) {
|
|||
});
|
||||
|
||||
m.def("_tracer_enter", [](variable_list trace_inputs, std::size_t num_backwards) {
|
||||
return tracer::enter(std::move(trace_inputs), num_backwards + 1, true);
|
||||
return tracer::enter(std::move(trace_inputs), num_backwards + 1);
|
||||
});
|
||||
m.def("_tracer_exit", [](variable_list var_outputs) {
|
||||
tracer::exit(var_outputs);
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
|
|||
py::object func = self;
|
||||
std::string cconv(inputs.size(), 't');
|
||||
Node* new_node = g.insertNode(g.createPythonOp(
|
||||
THPObjectPtr(func.release().ptr()), cconv, {}, {}, false));
|
||||
THPObjectPtr(func.release().ptr()), cconv, {}));
|
||||
new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
|
||||
for(auto i : inputs)
|
||||
new_node->addInput(i);
|
||||
|
|
|
|||
|
|
@ -523,7 +523,7 @@ variable_list get_grad_outputs(const variable_list& vars) {
|
|||
std::shared_ptr<Graph> trace(const ADTestSpec& test, const variable_list& vars_in) {
|
||||
std::shared_ptr<tracer::TracingState> state;
|
||||
variable_list trace_vars_in;
|
||||
std::tie(state, trace_vars_in) = tracer::enter(vars_in, 1, true);
|
||||
std::tie(state, trace_vars_in) = tracer::enter(vars_in, 1);
|
||||
auto trace_vars_out = test(trace_vars_in);
|
||||
tracer::exit(trace_vars_out);
|
||||
return state->graph;
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ std::shared_ptr<torch::jit::Graph> torch::jit::tracer::createGraphByTracing(
|
|||
py::function func,
|
||||
tracer::variable_list trace_inputs,
|
||||
size_t num_func_inputs) {
|
||||
auto enter_info = tracer::enter(std::move(trace_inputs), 1, false);
|
||||
auto enter_info = tracer::enter(std::move(trace_inputs), 1);
|
||||
py::tuple py_inputs(num_func_inputs);
|
||||
for(size_t i = 0; i < num_func_inputs; ++i) {
|
||||
py_inputs[i] = py::cast(enter_info.second[i]);
|
||||
|
|
@ -206,7 +206,6 @@ PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj,
|
|||
std::string arg_types,
|
||||
at::ArrayRef<Variable> inputs,
|
||||
pyobj_list scalar_args) {
|
||||
std::vector<VariableFlags> var_flags = fmap(inputs, &VariableFlags::of);
|
||||
THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
|
||||
if(!apply) {
|
||||
throw python_error();
|
||||
|
|
@ -215,9 +214,7 @@ PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj,
|
|||
return graph.createPythonOp(
|
||||
std::move(apply),
|
||||
arg_types,
|
||||
std::move(var_flags),
|
||||
std::move(scalar_args),
|
||||
state->creates_handles);
|
||||
std::move(scalar_args));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -221,9 +221,8 @@ inline Value* getOutputTrace(const std::shared_ptr<TracingState>& state, const V
|
|||
// reference to at::Tensor buffer to call unsafeGetTH, but you can't get this
|
||||
// out of a const vector (silly std::vector...)
|
||||
inline std::pair<std::shared_ptr<TracingState>, variable_list> enter(
|
||||
variable_list inputs, std::size_t num_stages, bool creates_handles) {
|
||||
variable_list inputs, std::size_t num_stages) {
|
||||
auto state = std::make_shared<TracingState>(num_stages);
|
||||
state->creates_handles = creates_handles;
|
||||
for (auto& input : inputs) {
|
||||
auto * value_state = detail::getValueState(state, input, false);
|
||||
if (value_state) {
|
||||
|
|
|
|||
|
|
@ -20,8 +20,7 @@ TracingState::TracingState(size_t num_stages)
|
|||
num_stages(num_stages),
|
||||
eval_count(0),
|
||||
var_flags(num_stages),
|
||||
output_edges(num_stages),
|
||||
creates_handles(true) {}
|
||||
output_edges(num_stages) {}
|
||||
|
||||
TracingState::~TracingState() = default;
|
||||
|
||||
|
|
|
|||
|
|
@ -54,10 +54,7 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
|
|||
|
||||
std::mutex mutex;
|
||||
variable_list inputs; // Used only for the duration of first stage
|
||||
|
||||
bool creates_handles; // should python ops ever get handles? Or should
|
||||
// we just record them as is.
|
||||
|
||||
|
||||
std::unique_lock<std::mutex> lock() {
|
||||
return std::unique_lock<std::mutex>(mutex);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user