JIT cleanup (#7631)

Cleans up dead code in the JIT:

* Remove interpreter_autograd_function
* Remove Handles
* Remove HandleBuilder
* Remove creates_handles, and tracing_autograd_python_function flags
* Remove unused var_args
* Fix submodules
This commit is contained in:
Zachary DeVito 2018-05-21 10:06:29 -07:00 committed by GitHub
parent e6f7e1807d
commit 286cd04a20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 69 additions and 559 deletions

View File

@ -648,7 +648,6 @@ main_sources = [
"torch/csrc/jit/export.cpp",
"torch/csrc/jit/import.cpp",
"torch/csrc/jit/autodiff.cpp",
"torch/csrc/jit/interpreter_autograd_function.cpp",
"torch/csrc/jit/python_arg_flatten.cpp",
"torch/csrc/jit/variable_flags.cpp",
"torch/csrc/jit/passes/create_autodiff_subgraphs.cpp",

View File

@ -1,4 +1,4 @@
graph(%0 : Double(2, 2)) {
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0), scope: Dropout
%1 : Double(2, 2) = ^Dropout(0.6, True, False)(%0), scope: Dropout
return (%1);
}

View File

@ -206,7 +206,6 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/tracer_state.cpp
${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp
${TORCH_SRC_DIR}/csrc/jit/type.cpp
${TORCH_SRC_DIR}/csrc/jit/interpreter_autograd_function.cpp
${TORCH_SRC_DIR}/csrc/jit/export.cpp
${TORCH_SRC_DIR}/csrc/jit/import.cpp
${TORCH_SRC_DIR}/csrc/onnx/onnx.cpp

View File

@ -487,7 +487,6 @@ static PyObject* initModule() {
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));
ASSERT_TRUE(THPEngine_initModule(module));
torch::autograd::initAutogradClosureBindings(module);
torch::jit::initJITBindings(module);
torch::onnx::initONNXBindings(module);
torch::autograd::initNNFunctions(module);

View File

@ -6,8 +6,6 @@ void THPAutograd_initFunctions();
namespace torch { namespace autograd {
void initAutogradClosureBindings(PyObject* module);
PyMethodDef* python_functions();
}}

View File

@ -37,11 +37,7 @@ variable_list Function::traced_apply(variable_list inputs) {
// Insert a CppOp in the trace.
auto& graph = state->graph;
std::vector<VariableFlags> var_flags;
for(auto & input: inputs) {
var_flags.push_back(VariableFlags::of(input));
}
auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
auto* this_node = graph->createCppOp(get_shared_ptr());
#ifndef NO_PYTHON
this_node->setSourceLocation(std::make_shared<StringSourceLocation>(
jit::tracer::getPythonInterpreterStackTrace()

View File

@ -3,7 +3,6 @@
#include "basic_ops.h"
#include "tensor.h"
#include "special.h"
#include "torch/csrc/jit/interpreter_autograd_function.h"
#include "torch/csrc/autograd/functions/pybind.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include "torch/csrc/autograd/generated/python_functions.h"
@ -99,9 +98,6 @@ void THPAutograd_initFunctions()
static PyTypeObject EvalClass;
addClass<Eval, NoCtor>(module, EvalClass, "Eval");
static PyTypeObject InterpreterAutogradClass;
addClass<torch::jit::InterpreterAutogradFunction, NoCtor>(module, InterpreterAutogradClass, "InterpreterAutogradFunction");
static PyTypeObject CopyBackwardsClass;
addClass<CopyBackwards, NoCtor>(module, CopyBackwardsClass, "CopyBackwards");
@ -118,18 +114,3 @@ void THPAutograd_initFunctions()
throw python_error();
}
}
namespace torch { namespace autograd {
void initAutogradClosureBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<jit::InterpreterFunctionFactory,std::shared_ptr<jit::InterpreterFunctionFactory>>(m, "InterpreterFunctionFactory")
.def("__call__", &jit::InterpreterFunctionFactory::construct_function)
;
m.def("_jit_createInterpreterFactory", [](jit::tracer::TracingState* tracing_state) {
return std::make_shared<jit::InterpreterFunctionFactory>(tracing_state);
});
}
}}

View File

@ -602,19 +602,7 @@ static void _trace_post_record(
auto state_lock = trace_info.state->lock();
trace_info.n->i_(attr::inplace, is_inplace);
// See definition in function.cpp.
THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")};
if (!passes_py_bool) throw python_error();
bool passes_state_transparently = passes_py_bool == Py_True;
// NB: this path is executed only for forward of Python functions, so there's no need to check
// tracing_state->in_eval_subgraph (it's always false, because they are never part of backward
// subgraphs AND we don't even materialize the forward function).
if (trace_info.state->creates_handles && !passes_state_transparently) {
// TODO: sgross and ezyang don't know if this is right
tracer::nontraceableBackwardSubgraph(input_vars, output_vars);
Function::set_up_context_edge(trace_info.n, input_vars, output_vars);
}
}
PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const UnpackedInput& unpacked,

View File

@ -336,6 +336,8 @@ struct PreprocessGraph {
// which are annoying to handle since 99% of values are at::Tensor anyway
// instead we create a fake subclass of TensorImpl that can be subclassed
// to hold arbitrary things
// Note: this is currently unused but will probably be useful in the future,
// so we keep it around
struct ContainerTensor : public at::TensorImpl {
public:
ContainerTensor()
@ -365,75 +367,6 @@ public:
}
};
// Dummy function is the last function that the autograd engine calls
// when evaluating Eval nodes. Its input tensors are the outputs that the
// Eval node needs to produce.
// We interscept these values using an Autograd callback. So the function itself
// never runs.
struct DummyFunction : autograd::Function {
virtual autograd::variable_list apply(const autograd::variable_list& inputs) override {
throw std::logic_error("DummyFunction::apply() called, but it should be blocked by a callback returning false");
}
};
// An AutogradHandle holds the information needed to run an Autograd backward pass
// after running a forward operator (such as PythonOp, CppOp, or for double-backwards another Eval Op)
// The EvalOperation uses AutogradHandle to perform this operation.
struct AutogradHandle : public ContainerTensor {
// The inputs of DummyFunction are the gradients of the forward passes
// inputs, and the _outputs_ of the run of the Autograd engine computing backward.
// there is one entry in this list for each forward input that requires
// gradients
std::shared_ptr<DummyFunction> forward_inputs;
// there is one entry in this list for each output of the forward pass
// that represents the location in the backwaard pass where the gradient
// of this output should be inserted at the beginning of the backward pass
autograd::edge_list forward_outputs;
};
// HandleBuilder is used to construct the correct Autograd Handle objects
// for use in a future stage.
// It is used even when the future stage does not require a handle since
// it also performs the conversions between Tensor and Variable, which
// behave differently depending on whether a future handle needs to be
// created.
struct HandleBuilder {
HandleBuilder(bool requires_handle) {
if(requires_handle) {
handle = new AutogradHandle();
handle->forward_inputs = std::make_shared<DummyFunction>();
}
}
autograd::Variable addInput(at::Tensor && input_, const VariableFlags & flags_) {
autograd::Variable& input = static_cast<autograd::Variable&>(input_);
if(handle && flags_.requires_grad) {
auto variable = autograd::make_variable(input.data(), /*requires_grad=*/false);
autograd::create_gradient_edge(variable, handle->forward_inputs);
return variable;
} else {
return autograd::make_variable(input.data(), /*requires_grad=*/false);
}
}
at::Tensor addOutput(const autograd::Variable & output) {
if(handle) {
handle->forward_outputs.push_back(output.gradient_edge());
}
return output.detach();
}
void writeTo(Stack & outputs) {
// outputs takes ownership of handle
if(handle) {
outputs.push_back(at::Tensor(handle, /*retain=*/false));
handle = nullptr;
}
}
private:
AutogradHandle* handle = nullptr;
};
bool hasHandleOutput(Node * n) {
if(n->outputs().size() == 0)
return false;
@ -444,8 +377,7 @@ bool hasHandleOutput(Node * n) {
#ifndef NO_PYTHON
Operation createPythonOperation(PythonOp* op) {
py::function func = py::reinterpret_borrow<py::function>(py::handle(op->pyobj.get()));
bool tracing_autograd_python_function = op->tracing_autograd_python_function;
bool has_handle = hasHandleOutput(op);
JIT_ASSERT(!hasHandleOutput(op));
size_t num_inputs = 0;
for(auto arg_type : op->cconv) {
if(arg_type == 't')
@ -457,100 +389,49 @@ Operation createPythonOperation(PythonOp* op) {
size_t i = 0;
size_t next_scalar = 0;
size_t next_tensor = 0;
HandleBuilder builder(has_handle);
// Note: The first branch here should be considered deprecated and will
// probably be removed in the future.
//
// tracing_autograd_python_function indicates that we need to hook this
// PythonOp up to autograd with the HandleBuilder
if (tracing_autograd_python_function) {
for (auto arg_type : op->cconv) {
if (arg_type == 's') {
py_inputs[i] = py::reinterpret_borrow<py::object>(
op->scalar_args[next_scalar++].get());
} else if (arg_type == 't') {
py_inputs[i] = py::reinterpret_steal<py::object>(
THPVariable_Wrap(builder.addInput(
std::move(peek(stack, next_tensor, num_inputs)),
op->var_flags.at(next_tensor))));
next_tensor++;
}
i++;
for (auto arg_type : op->cconv) {
if (arg_type == 's') {
py_inputs[i] = py::reinterpret_borrow<py::object>(
op->scalar_args[next_scalar++].get());
} else if (arg_type == 't') {
auto var = peek(stack, next_tensor, num_inputs);
py_inputs[i] =
py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
next_tensor++;
}
drop(stack, num_inputs);
py::object py_outputs(func(*py_inputs));
auto num_outputs = op->outputs().size();
auto addOutput = [&](py::handle entry) {
if (!THPVariable_Check(entry.ptr())) {
throw std::runtime_error(
"Function.apply returned a non-Variable output");
}
THPVariable* var = (THPVariable*)entry.ptr();
stack.push_back(builder.addOutput(var->cdata));
};
if (!PyTuple_Check(py_outputs.ptr())) {
if (num_outputs != 1) {
throw std::runtime_error(
"Function.apply returned the wrong number of outputs.");
}
addOutput(py_outputs);
} else {
auto output_tuple = py::tuple(py_outputs);
if (output_tuple.size() != num_outputs) {
throw std::runtime_error(
"Function.apply returned the wrong number of outputs.");
}
for (py::handle entry : output_tuple) {
addOutput(entry);
}
}
builder.writeTo(stack);
return 0;
} else {
for (auto arg_type : op->cconv) {
if (arg_type == 's') {
py_inputs[i] = py::reinterpret_borrow<py::object>(
op->scalar_args[next_scalar++].get());
} else if (arg_type == 't') {
auto var = peek(stack, next_tensor, num_inputs);
py_inputs[i] =
py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
next_tensor++;
}
i++;
}
drop(stack, num_inputs);
py::object py_outputs(func(*py_inputs));
auto num_outputs = op->outputs().size();
auto addOutput = [&](py::handle entry) {
if (!THPVariable_Check(entry.ptr())) {
throw std::runtime_error(
"Function application returned a non-Variable output");
}
THPVariable* var = (THPVariable*)entry.ptr();
auto cdata = var->cdata;
stack.push_back(std::move(cdata));
};
if (!PyTuple_Check(py_outputs.ptr())) {
if (num_outputs != 1) {
throw std::runtime_error(
"Function.apply returned the wrong number of outputs.");
}
addOutput(py_outputs);
} else {
auto output_tuple = py::tuple(py_outputs);
if (output_tuple.size() != num_outputs) {
throw std::runtime_error(
"Function application returned the wrong number of outputs.");
}
for (py::handle entry : py::tuple(py_outputs)) {
addOutput(entry);
}
}
return 0;
i++;
}
drop(stack, num_inputs);
py::object py_outputs(func(*py_inputs));
auto num_outputs = op->outputs().size();
auto addOutput = [&](py::handle entry) {
if (!THPVariable_Check(entry.ptr())) {
throw std::runtime_error(
"Function application returned a non-Variable output");
}
THPVariable* var = (THPVariable*)entry.ptr();
auto cdata = var->cdata;
stack.push_back(std::move(cdata));
};
if (!PyTuple_Check(py_outputs.ptr())) {
if (num_outputs != 1) {
throw std::runtime_error(
"Function.apply returned the wrong number of outputs.");
}
addOutput(py_outputs);
} else {
auto output_tuple = py::tuple(py_outputs);
if (output_tuple.size() != num_outputs) {
throw std::runtime_error(
"Function application returned the wrong number of outputs.");
}
for (py::handle entry : py::tuple(py_outputs)) {
addOutput(entry);
}
}
return 0;
};
}
#else
@ -564,55 +445,18 @@ Operation createPythonOperation(PythonOp* op) {
Operation createCppOperation(CppOp* op) {
std::shared_ptr<autograd::Function> func = op->fn;
bool has_handle = hasHandleOutput(op);
JIT_ASSERT(!hasHandleOutput(op));
auto num_inputs = op->inputs().size();
return [=](Stack & stack) {
HandleBuilder builder(has_handle);
autograd::variable_list v_inputs;
for(size_t i = 0; i < num_inputs; i++) {
v_inputs.push_back(builder.addInput(std::move(peek(stack, i, num_inputs)), op->var_flags[i]));
v_inputs.push_back(std::move(peek(stack, i, num_inputs)));
}
drop(stack, num_inputs);
autograd::variable_list v_outputs = (*func)(v_inputs);
for(auto & output : v_outputs) {
stack.push_back(builder.addOutput(output));
stack.push_back(output);
}
builder.writeTo(stack);
return 0;
};
}
Operation createEvalOperation(CppOp * op) {
bool has_handle_output = hasHandleOutput(op);
auto num_inputs = op->inputs().size();
return [=](Stack & stack) {
at::Tensor handle_t = std::move(stack.back());
AutogradHandle * handle_in = dynamic_cast<AutogradHandle*>(handle_t.get());
JIT_ASSERT(handle_in);
HandleBuilder builder(has_handle_output);
auto& engine = torch::autograd::Engine::getDefaultEngine();
autograd::variable_list v_inputs;
for(size_t i = 0; i < num_inputs - 1; i++) {
v_inputs.push_back(builder.addInput(std::move(peek(stack, i, num_inputs)), op->var_flags[i]));
}
drop(stack, num_inputs);
// TODO: handle create_graph appropriately
bool create_graph = true;
// note: node handle_in->use_count() == 1 means that we are guarenteed that we have the only
// only copy of the handle. This might make it seem it is ok to pass keep_graph=False.
// However, it is possible for 'copied_next_fns' to grab functions used by _other_ handles,
// and these functions will be executed in this run. Since these other handles
// may still be alive, it is not safe to release the graph
// TODO: we could cache this list in AutogradHandle (it's read only)
autograd::edge_list output_edges;
const auto num_inputs = handle_in->forward_inputs->num_inputs();
output_edges.reserve(num_inputs);
for (uint32_t i = 0; i < num_inputs; ++i)
output_edges.emplace_back(handle_in->forward_inputs, i);
auto values = engine.execute(handle_in->forward_outputs, v_inputs, true, create_graph, output_edges);
for(auto & v : values)
stack.push_back(builder.addOutput(v));
builder.writeTo(stack);
return 0;
};
}
@ -623,11 +467,8 @@ Operation getOperation(jit::Node* node) {
IR_IFM(node, PythonOp)
return createPythonOperation(value);
IR_ELSEIFM(CppOp)
if(dynamic_cast<autograd::Eval*>(value->fn.get())) {
return createEvalOperation(value);
} else {
return createCppOperation(value);
}
JIT_ASSERT(!dynamic_cast<autograd::Eval*>(value->fn.get()));
return createCppOperation(value);
IR_ELSEIF(FusionGroup)
auto fusion_fn = sharedFusionCompiler().getOrCompile(value);
auto num_inputs = value->inputs().size();

View File

@ -1,187 +0,0 @@
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/interpreter_autograd_function.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/jit/variable_flags.h"
#include <ATen/ATen.h>
#include <algorithm>
#include <cstdint>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
namespace torch { namespace jit {
using namespace torch::jit::tracer;
static at::Tensor zeroTensorWithType(const TensorType & type) {
auto device = (type.device() < 0)? at::kCPU : at::kCUDA;
auto & at_type = at::getType(device, type.scalarType());
// note: this has to be a contiguous tensor of zeros, because the fusion engine
// specialized to what is normally here which might be fully dense
return autograd::make_variable(at::zeros(at_type, type.sizes()));
}
autograd::variable_list InterpreterAutogradFunction::apply(
const autograd::variable_list& inputs) {
// Initial correctness checks.
if (stage_ == stage_details_.size()) {
throw std::runtime_error(std::string("Function compiled only for ") +
std::to_string(stage_details_.size() - 1) + " derivatives. Use nderivs argument " +
"to request more.");
}
if (used_) throw std::runtime_error(autograd::ERR_BACKWARD_TWICE);
used_ |= !keep_graph_;
const auto & details = stage_details_[stage_];
// Validate inputs
std::vector<at::Tensor> stack;
stack.reserve(inputs.size());
TORCH_ASSERT(inputs.size() == num_inputs_);
TORCH_ASSERT(inputs.size() == details.input_flags.size());
for (std::size_t i = 0; i < (std::size_t)inputs.size(); ++i) {
auto actual_flags = VariableFlags::of(inputs[i]);
auto traced_flags = details.input_flags[i];
// check that this trace is general enough to handle the input
// flags of the actual tensor. We can't handle the following two cases
// because we won't have a trace containing computation of either
// the tensor itself (not defined) or the stage for its gradient (requires_grad=False)
if(!traced_flags.defined && actual_flags.defined) {
throw std::runtime_error("JIT interpreter received a defined input, but the"
" trace was compiled with the input being undefined.");
}
if(!traced_flags.requires_grad && actual_flags.requires_grad) {
throw std::runtime_error("JIT inteperpreter recieved an input with "
" requires_grad=True, but was compiled with requires_grad=False");
}
// The remaining cases we can handle. If the gradient was not
// required but the trace will compute it, then we just compute it and
// ignore the result.
// However, if we are passed an undefined tensor, but the trace
// expects a defined tensor, then we have to give it one.
// Undefined tensors are used as stand-ins for zero tensors, so
// we create a zero-filled tensor of the right size
if(!actual_flags.defined) {
// [Temporary workaround for variants] until tracer produces all variants:
// This case appears commonly when you have a function
// x, y = fn(z)
// and only use x then gradient for y
// will be undefined. If you reuse the same trace with and _sometimes_ use y
// then in the cases where you don't use it, the grad_y input in stage 1
// will be undefined. To ensure we can continue, we create a 0 gradient,
// using trace information to figure out what shape it should be
if(traced_flags.defined) {
stack.push_back(zeroTensorWithType(interp_.tensorTypeForInput(i)));
} else {
stack.push_back(at::Tensor());
}
} else {
stack.push_back(inputs[i].detach());
}
}
// Run the interpreter
InterpreterState interp = (keep_graph_) ? interp_.clone() : interp_;
interp.runOneStage(stack);
// Lazily create grad_fn
std::shared_ptr<Function> grad_fn;
auto make_grad_fn = [&]() {
grad_fn = std::make_shared<InterpreterAutogradFunction>(
std::move(interp), stage_details_, stage_ + 1);
// Running this next stage is actually not valid (nderiv is too low)
// but we don't know if the user will ever ask for it so we don't error out here.
// Instead we have to return early because we rely on stage_details_[stage+1] in the
// remaining code
if(stage_ + 1 == stage_details_.size())
return;
// Patch next_edges to include prevous stage next_edges
// This is needed because stage N is really a derivative of
// all stages from 1 to N-1. If a part of stage x graph is
// reused in stage y (y > x), it is inlined by the tracer,
// and so we need to copy next_fns because those Variables
// aren't real inputs to that stage, so that's the only place
// where we can get them.
for (auto copied_idx : stage_details_[stage_ + 1].copied_next_fns) {
grad_fn->add_next_edge(next_edges_[copied_idx]);
}
// Add grad_fns corresponding to inputs
for(size_t i = 0; i < inputs.size(); ++i) {
// If an input isn't used, there's no gradient for it, and next stage
// won't even have its grad in the trace. Don't create an entry for it.
if (!details.used_inputs[i]) continue;
auto & input = inputs[i];
if (!details.input_flags[i].requires_grad) {
continue; // See Note [Null-edge pruning]
} else if (!input.defined() || !input.requires_grad()) {
// See Note [Temporary workaround for variants]
grad_fn->add_next_edge({});
continue;
}
grad_fn->add_next_edge(input.gradient_edge());
}
};
// Wrap the outputs
// TODO: handle views
autograd::variable_list result;
JIT_ASSERT(stack.size() == details.output_flags.size());
auto num_outputs = stack.size();
for (std::size_t i = 0; i < num_outputs; ++i) {
auto & flags = details.output_flags[i];
if (flags.requires_grad) { // See Note [Null-edge pruning]
if (!grad_fn) make_grad_fn();
auto variable = static_cast<const Variable&>(stack[i]);
autograd::create_gradient_edge(variable, grad_fn);
result.push_back(std::move(variable));
} else {
result.push_back(static_cast<const Variable&>(stack[i]));
}
}
return result;
}
InterpreterFunctionFactory::InterpreterFunctionFactory(TracingState *state) {
code_ = jit::Code(state->graph);
stage_details_.resize(state->graph->stage() + 1);
auto graph_inputs = state->graph->inputs();
auto inputs_it = graph_inputs.begin();
for (std::size_t stage = 0; stage < state->graph->stage() + 1; ++stage) {
auto & details = stage_details_[stage];
std::tie(details.input_flags, details.output_flags) = std::move(state->var_flags[stage]);
for (std::size_t i = 0; inputs_it != graph_inputs.end() && (*inputs_it)->stage() == stage; ++i, ++inputs_it) {
details.used_inputs.push_back((*inputs_it)->uses().size() > 0);
}
if (stage >= 1) {
auto & current_outputs = state->output_edges[stage];
auto & prev_outputs = state->output_edges[stage - 1];
for (auto & output : current_outputs) {
// Check if output appears in outputs of previous stage
auto prev_it = std::find(prev_outputs.begin(), prev_outputs.end(), output);
if (prev_it == prev_outputs.end()) continue;
// If yes, find its index and append that to the list of edges that will need
// to be copied in InterpreterAutogradFunction.
details.copied_next_fns.push_back(std::distance(prev_outputs.begin(), prev_it));
}
}
}
}
std::shared_ptr<InterpreterAutogradFunction> InterpreterFunctionFactory::construct() {
return std::make_shared<InterpreterAutogradFunction>(code_, stage_details_);
}
}}

View File

@ -1,78 +0,0 @@
#pragma once
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/variable_flags.h"
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>
namespace torch { namespace jit {
namespace tracer {
struct TracingState;
} // namespace tracer
struct StageDetails {
std::vector<VariableFlags> input_flags;
std::vector<VariableFlags> output_flags;
std::vector<int> copied_next_fns;
std::vector<bool> used_inputs;
};
struct InterpreterAutogradFunction : autograd::Function {
InterpreterAutogradFunction(
const jit::Code& code,
const std::vector<StageDetails>& stage_details)
// Stage 0 isn't run through the autograd, so we set this
// here just in case it is used.
: Function(/*num_inputs=*/stage_details.at(0).input_flags.size()),
interp_(code),
stage_details_(stage_details),
stage_(0) {}
InterpreterAutogradFunction(InterpreterState interp,
const std::vector<StageDetails>& stage_details,
std::size_t stage)
: interp_(std::move(interp))
, stage_details_(stage_details)
, stage_(stage) {}
// apply() is a protected method in `autograd::Function` since users should
// usually use the call operator, which invokes either `apply()` or
// `traced_apply()` depending on whether the function is traced. For
// InterpreterAutogradFunctions, however, we don't need this extra tracing
// logic. So we make it public here.
using autograd::Function::apply;
virtual void will_release_variables() override {
keep_graph_ = false;
}
virtual autograd::variable_list apply(const autograd::variable_list& inputs) override;
private:
InterpreterState interp_;
const std::vector<StageDetails>& stage_details_;
size_t stage_;
bool keep_graph_ = true;
bool used_ = false;
};
struct InterpreterFunctionFactory {
explicit InterpreterFunctionFactory(tracer::TracingState *state);
// Return `InterpreterAutogradFunction` because it has its apply() public.
std::shared_ptr<InterpreterAutogradFunction> construct();
// For when we need to pass a function with this signature.
std::shared_ptr<autograd::Function> construct_function() {
return construct();
}
private:
jit::Code code_;
std::vector<StageDetails> stage_details_;
};
}}

View File

@ -106,13 +106,10 @@ void PythonOp::cloneFrom(Node * other_) {
this->cconv = other->cconv;
Py_INCREF(other->pyobj.get());
this->pyobj = THPObjectPtr(other->pyobj.get());
this->var_flags = other->var_flags;
for(auto & sa : other->scalar_args) {
Py_INCREF(sa.get());
this->scalar_args.emplace_back(sa.get());
}
this->tracing_autograd_python_function =
other->tracing_autograd_python_function;
}
}} // namespace torch::jit

View File

@ -979,10 +979,8 @@ public:
Node* createPythonOp(
THPObjectPtr&& pyobj,
const std::string& cconv,
std::vector<VariableFlags>&& var_flags,
pyobj_list&& scalar_args,
bool tracing_autograd_python_function = true);
Node * createCppOp(const std::shared_ptr<torch::autograd::Function> & fn, std::vector<VariableFlags> && var_flags);
pyobj_list&& scalar_args);
Node * createCppOp(const std::shared_ptr<torch::autograd::Function> & fn);
// clone n, making a new node in _this_ graph.
// use node_map to translate inputs of n to inputs of the cloned node
// if copy_blocks is false, it will not recursively clone the nested blocks
@ -1271,14 +1269,10 @@ struct PythonOp : public Node {
PythonOp* init(
THPObjectPtr&& pyobj,
const std::string& cconv,
std::vector<VariableFlags>&& var_flags,
pyobj_list&& scalar_args,
bool tracing_autograd_python_function = true) {
pyobj_list&& scalar_args) {
this->pyobj = std::move(pyobj);
this->scalar_args = std::move(scalar_args);
this->cconv = cconv;
this->var_flags = std::move(var_flags);
this->tracing_autograd_python_function = tracing_autograd_python_function;
return this;
}
virtual Node * allocNewInstance(Graph * g) override {
@ -1300,27 +1294,21 @@ struct PythonOp : public Node {
// 's' -- python scalar argument
// 't' -- tensor argument
std::string cconv;
bool tracing_autograd_python_function;
// Scalar arguments to the Python function. Not necessarily passed to
// the function in this order; see cconv for the correct order.
std::vector<THPObjectPtr> scalar_args;
std::vector<VariableFlags> var_flags;
std::string name() const;
virtual void cloneFrom(Node * other_) override;
};
inline Node* Graph::createPythonOp(
THPObjectPtr&& pyobj,
const std::string& cconv,
std::vector<VariableFlags>&& var_flags,
pyobj_list&& scalar_args,
bool tracing_autograd_python_function) {
pyobj_list&& scalar_args) {
auto op = new PythonOp(this);
return op->init(
std::move(pyobj),
cconv,
std::move(var_flags),
std::move(scalar_args),
tracing_autograd_python_function);
std::move(scalar_args));
}
// A Cpp operator is an operator which dispatches directly to an autograd function.
@ -1330,12 +1318,10 @@ struct CppOp : public Node {
CppOp(Graph * g)
: Node(g,prim::CppOp) {}
std::shared_ptr<torch::autograd::Function> fn;
std::vector<VariableFlags> var_flags;
std::string name() const;
CppOp* init(std::shared_ptr<torch::autograd::Function> fn, std::vector<VariableFlags> && var_flags) {
CppOp* init(std::shared_ptr<torch::autograd::Function> fn) {
JIT_ASSERT(fn);
this->fn = std::move(fn);
this->var_flags = std::move(var_flags);
return this;
}
virtual Node * allocNewInstance(Graph * g) override {
@ -1345,12 +1331,11 @@ struct CppOp : public Node {
Node::cloneFrom(other_);
auto other = other_->cast<CppOp>();
this->fn = other->fn;
this->var_flags = other->var_flags;
}
};
inline Node * Graph::createCppOp(const std::shared_ptr<torch::autograd::Function> & fn, std::vector<VariableFlags> && var_flags) {
inline Node * Graph::createCppOp(const std::shared_ptr<torch::autograd::Function> & fn) {
auto op = new CppOp(this);
return op->init(fn, std::move(var_flags));
return op->init(fn);
}
inline graph_node_list_iterator Node::iterator() {

View File

@ -70,7 +70,7 @@ void initPythonIRBindings(PyObject * module_) {
std::string cconv(inputs.size(), 't');
func.attr("symbolic") = symbolic;
Node* new_node = g.insertNode(g.createPythonOp(
THPObjectPtr(func.release().ptr()), cconv, {}, {}, false));
THPObjectPtr(func.release().ptr()), cconv, {}));
for (auto i : inputs)
new_node->addInput(i);
std::vector<Value*> outputs;

View File

@ -54,7 +54,7 @@ void initPythonTracerBindings(PyObject* module_) {
});
m.def("_tracer_enter", [](variable_list trace_inputs, std::size_t num_backwards) {
return tracer::enter(std::move(trace_inputs), num_backwards + 1, true);
return tracer::enter(std::move(trace_inputs), num_backwards + 1);
});
m.def("_tracer_exit", [](variable_list var_outputs) {
tracer::exit(var_outputs);

View File

@ -65,7 +65,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
py::object func = self;
std::string cconv(inputs.size(), 't');
Node* new_node = g.insertNode(g.createPythonOp(
THPObjectPtr(func.release().ptr()), cconv, {}, {}, false));
THPObjectPtr(func.release().ptr()), cconv, {}));
new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
for(auto i : inputs)
new_node->addInput(i);

View File

@ -523,7 +523,7 @@ variable_list get_grad_outputs(const variable_list& vars) {
std::shared_ptr<Graph> trace(const ADTestSpec& test, const variable_list& vars_in) {
std::shared_ptr<tracer::TracingState> state;
variable_list trace_vars_in;
std::tie(state, trace_vars_in) = tracer::enter(vars_in, 1, true);
std::tie(state, trace_vars_in) = tracer::enter(vars_in, 1);
auto trace_vars_out = test(trace_vars_in);
tracer::exit(trace_vars_out);
return state->graph;

View File

@ -49,7 +49,7 @@ std::shared_ptr<torch::jit::Graph> torch::jit::tracer::createGraphByTracing(
py::function func,
tracer::variable_list trace_inputs,
size_t num_func_inputs) {
auto enter_info = tracer::enter(std::move(trace_inputs), 1, false);
auto enter_info = tracer::enter(std::move(trace_inputs), 1);
py::tuple py_inputs(num_func_inputs);
for(size_t i = 0; i < num_func_inputs; ++i) {
py_inputs[i] = py::cast(enter_info.second[i]);
@ -206,7 +206,6 @@ PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj,
std::string arg_types,
at::ArrayRef<Variable> inputs,
pyobj_list scalar_args) {
std::vector<VariableFlags> var_flags = fmap(inputs, &VariableFlags::of);
THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
if(!apply) {
throw python_error();
@ -215,9 +214,7 @@ PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj,
return graph.createPythonOp(
std::move(apply),
arg_types,
std::move(var_flags),
std::move(scalar_args),
state->creates_handles);
std::move(scalar_args));
});
}
#endif

View File

@ -221,9 +221,8 @@ inline Value* getOutputTrace(const std::shared_ptr<TracingState>& state, const V
// reference to at::Tensor buffer to call unsafeGetTH, but you can't get this
// out of a const vector (silly std::vector...)
inline std::pair<std::shared_ptr<TracingState>, variable_list> enter(
variable_list inputs, std::size_t num_stages, bool creates_handles) {
variable_list inputs, std::size_t num_stages) {
auto state = std::make_shared<TracingState>(num_stages);
state->creates_handles = creates_handles;
for (auto& input : inputs) {
auto * value_state = detail::getValueState(state, input, false);
if (value_state) {

View File

@ -20,8 +20,7 @@ TracingState::TracingState(size_t num_stages)
num_stages(num_stages),
eval_count(0),
var_flags(num_stages),
output_edges(num_stages),
creates_handles(true) {}
output_edges(num_stages) {}
TracingState::~TracingState() = default;

View File

@ -54,10 +54,7 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
std::mutex mutex;
variable_list inputs; // Used only for the duration of first stage
bool creates_handles; // should python ops ever get handles? Or should
// we just record them as is.
std::unique_lock<std::mutex> lock() {
return std::unique_lock<std::mutex>(mutex);
}