diff --git a/setup.py b/setup.py index c3d32aee415..96240b2d4d0 100644 --- a/setup.py +++ b/setup.py @@ -435,6 +435,7 @@ main_sources = [ "torch/csrc/allocators.cpp", "torch/csrc/serialization.cpp", "torch/csrc/jit/init.cpp", + "torch/csrc/jit/interpreter.cpp", "torch/csrc/jit/ir.cpp", "torch/csrc/jit/python_ir.cpp", "torch/csrc/jit/test_jit.cpp", diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index e7388d3273d..5dfac6c7550 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -33,13 +33,13 @@ auto ${name} = ${type_cast}(node->${method}(stringToSymbol("${name}")));\ """) CALL_NAMESPACE = CodeTemplate("at::${name}(${args})") -CALL_METHOD = CodeTemplate("vars[0].${name}(${args})") +CALL_METHOD = CodeTemplate("inputs[0].${name}(${args})") CONSTRUCTOR = CodeTemplate("""\ {"${descriptor}", [](Node *node) { ${assignments} - return TensorOp([=](const variable_list& vars) -> variable_list { - return pack_list(${call}); + return TensorOp([=](const std::vector & inputs, std::vector & outputs) { + pack_list(outputs, ${call}); }, "${name}", ${num_inputs}); }}, """) @@ -85,16 +85,16 @@ def gen_jit_dispatch(declarations, out): if 'namespace' in decl['method_of']: if any(arg['simple_type'] == 'TensorList' for arg in arguments): assert sum(map(is_tensor_arg, arguments)) == 1 - args = ['as_tensor_list(vars)' if is_tensor_arg(arg) else arg['name'] + args = ['inputs' if is_tensor_arg(arg) else arg['name'] for arg in arguments] else: tensor_id = iter(count(start=0)) - args = ['vars[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name'] + args = ['inputs[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name'] for arg in arguments] call = CALL_NAMESPACE.substitute(name=name, args=args) else: tensor_id = iter(count(start=1)) - args = ['vars[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name'] + args = ['inputs[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name'] for arg in arguments[1:]] call = CALL_METHOD.substitute(name=name, args=args) diff --git a/tools/jit/templates/aten_dispatch.cpp b/tools/jit/templates/aten_dispatch.cpp index 08d383e843a..0c203b18723 100644 --- a/tools/jit/templates/aten_dispatch.cpp +++ b/tools/jit/templates/aten_dispatch.cpp @@ -19,18 +19,19 @@ using operator_constructor = std::function; namespace { -variable_list pack_list(Tensor v) { return { std::move(v) }; } -variable_list pack_list(Scalar v) { return { v.toTensor() }; } -variable_list pack_list(std::vector t) { return fmap(t); } -variable_list pack_list(std::tuple v) { - return { std::move(std::get<0>(v)), std::move(std::get<1>(v)) }; +void pack_list(std::vector & outputs, Tensor v) { outputs.push_back(v); } +void pack_list(std::vector & outputs, Scalar v) { outputs.push_back(v.toTensor()); } +void pack_list(std::vector & outputs, const std::vector & t) { + outputs.insert(outputs.end(), t.begin(), t.end()); } -variable_list pack_list(std::tuple v) { - return { std::get<0>(v), std::get<1>(v), std::get<2>(v) }; +void pack_list(std::vector & outputs, std::tuple v) { + outputs.push_back(std::get<0>(v)); + outputs.push_back(std::get<1>(v)); } - -std::vector as_tensor_list(const variable_list& vars) { - return fmap(vars, [](Variable v) { return static_cast(v); }); +void pack_list(std::vector & outputs, std::tuple v) { + outputs.push_back(std::get<0>(v)); + outputs.push_back(std::get<1>(v)); + outputs.push_back(std::get<2>(v)); } template diff --git a/tools/jit/templates/aten_dispatch.h b/tools/jit/templates/aten_dispatch.h index da1e197be82..91180a9fa38 100644 --- a/tools/jit/templates/aten_dispatch.h +++ b/tools/jit/templates/aten_dispatch.h @@ -8,7 +8,7 @@ namespace torch { namespace jit { struct TensorOp { - using op_type = std::function; + using op_type = std::function &, std::vector &)>; TensorOp(op_type op, std::string name, size_t num_inputs) : op(op) diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 756f56e40c3..602653d54af 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -204,7 +204,9 @@ static variable_list call_function(FunctionTask& task) { auto& callback = it_p.first->second; if (!callback(&fn, inputs)) return variable_list(fn.next_functions.size()); } - + if(!task.base->keep_graph) { + fn.willReleaseVariables(); + } auto outputs = fn(inputs); auto& post_callbacks = task.base->post_callbacks; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index a7efd3eb864..57952b772ef 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -103,7 +103,10 @@ struct Function : std::enable_shared_from_this { // Releases saved variables if the operation won't be reused virtual inline void releaseVariables() {} - + // called before a an apply if will release variables is going to be called + // allows larger ops like InterpreterAutogradFunction + // to incrementally release variables as they run + virtual inline void willReleaseVariables() {} // Function name for debugging virtual std::string name(); diff --git a/torch/csrc/autograd/functions/jit_closure.cpp b/torch/csrc/autograd/functions/jit_closure.cpp index 33cff95af18..97d70416350 100644 --- a/torch/csrc/autograd/functions/jit_closure.cpp +++ b/torch/csrc/autograd/functions/jit_closure.cpp @@ -118,7 +118,14 @@ struct EmitNull : public Function { struct LambdaFunction : public Function { LambdaFunction(const jit::TensorOp& op) - : LambdaFunction(op.num_inputs, op.op) { + : LambdaFunction(op.num_inputs, nullptr) { + auto & real_op = op.op; + this->fn_ = [real_op](const variable_list& inputs) -> variable_list { + std::vector tinputs(inputs.begin(), inputs.end()); + std::vector toutputs; + real_op(tinputs, toutputs); + return variable_list(toutputs.begin(), toutputs.end()); + }; this->name_ = op.name; } @@ -279,10 +286,6 @@ struct FusionGroupFunction : public Function { data.push_back(input.data()); AutoGPU guard(data.back()); std::vector outputs; - outputs.reserve(function->outputDescriptors().size()); - for(auto & od : function->outputDescriptors()) { - outputs.push_back(at::CUDA(od.scalar_type).tensor()); - } function->launch(data, outputs); return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) { return std::make_shared("FusionGroupFunction is not differentiable", std::move(f)); diff --git a/torch/csrc/jit/fusion_compiler.cpp b/torch/csrc/jit/fusion_compiler.cpp index 131ad24ab57..b98163408c7 100644 --- a/torch/csrc/jit/fusion_compiler.cpp +++ b/torch/csrc/jit/fusion_compiler.cpp @@ -419,7 +419,7 @@ void compressContiguous( } // anonymous namespace -void CompiledFusionFunction::launch(at::ArrayRef inputs, at::ArrayRef outputs) { +void CompiledFusionFunction::launch_with_tensors(at::ArrayRef inputs, at::ArrayRef outputs) { AutoGPU gpu_guard(inputs); JIT_ASSERT(inputs.size() == input_desc.size()); JIT_ASSERT(outputs.size() == output_desc.size()); @@ -479,6 +479,16 @@ void CompiledFusionFunction::launch(at::ArrayRef inputs, at::ArrayRe launch(numel, arguments.data()); } +void CompiledFusionFunction::launch(at::ArrayRef inputs, std::vector & outputs) { + AutoGPU guard(inputs.back()); + outputs.clear(); + outputs.reserve(outputDescriptors().size()); + for(auto & od : outputDescriptors()) { + outputs.push_back(at::CUDA(od.scalar_type).tensor()); + } + launch_with_tensors(inputs, outputs); +} + void CompiledFusionFunction::launch(uint32_t numel, void ** arguments) { int numBlocks = std::min(maxBlocks, ceilDiv(numel, blockSize)); //std::cout << "maxBlocks = " << maxBlocks << " needed blocks: " << ceilDiv(numel,blockSize) @@ -539,7 +549,7 @@ void FusionCompiler::debugLaunchGraph(Graph & graph, at::ArrayRef in agraph.output_desc.emplace_back(i); } auto func = getOrCompile(agraph); - func->launch(inputs, outputs); + func->launch_with_tensors(inputs, outputs); } //TODO: thread safety diff --git a/torch/csrc/jit/fusion_compiler.h b/torch/csrc/jit/fusion_compiler.h index 1d0fbccddee..d0e241485d4 100644 --- a/torch/csrc/jit/fusion_compiler.h +++ b/torch/csrc/jit/fusion_compiler.h @@ -85,7 +85,11 @@ struct CompiledFusionFunction { CompiledFusionFunction(const std::string & name, AnnotatedGraph & agraph); ~CompiledFusionFunction(); - void launch(at::ArrayRef inputs, at::ArrayRef outputs); + // expects outputs to be pre-allocated + void launch_with_tensors(at::ArrayRef inputs, at::ArrayRef outputs); + + // creates new tensors for outputs + void launch(at::ArrayRef inputs, std::vector & outputs); const std::vector & outputDescriptors() const { return output_desc; } diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp new file mode 100644 index 00000000000..36db06b38b8 --- /dev/null +++ b/torch/csrc/jit/interpreter.cpp @@ -0,0 +1,293 @@ +#include "interpreter.h" +#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/generated/aten_dispatch.h" +#ifdef WITH_CUDA +#include "torch/csrc/jit/fusion_compiler.h" +#endif + +namespace torch { namespace jit { + +using tensor_list = std::vector; +using Callback = std::function; +// Returns a function implementing functionality of a given node, +// or nullptr if it's a no-op for autograd. +Callback getCallback(Node *node) { + IR_IFM(node, PythonOp) + throw NotImplementedException(); + IR_ELSEIFM(CppOp) + throw NotImplementedException(); + IR_ELSEIF(Select) + barf("getCallback() on select?"); + IR_ELSEIF(FusionGroup) +#ifdef WITH_CUDA + auto fusion_fn = sharedFusionCompiler().getOrCompile(*value->g(kSubgraph)); + return [fusion_fn](const tensor_list & inputs, tensor_list & outputs) { + fusion_fn->launch(inputs, outputs); + }; +#else + throw std::runtime_error("don't know how to execute FusionGroups without CUDA"); +#endif + IR_ELSEIF(Constant) + auto t = value->t(kvalue); + return [t](const tensor_list & inputs, tensor_list & outputs) { + outputs.push_back(t); + }; + IR_ELSEIF(Undefined) + return [](const tensor_list & inputs, tensor_list & outputs) { + outputs.push_back(at::Tensor()); + }; + IR_ELSE() + return getTensorOp(node).op; + IR_END() +} + + +// We need some lists for inputs and outputs. To keep all the memory +// contiguous we allocate a single vector and use offsets into the vector +// which are stored in the ListHandle struct +// start is an offset into int_data of Code for ListHandle +// and bool_data of Code for ListHandle +template +struct ListHandle { + int start; + int size; +}; + +struct UseList { + // values to be used + ListHandle values; + // boolean flags indicating whether to free the Tensor after this use + ListHandle free_flags; +}; + +// one instruction plus meta-data +struct Instruction { + Callback callback; + UseList inputs; + ListHandle outputs; +}; + +struct Stage { + ListHandle inputs; // inputs to define for the stage + UseList outputs; // values consumed by the return + std::vector instructions; +}; + +// pre-processing that happens once per graph +struct CodeImpl { + CodeImpl(std::shared_ptr & graph) + : graph(graph) { + int64_t cur_stage = -1; + size_t input_pos = 0; + size_t output_pos = 0; + // step 1: encode all operators and stages into registers and fill in + // input/output lists + for(auto node : graph->nodes()) { + if(node->kind() == kSelect) + continue; + insertStagesTo(cur_stage, node->stage(), input_pos, output_pos); + cur_stage = node->stage(); + stages.back().instructions.emplace_back(); + auto & inst = stages.back().instructions.back(); + listBegin(inst.inputs.values); + for(auto input : node->inputs()) { + listInsert(inst.inputs.values, getOrAllocateRegister(input, true)); + } + listBegin(inst.outputs); + for(auto output : node->outputs()) { + listInsert(inst.outputs, getOrAllocateRegister(output)); + } + inst.callback = getCallback(node); + } + // it is possible that the final stages have no instructions in them + // and are just identity functions. We call insertStagesTo here + // to force all these empty stages to be generated if they exist + insertStagesTo(cur_stage, graph->stage(), input_pos, output_pos); + + // step 2: the last time we use a register we want to mark its free_flag + // so we clean it up + // this is done with a backward scan where we mark the first time we see it + std::unordered_set seen_registers; + auto scanUses = [&](UseList & u) { + listBegin(u.free_flags); + for(int i = 0; i < u.values.size; i++) { + int reg = get(u.values,i); + listInsert(u.free_flags, seen_registers.count(reg) == 0); + seen_registers.insert(reg); + } + }; + for(auto sit = stages.rbegin(); sit != stages.rend(); sit++) { + scanUses(sit->outputs); + for(auto iit = sit->instructions.rbegin(); iit != sit->instructions.rend(); iit++) { + scanUses(iit->inputs); + } + } + } + void insertStagesTo(int64_t cur_stage, int64_t goal_stage, size_t & input_pos, size_t & output_pos) { + while(cur_stage < goal_stage) { + cur_stage++; + stages.emplace_back(); + auto & stage = stages.back(); + listBegin(stage.inputs); + for(;input_pos < graph->inputs().size(); input_pos++) { + auto input = graph->inputs()[input_pos]; + if((int64_t)input->stage() > cur_stage) + break; + // unused inputs are given a false register -1 so that we never hold a + // reference to the tensor data, otherwise we would fail to clean them + // up since they do not have a last use at which to free them + int reg = input->uses().size() > 0 ? getOrAllocateRegister(input) : -1; + listInsert(stage.inputs, reg); + } + listBegin(stage.outputs.values); + for(;output_pos < graph->outputs().size(); output_pos++) { + auto output = graph->outputs()[output_pos]; + if((int64_t)output->stage() > cur_stage) + break; + listInsert(stage.outputs.values, getOrAllocateRegister(output)); + } + } + } + // helpers to build/access RegList objects + int get(ListHandle & list, int i) { + return int_data[list.start + i]; + } + void listBegin(ListHandle & list) { + list.start = int_data.size(); + list.size = 0; + } + void listInsert(ListHandle & list, int value) { + JIT_ASSERTM(list.start + list.size == (int)int_data.size(), "another list already started"); + int_data.push_back(value); + list.size++; + } + void listBegin(ListHandle & list) { + list.start = bool_data.size(); + list.size = 0; + } + void listInsert(ListHandle & list, int value) { + JIT_ASSERTM(list.start + list.size == (int)bool_data.size(), "another list already started"); + bool_data.push_back(value); + list.size++; + } + + int getOrAllocateRegister(Node * n, bool required = false) { + size_t u = n->unique(); + if(unique_to_reg.count(u) > 0) + return unique_to_reg[u]; + JIT_ASSERT(!required); + int r = register_size++; + unique_to_reg[u] = r; + return r; + } + std::shared_ptr graph; + std::unordered_map unique_to_reg; // map from unique of nodes to register in register table + + friend struct InterpreterState; + std::vector stages; + int register_size = 0; + + // all memory ArrayRef are slices of this, to make sure + // the interpreter is mostly linearly scanning through memory + std::vector int_data; + std::vector bool_data; +}; + +// InterpreterState state that is held across stages and used to compute a Code +struct InterpreterStateImpl { + InterpreterStateImpl(const Code & function_) + : function(function_.pImpl), + int_data(function->int_data.data()), + bool_data(function->bool_data), + registers(function->register_size) { + } + void runOneStage( + const std::vector & inputs, + std::vector & outputs) { + //std::cout << "running stage: " << current_stage << " of " << function->stages.size() << "\n"; + JIT_ASSERT(current_stage < function->stages.size()); + auto & stage = function->stages[current_stage++]; + JIT_ASSERT((int)inputs.size() == stage.inputs.size); + for(int i = 0; i < stage.inputs.size; i++) { + int reg = get(stage.inputs,i); + if(reg >= 0) { // otherwise this input is dead, and we do not store it to avoid holding the reference + registers[reg] = inputs[i]; + } + //std::cout << "registers[" << reg << "] = inputs[" << i << "](" << inputs[i].defined() << ")\n"; + } + for(auto & inst : stage.instructions) { + loadTensorsFromRegisters(inst.inputs, input_buffer); + inst.callback(input_buffer, output_buffer); + for(int i = 0; i < inst.outputs.size; i++) { + int reg = get(inst.outputs,i); + registers[reg] = std::move(output_buffer[i]); + //std::cout << "registers[" << reg << "] = outputs[" << i << "](" << output_buffer[i].defined() << ")\n"; + } + output_buffer.clear(); + input_buffer.clear(); + } + outputs.clear(); + loadTensorsFromRegisters(stage.outputs, outputs); + } + int get(const ListHandle & list, int i) { + return int_data[list.start + i]; + }; + bool get(const ListHandle & list, int i) { + return bool_data[list.start + i]; + } + void loadTensorsFromRegisters(const UseList & uses, std::vector & outputs) { + for(int i = 0; i < uses.values.size; i++) { + int reg = get(uses.values,i); + auto & value = registers[reg]; + //std::cout << "inputs[" << i << "] = registers[" << reg << "] (" << value.defined() << ")"; + if(get(uses.free_flags,i)) { + //std::cout << " and FREED"; + outputs.push_back(std::move(value)); + } else { + outputs.push_back(value); + } + //std::cout << "\n"; + } + } + size_t current_stage = 0; + std::shared_ptr function; // keep function alive + // these are just copies of function to prevent indirections in intepreter + int * int_data; + const std::vector & bool_data; + + + // this holds all the tensors for this interpreter run + // we don't bother minimizing the size of this vector, since the extra + // memory used by the pointers in this will be small + // instead we are very aggresive about releasing tensors when they become dead + // to make sure memory management happens efficiently. + + // We optimize for the case where derivatives are run with retain_graph=False + // in the case where it is true, then the interpreter and this array get copied + // if this every becomes a bottleneck then we _should_ consider minimizing the + // total number or register + std::vector registers; + + // single buffer for input calls to ATen functions, so that we do not reallocate + std::vector input_buffer; + // also to prevent allocations + std::vector output_buffer; +}; + +Code::Code(std::shared_ptr & graph) +: pImpl(new CodeImpl(graph)) {} +Code::~Code() {} +InterpreterState::InterpreterState(const Code & function) +: pImpl(new InterpreterStateImpl(function)) {} +InterpreterState::~InterpreterState() {} +void InterpreterState::runOneStage( + const std::vector & inputs, + std::vector & outputs) { + return pImpl->runOneStage(inputs, outputs); +} +InterpreterState InterpreterState::clone() const { + return InterpreterState(new InterpreterStateImpl(*pImpl)); +} +InterpreterState::InterpreterState(InterpreterStateImpl * pImpl) : pImpl(pImpl) {} + +}} diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h new file mode 100644 index 00000000000..eb1954bf842 --- /dev/null +++ b/torch/csrc/jit/interpreter.h @@ -0,0 +1,53 @@ +#pragma once +#include +#include + +namespace at { + struct Tensor; +} +namespace torch { namespace jit { + +struct NotImplementedException : public std::logic_error { + NotImplementedException() + : std::logic_error("Function not yet implemented.") {} +}; + +// The interpreter run Graphs with Tensor inputs and Tensor outputs +// a separate component in the autograd handles unwrapping and wrapping +// variable objects for use in the interpreter. + +struct CodeImpl; +struct InterpreterStateImpl; +struct Graph; + +struct Code { + Code() + : pImpl(nullptr) {} + Code(std::shared_ptr & graph); + ~Code(); + operator bool() const { + return pImpl != nullptr; + } +private: + std::shared_ptr pImpl; + friend class InterpreterStateImpl; +}; + +struct InterpreterState { + InterpreterState(const Code & code); + // advance the interpreter state by running one stage. Returning the + // outputs for that stage, suspending the computation. + // Call this function again continues computation where it left off. + void runOneStage( + const std::vector & inputs, + std::vector & outputs); + ~InterpreterState(); + // create a copy of InterpreterState with its current state + // used when retain_graph=True so that stages can be re-run + InterpreterState clone() const; +private: + InterpreterState(InterpreterStateImpl * pImpl); + std::shared_ptr pImpl; +}; + +}} diff --git a/torch/csrc/jit/interpreter_autograd_function.h b/torch/csrc/jit/interpreter_autograd_function.h new file mode 100644 index 00000000000..651c503a9af --- /dev/null +++ b/torch/csrc/jit/interpreter_autograd_function.h @@ -0,0 +1,36 @@ +#pragma once + +#include "torch/csrc/jit/interpreter.h" +#include "torch/csrc/autograd/function.h" +#include "torch/csrc/autograd/functions/utils.h" +#include "torch/csrc/autograd/functions/basic_ops.h" +namespace torch { namespace jit { +struct InterpreterAutogradFunction : public autograd::Function { + InterpreterAutogradFunction(const jit::Code & code) + : interp_(code) {} + InterpreterAutogradFunction(const InterpreterState & interp_, autograd::FunctionFlags && f) + : autograd::Function(std::move(f)), interp_(interp_) {} + + virtual void willReleaseVariables() override { + keep_graph = false; + } + virtual autograd::variable_list apply(const autograd::variable_list& inputs) override { + std::vector tinputs; + std::vector toutputs; + for(auto & i : inputs) { + tinputs.push_back(i.data()); + } + InterpreterState interp = (keep_graph) ? interp_.clone() : interp_; + keep_graph = true; + interp.runOneStage(tinputs, toutputs); + auto r = autograd::wrap_outputs(inputs, std::move(toutputs), [&](autograd::FunctionFlags f) { + return std::make_shared(interp, std::move(f)); + }); + return r; + } +private: + bool keep_graph = true; + InterpreterState interp_; +}; + +}} diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 6940ed77206..6dd2b82fb52 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -854,7 +854,7 @@ struct PythonOp : public Node { // The Python object which contains the implementation of this function. // This is either a class (non-legacy) or an object (legacy). See - // TraceInterpreter for execution semantics. + // TraceInterpreterState for execution semantics. THPObjectPtr pyobj; // The calling convention for the Python function. // 's' -- python scalar argument diff --git a/torch/csrc/jit/python_compiled_function.cpp b/torch/csrc/jit/python_compiled_function.cpp index 53ad8d91cf3..a61c1ebc9d6 100644 --- a/torch/csrc/jit/python_compiled_function.cpp +++ b/torch/csrc/jit/python_compiled_function.cpp @@ -9,6 +9,8 @@ #include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/inplace_check.h" #include "torch/csrc/jit/python_arg_flatten.h" +#include "torch/csrc/jit/interpreter.h" +#include "torch/csrc/jit/interpreter_autograd_function.h" #include #include @@ -54,7 +56,7 @@ struct CompiledFunction { , is_volatile_(is_volatile) {} bool ready() { - if (closure_) return true; + if (is_ready_) return true; // Remove expired traces traces_.erase(std::remove_if(traces_.begin(), @@ -83,20 +85,30 @@ struct CompiledFunction { PeepholeOptimize(complete_trace->graph); FuseGraph(complete_trace->graph); } - - closure_ = std::make_shared(complete_trace.get()); + try { + code_ = jit::Code(complete_trace->graph); + } catch(const jit::NotImplementedException & ex) { + closure_ = std::make_shared(complete_trace.get()); + } + is_ready_ = true; return true; } variable_list run(const variable_list& in_vars) { - JIT_ASSERT(closure_); + JIT_ASSERT(is_ready_); AutoNoGIL _gil_guard; - auto fn = closure_->construct(); - return (*fn)(in_vars); + if(closure_) { + auto fn = closure_->construct(); + return (*fn)(in_vars); + } else { + InterpreterAutogradFunction interp(code_); + interp.willReleaseVariables(); // forward pass is never reused, so it is safe to release anything it can + return interp.apply(in_vars); + } } PyObject* add_trace(PyObject *args, const variable_list& in_vars) { - JIT_ASSERT(!closure_); + JIT_ASSERT(!is_ready_); // Start tracing auto trace = tracer::enter(fmap(in_vars), is_volatile_ ? 1 : (fn_.nderivs_ + 1)); @@ -120,7 +132,9 @@ struct CompiledFunction { CompiledFunction& fn_; std::string out_desc_; + bool is_ready_ = false; std::shared_ptr closure_; + jit::Code code_; std::vector> traces_; bool is_volatile_; }; diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp index 75db2c3a75f..52a058dce5b 100644 --- a/torch/csrc/jit/test_jit.cpp +++ b/torch/csrc/jit/test_jit.cpp @@ -9,6 +9,7 @@ #include "torch/csrc/jit/attributes.h" #include "torch/csrc/jit/interned_strings.h" #include +#include "torch/csrc/jit/interpreter.h" namespace torch { namespace jit { @@ -246,7 +247,208 @@ void internedStringsTests () { } + +at::Tensor t_use(at::Tensor x) { + return x; +} +at::Tensor t_def(at::Tensor x) { + return x.t(); +} + +// given the difference of output vs expected tensor, check whether the +// difference is within a relative tolerance range. This is a standard way of +// matching tensor values upto certain precision +bool checkRtol(const at::Tensor& diff, const std::vector inputs) { + double maxValue = 0.0; + for (auto& tensor : inputs) { + maxValue = fmax(tensor.abs().max().toCFloat(), maxValue); + } + return diff.abs().max().toCFloat() < 2e-6 * maxValue; +} +bool almostEqual(const at::Tensor & a, const at::Tensor & b) { + return checkRtol(a - b,{a, b}); +} + +bool exactlyEqual(const at::Tensor & a, const at::Tensor & b) { + return (a - b).abs().max().toCFloat() == 0.f; +} + +std::pair +lstm(at::Tensor input, + at::Tensor hx, + at::Tensor cx, + at::Tensor w_ih, + at::Tensor w_hh) { + auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh)); + + auto chunked_gates = gates.chunk(4, 1); + auto ingate = chunked_gates[0]; + auto forgetgate = chunked_gates[1]; + auto cellgate = chunked_gates[2]; + auto outgate = chunked_gates[3]; + + ingate = ingate.sigmoid(); + outgate = outgate.sigmoid(); + cellgate = cellgate.tanh(); + forgetgate = forgetgate.sigmoid(); + + auto cy = (forgetgate * cx) + (ingate * cellgate); + auto hy = outgate * cy.tanh(); + + return {hy, cy}; +} + +Symbol sym(const char * str) { + return stringToSymbol(str); +} + +Node * node(Graph& graph, const char * n, ArrayRef inputs) { + return graph.appendNode(graph.create(sym(n),inputs)); +} + +Node * add(Graph & g, Node * a, Node * b) { + auto r = node(g, "add", {a,b}); + r->t_(sym("alpha"), at::Scalar(1).toTensor()); + return r; +} + +std::tuple build_lstm_body( + Graph & g, + Node * input, + Node * hx, + Node * cx, + Node * w_ih, + Node * w_hh) { + auto gates = add(g, node(g,"mm",{ input, w_ih }), node(g, "mm", {hx, w_hh})); + auto chunked_gates = node(g, "chunk", { gates }); + chunked_gates->i_(sym("chunks"), 4); + chunked_gates->i_(sym("dim"), 1); + auto ingate = g.appendNode(g.createSelect(chunked_gates, 0)); + auto forgetgate = g.appendNode(g.createSelect(chunked_gates, 1)); + auto cellgate = g.appendNode(g.createSelect(chunked_gates, 2)); + auto outgate = g.appendNode(g.createSelect(chunked_gates, 3)); + ingate = node(g,"sigmoid",{ingate}); + outgate = node(g,"sigmoid",{outgate}); + cellgate = node(g,"tanh",{cellgate}); + forgetgate = node(g,"sigmoid",{forgetgate}); + + auto cy = add(g, node(g,"mul", {forgetgate, cx}) , node(g, "mul", {ingate, cellgate})); + auto hy = node(g, "mul", {outgate, node(g, "tanh", {cy})}); + + return std::make_tuple(hy,cy); +} + +std::shared_ptr build_lstm() { + auto r = std::make_shared(); + auto & g = *r; + Node * input = g.addInput(); + Node * hx = g.addInput(); + Node * cx = g.addInput(); + Node * w_ih = g.addInput(); + Node * w_hh = g.addInput(); + + Node * hy; + Node * cy; + std::tie(hy,cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh); + + g.registerOutput(hy); + g.registerOutput(cy); + g.lint(); + + return r; +} + +std::shared_ptr build_lstm_stages() { + auto r = std::make_shared(); + auto & g = *r; + Node * input = g.addInput(); + Node * hx = g.addInput(); + Node * cx = g.addInput(); + Node * w_ih = g.addInput(); + Node * w_hh = g.addInput(); + + Node * hy; + Node * cy; + std::tie(hy,cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh); + + // use some stuff from the previous stage as well + // as a new input + g.advanceStage(); + hx = hy; + g.registerOutput(cy); + cx = g.addInput(); + + std::tie(hy,cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh); + + g.registerOutput(hy); + g.registerOutput(cy); + g.lint(); + + return r; +} + + +void interpTest() { + constexpr int batch_size = 4; + constexpr int input_size = 256; + constexpr int seq_len = 32; + + int hidden_size = 2*input_size; + + auto input = at::CUDA(at::kFloat).randn({seq_len, batch_size, input_size}); + auto hx = at::CUDA(at::kFloat).randn({batch_size, hidden_size}); + auto cx = at::CUDA(at::kFloat).randn({batch_size, hidden_size}); + auto w_ih = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, input_size})); + auto w_hh = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, hidden_size})); + + auto lstm_g = build_lstm(); + Code lstm_function(lstm_g); + std::vector outputs; + InterpreterState lstm_interp(lstm_function); + lstm_interp.runOneStage({input[0], hx, cx, w_ih, w_hh}, outputs); + std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); + + //std::cout << almostEqual(outputs[0],hx) << "\n"; + JIT_ASSERT(exactlyEqual(outputs[0],hx)); + JIT_ASSERT(exactlyEqual(outputs[1],cx)); +} + +void interpStageTest() { + constexpr int batch_size = 4; + constexpr int input_size = 256; + constexpr int seq_len = 32; + + int hidden_size = 2*input_size; + auto input = at::CUDA(at::kFloat).randn({seq_len, batch_size, input_size}); + auto hx = at::CUDA(at::kFloat).randn({batch_size, hidden_size}); + auto cx = at::CUDA(at::kFloat).randn({batch_size, hidden_size}); + auto cx1 = at::CUDA(at::kFloat).randn({batch_size, hidden_size}); + auto w_ih = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, input_size})); + auto w_hh = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, hidden_size})); + + + auto lstm_g = build_lstm_stages(); + Code lstm_function(lstm_g); + std::vector outputs; + InterpreterState lstm_interp(lstm_function); + lstm_interp.runOneStage({input[0], hx, cx, w_ih, w_hh}, outputs); + auto cy0 = outputs[0]; + lstm_interp.runOneStage({cx1}, outputs); + at::Tensor ihx = outputs[0]; + at::Tensor icx = outputs[1]; + + + std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); + std::tie(hx, cx) = lstm(input[0], hx, cx1, w_ih, w_hh); + + //std::cout << almostEqual(outputs[0],hx) << "\n"; + JIT_ASSERT(exactlyEqual(outputs[0],hx)); + JIT_ASSERT(exactlyEqual(outputs[1],cx)); +} + void runJITCPPTests() { + interpTest(); + interpStageTest(); codeTemplateTest(); fusionTests(); attributesTest();