mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add a JIT interpreter (#3634)
* Add a JIT interpreter The separate interpreter is used to graphs with a lower overhead than converting them to autograd graphs. Some notes: * does not support Handles/PythonOp/CppOp, these will be in a future commit * jit_closure.cpp still exists and we fall back to it for now when cannot handle something because of PythonOp/CppOp * In order to support retain_graph=True, the interpreter can be cloned, creating a copy that can be run with different arguments. This is assumed to be the non-standard case so cloning is not particularly optimized. No tensor _data_ is copied, but the at::Tensor list in the interpreter is. If we hit problems, there is a lot we could do (such as register allocation) to minimize the stuff that needs to be copied. * Uses a pImpl pattern to keep implementation details out of its header file. * Modifies the way getTensorOp works so that it reads/writes to already-existing vectors, this prevents needing to realloc these buffers each time. * Timings are here: https://gist.github.com/zdevito/5a20ac29fb1b9e449e693b67dc478127 This reduces overhead to about the same as running it in python. It is about 10us faster to run the same thing using ATen directly. * Code Mod Interpreter -> InterpreterState Function -> Code Add other requested comments. * RegList -> ListHandle<T> Change the RegList functions to be safer by identifying the type of each argument list, and checking that list insert does not try to add to two different lists at once. * Use exactly equal for interp tests
This commit is contained in:
parent
b67acd2d39
commit
e43ff32192
1
setup.py
1
setup.py
|
|
@ -435,6 +435,7 @@ main_sources = [
|
||||||
"torch/csrc/allocators.cpp",
|
"torch/csrc/allocators.cpp",
|
||||||
"torch/csrc/serialization.cpp",
|
"torch/csrc/serialization.cpp",
|
||||||
"torch/csrc/jit/init.cpp",
|
"torch/csrc/jit/init.cpp",
|
||||||
|
"torch/csrc/jit/interpreter.cpp",
|
||||||
"torch/csrc/jit/ir.cpp",
|
"torch/csrc/jit/ir.cpp",
|
||||||
"torch/csrc/jit/python_ir.cpp",
|
"torch/csrc/jit/python_ir.cpp",
|
||||||
"torch/csrc/jit/test_jit.cpp",
|
"torch/csrc/jit/test_jit.cpp",
|
||||||
|
|
|
||||||
|
|
@ -33,13 +33,13 @@ auto ${name} = ${type_cast}(node->${method}(stringToSymbol("${name}")));\
|
||||||
""")
|
""")
|
||||||
|
|
||||||
CALL_NAMESPACE = CodeTemplate("at::${name}(${args})")
|
CALL_NAMESPACE = CodeTemplate("at::${name}(${args})")
|
||||||
CALL_METHOD = CodeTemplate("vars[0].${name}(${args})")
|
CALL_METHOD = CodeTemplate("inputs[0].${name}(${args})")
|
||||||
|
|
||||||
CONSTRUCTOR = CodeTemplate("""\
|
CONSTRUCTOR = CodeTemplate("""\
|
||||||
{"${descriptor}", [](Node *node) {
|
{"${descriptor}", [](Node *node) {
|
||||||
${assignments}
|
${assignments}
|
||||||
return TensorOp([=](const variable_list& vars) -> variable_list {
|
return TensorOp([=](const std::vector<Tensor> & inputs, std::vector<Tensor> & outputs) {
|
||||||
return pack_list(${call});
|
pack_list(outputs, ${call});
|
||||||
}, "${name}", ${num_inputs});
|
}, "${name}", ${num_inputs});
|
||||||
}},
|
}},
|
||||||
""")
|
""")
|
||||||
|
|
@ -85,16 +85,16 @@ def gen_jit_dispatch(declarations, out):
|
||||||
if 'namespace' in decl['method_of']:
|
if 'namespace' in decl['method_of']:
|
||||||
if any(arg['simple_type'] == 'TensorList' for arg in arguments):
|
if any(arg['simple_type'] == 'TensorList' for arg in arguments):
|
||||||
assert sum(map(is_tensor_arg, arguments)) == 1
|
assert sum(map(is_tensor_arg, arguments)) == 1
|
||||||
args = ['as_tensor_list(vars)' if is_tensor_arg(arg) else arg['name']
|
args = ['inputs' if is_tensor_arg(arg) else arg['name']
|
||||||
for arg in arguments]
|
for arg in arguments]
|
||||||
else:
|
else:
|
||||||
tensor_id = iter(count(start=0))
|
tensor_id = iter(count(start=0))
|
||||||
args = ['vars[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name']
|
args = ['inputs[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name']
|
||||||
for arg in arguments]
|
for arg in arguments]
|
||||||
call = CALL_NAMESPACE.substitute(name=name, args=args)
|
call = CALL_NAMESPACE.substitute(name=name, args=args)
|
||||||
else:
|
else:
|
||||||
tensor_id = iter(count(start=1))
|
tensor_id = iter(count(start=1))
|
||||||
args = ['vars[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name']
|
args = ['inputs[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name']
|
||||||
for arg in arguments[1:]]
|
for arg in arguments[1:]]
|
||||||
call = CALL_METHOD.substitute(name=name, args=args)
|
call = CALL_METHOD.substitute(name=name, args=args)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,18 +19,19 @@ using operator_constructor = std::function<TensorOp(jit::Node*)>;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
variable_list pack_list(Tensor v) { return { std::move(v) }; }
|
void pack_list(std::vector<Tensor> & outputs, Tensor v) { outputs.push_back(v); }
|
||||||
variable_list pack_list(Scalar v) { return { v.toTensor() }; }
|
void pack_list(std::vector<Tensor> & outputs, Scalar v) { outputs.push_back(v.toTensor()); }
|
||||||
variable_list pack_list(std::vector<Tensor> t) { return fmap<Variable>(t); }
|
void pack_list(std::vector<Tensor> & outputs, const std::vector<Tensor> & t) {
|
||||||
variable_list pack_list(std::tuple<Tensor, Tensor> v) {
|
outputs.insert(outputs.end(), t.begin(), t.end());
|
||||||
return { std::move(std::get<0>(v)), std::move(std::get<1>(v)) };
|
|
||||||
}
|
}
|
||||||
variable_list pack_list(std::tuple<Tensor, Tensor, Tensor> v) {
|
void pack_list(std::vector<Tensor> & outputs, std::tuple<Tensor, Tensor> v) {
|
||||||
return { std::get<0>(v), std::get<1>(v), std::get<2>(v) };
|
outputs.push_back(std::get<0>(v));
|
||||||
|
outputs.push_back(std::get<1>(v));
|
||||||
}
|
}
|
||||||
|
void pack_list(std::vector<Tensor> & outputs, std::tuple<Tensor, Tensor, Tensor> v) {
|
||||||
std::vector<Tensor> as_tensor_list(const variable_list& vars) {
|
outputs.push_back(std::get<0>(v));
|
||||||
return fmap(vars, [](Variable v) { return static_cast<Tensor>(v); });
|
outputs.push_back(std::get<1>(v));
|
||||||
|
outputs.push_back(std::get<2>(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<size_t N>
|
template<size_t N>
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
namespace torch { namespace jit {
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
struct TensorOp {
|
struct TensorOp {
|
||||||
using op_type = std::function<autograd::variable_list(const autograd::variable_list&)>;
|
using op_type = std::function<void(const std::vector<at::Tensor> &, std::vector<at::Tensor> &)>;
|
||||||
|
|
||||||
TensorOp(op_type op, std::string name, size_t num_inputs)
|
TensorOp(op_type op, std::string name, size_t num_inputs)
|
||||||
: op(op)
|
: op(op)
|
||||||
|
|
|
||||||
|
|
@ -204,7 +204,9 @@ static variable_list call_function(FunctionTask& task) {
|
||||||
auto& callback = it_p.first->second;
|
auto& callback = it_p.first->second;
|
||||||
if (!callback(&fn, inputs)) return variable_list(fn.next_functions.size());
|
if (!callback(&fn, inputs)) return variable_list(fn.next_functions.size());
|
||||||
}
|
}
|
||||||
|
if(!task.base->keep_graph) {
|
||||||
|
fn.willReleaseVariables();
|
||||||
|
}
|
||||||
auto outputs = fn(inputs);
|
auto outputs = fn(inputs);
|
||||||
|
|
||||||
auto& post_callbacks = task.base->post_callbacks;
|
auto& post_callbacks = task.base->post_callbacks;
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,10 @@ struct Function : std::enable_shared_from_this<Function> {
|
||||||
|
|
||||||
// Releases saved variables if the operation won't be reused
|
// Releases saved variables if the operation won't be reused
|
||||||
virtual inline void releaseVariables() {}
|
virtual inline void releaseVariables() {}
|
||||||
|
// called before a an apply if will release variables is going to be called
|
||||||
|
// allows larger ops like InterpreterAutogradFunction
|
||||||
|
// to incrementally release variables as they run
|
||||||
|
virtual inline void willReleaseVariables() {}
|
||||||
// Function name for debugging
|
// Function name for debugging
|
||||||
virtual std::string name();
|
virtual std::string name();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,14 @@ struct EmitNull : public Function {
|
||||||
|
|
||||||
struct LambdaFunction : public Function {
|
struct LambdaFunction : public Function {
|
||||||
LambdaFunction(const jit::TensorOp& op)
|
LambdaFunction(const jit::TensorOp& op)
|
||||||
: LambdaFunction(op.num_inputs, op.op) {
|
: LambdaFunction(op.num_inputs, nullptr) {
|
||||||
|
auto & real_op = op.op;
|
||||||
|
this->fn_ = [real_op](const variable_list& inputs) -> variable_list {
|
||||||
|
std::vector<at::Tensor> tinputs(inputs.begin(), inputs.end());
|
||||||
|
std::vector<at::Tensor> toutputs;
|
||||||
|
real_op(tinputs, toutputs);
|
||||||
|
return variable_list(toutputs.begin(), toutputs.end());
|
||||||
|
};
|
||||||
this->name_ = op.name;
|
this->name_ = op.name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -279,10 +286,6 @@ struct FusionGroupFunction : public Function {
|
||||||
data.push_back(input.data());
|
data.push_back(input.data());
|
||||||
AutoGPU guard(data.back());
|
AutoGPU guard(data.back());
|
||||||
std::vector<at::Tensor> outputs;
|
std::vector<at::Tensor> outputs;
|
||||||
outputs.reserve(function->outputDescriptors().size());
|
|
||||||
for(auto & od : function->outputDescriptors()) {
|
|
||||||
outputs.push_back(at::CUDA(od.scalar_type).tensor());
|
|
||||||
}
|
|
||||||
function->launch(data, outputs);
|
function->launch(data, outputs);
|
||||||
return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) {
|
return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) {
|
||||||
return std::make_shared<torch::autograd::Error>("FusionGroupFunction is not differentiable", std::move(f));
|
return std::make_shared<torch::autograd::Error>("FusionGroupFunction is not differentiable", std::move(f));
|
||||||
|
|
|
||||||
|
|
@ -419,7 +419,7 @@ void compressContiguous(
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) {
|
void CompiledFusionFunction::launch_with_tensors(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) {
|
||||||
AutoGPU gpu_guard(inputs);
|
AutoGPU gpu_guard(inputs);
|
||||||
JIT_ASSERT(inputs.size() == input_desc.size());
|
JIT_ASSERT(inputs.size() == input_desc.size());
|
||||||
JIT_ASSERT(outputs.size() == output_desc.size());
|
JIT_ASSERT(outputs.size() == output_desc.size());
|
||||||
|
|
@ -479,6 +479,16 @@ void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, at::ArrayRe
|
||||||
launch(numel, arguments.data());
|
launch(numel, arguments.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, std::vector<at::Tensor> & outputs) {
|
||||||
|
AutoGPU guard(inputs.back());
|
||||||
|
outputs.clear();
|
||||||
|
outputs.reserve(outputDescriptors().size());
|
||||||
|
for(auto & od : outputDescriptors()) {
|
||||||
|
outputs.push_back(at::CUDA(od.scalar_type).tensor());
|
||||||
|
}
|
||||||
|
launch_with_tensors(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
void CompiledFusionFunction::launch(uint32_t numel, void ** arguments) {
|
void CompiledFusionFunction::launch(uint32_t numel, void ** arguments) {
|
||||||
int numBlocks = std::min(maxBlocks, ceilDiv(numel, blockSize));
|
int numBlocks = std::min(maxBlocks, ceilDiv(numel, blockSize));
|
||||||
//std::cout << "maxBlocks = " << maxBlocks << " needed blocks: " << ceilDiv(numel,blockSize)
|
//std::cout << "maxBlocks = " << maxBlocks << " needed blocks: " << ceilDiv(numel,blockSize)
|
||||||
|
|
@ -539,7 +549,7 @@ void FusionCompiler::debugLaunchGraph(Graph & graph, at::ArrayRef<at::Tensor> in
|
||||||
agraph.output_desc.emplace_back(i);
|
agraph.output_desc.emplace_back(i);
|
||||||
}
|
}
|
||||||
auto func = getOrCompile(agraph);
|
auto func = getOrCompile(agraph);
|
||||||
func->launch(inputs, outputs);
|
func->launch_with_tensors(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: thread safety
|
//TODO: thread safety
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,11 @@ struct CompiledFusionFunction {
|
||||||
CompiledFusionFunction(const std::string & name, AnnotatedGraph & agraph);
|
CompiledFusionFunction(const std::string & name, AnnotatedGraph & agraph);
|
||||||
~CompiledFusionFunction();
|
~CompiledFusionFunction();
|
||||||
|
|
||||||
void launch(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs);
|
// expects outputs to be pre-allocated
|
||||||
|
void launch_with_tensors(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs);
|
||||||
|
|
||||||
|
// creates new tensors for outputs
|
||||||
|
void launch(at::ArrayRef<at::Tensor> inputs, std::vector<at::Tensor> & outputs);
|
||||||
const std::vector<TensorDesc> & outputDescriptors() const {
|
const std::vector<TensorDesc> & outputDescriptors() const {
|
||||||
return output_desc;
|
return output_desc;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
293
torch/csrc/jit/interpreter.cpp
Normal file
293
torch/csrc/jit/interpreter.cpp
Normal file
|
|
@ -0,0 +1,293 @@
|
||||||
|
#include "interpreter.h"
|
||||||
|
#include "torch/csrc/jit/ir.h"
|
||||||
|
#include "torch/csrc/jit/generated/aten_dispatch.h"
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
#include "torch/csrc/jit/fusion_compiler.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
|
using tensor_list = std::vector<at::Tensor>;
|
||||||
|
using Callback = std::function<void(const tensor_list &, tensor_list &)>;
|
||||||
|
// Returns a function implementing functionality of a given node,
|
||||||
|
// or nullptr if it's a no-op for autograd.
|
||||||
|
Callback getCallback(Node *node) {
|
||||||
|
IR_IFM(node, PythonOp)
|
||||||
|
throw NotImplementedException();
|
||||||
|
IR_ELSEIFM(CppOp)
|
||||||
|
throw NotImplementedException();
|
||||||
|
IR_ELSEIF(Select)
|
||||||
|
barf("getCallback() on select?");
|
||||||
|
IR_ELSEIF(FusionGroup)
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
auto fusion_fn = sharedFusionCompiler().getOrCompile(*value->g(kSubgraph));
|
||||||
|
return [fusion_fn](const tensor_list & inputs, tensor_list & outputs) {
|
||||||
|
fusion_fn->launch(inputs, outputs);
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
throw std::runtime_error("don't know how to execute FusionGroups without CUDA");
|
||||||
|
#endif
|
||||||
|
IR_ELSEIF(Constant)
|
||||||
|
auto t = value->t(kvalue);
|
||||||
|
return [t](const tensor_list & inputs, tensor_list & outputs) {
|
||||||
|
outputs.push_back(t);
|
||||||
|
};
|
||||||
|
IR_ELSEIF(Undefined)
|
||||||
|
return [](const tensor_list & inputs, tensor_list & outputs) {
|
||||||
|
outputs.push_back(at::Tensor());
|
||||||
|
};
|
||||||
|
IR_ELSE()
|
||||||
|
return getTensorOp(node).op;
|
||||||
|
IR_END()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// We need some lists for inputs and outputs. To keep all the memory
|
||||||
|
// contiguous we allocate a single vector and use offsets into the vector
|
||||||
|
// which are stored in the ListHandle struct
|
||||||
|
// start is an offset into int_data of Code for ListHandle<int>
|
||||||
|
// and bool_data of Code for ListHandle<bool>
|
||||||
|
template<typename T>
|
||||||
|
struct ListHandle {
|
||||||
|
int start;
|
||||||
|
int size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UseList {
|
||||||
|
// values to be used
|
||||||
|
ListHandle<int> values;
|
||||||
|
// boolean flags indicating whether to free the Tensor after this use
|
||||||
|
ListHandle<bool> free_flags;
|
||||||
|
};
|
||||||
|
|
||||||
|
// one instruction plus meta-data
|
||||||
|
struct Instruction {
|
||||||
|
Callback callback;
|
||||||
|
UseList inputs;
|
||||||
|
ListHandle<int> outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Stage {
|
||||||
|
ListHandle<int> inputs; // inputs to define for the stage
|
||||||
|
UseList outputs; // values consumed by the return
|
||||||
|
std::vector<Instruction> instructions;
|
||||||
|
};
|
||||||
|
|
||||||
|
// pre-processing that happens once per graph
|
||||||
|
struct CodeImpl {
|
||||||
|
CodeImpl(std::shared_ptr<Graph> & graph)
|
||||||
|
: graph(graph) {
|
||||||
|
int64_t cur_stage = -1;
|
||||||
|
size_t input_pos = 0;
|
||||||
|
size_t output_pos = 0;
|
||||||
|
// step 1: encode all operators and stages into registers and fill in
|
||||||
|
// input/output lists
|
||||||
|
for(auto node : graph->nodes()) {
|
||||||
|
if(node->kind() == kSelect)
|
||||||
|
continue;
|
||||||
|
insertStagesTo(cur_stage, node->stage(), input_pos, output_pos);
|
||||||
|
cur_stage = node->stage();
|
||||||
|
stages.back().instructions.emplace_back();
|
||||||
|
auto & inst = stages.back().instructions.back();
|
||||||
|
listBegin(inst.inputs.values);
|
||||||
|
for(auto input : node->inputs()) {
|
||||||
|
listInsert(inst.inputs.values, getOrAllocateRegister(input, true));
|
||||||
|
}
|
||||||
|
listBegin(inst.outputs);
|
||||||
|
for(auto output : node->outputs()) {
|
||||||
|
listInsert(inst.outputs, getOrAllocateRegister(output));
|
||||||
|
}
|
||||||
|
inst.callback = getCallback(node);
|
||||||
|
}
|
||||||
|
// it is possible that the final stages have no instructions in them
|
||||||
|
// and are just identity functions. We call insertStagesTo here
|
||||||
|
// to force all these empty stages to be generated if they exist
|
||||||
|
insertStagesTo(cur_stage, graph->stage(), input_pos, output_pos);
|
||||||
|
|
||||||
|
// step 2: the last time we use a register we want to mark its free_flag
|
||||||
|
// so we clean it up
|
||||||
|
// this is done with a backward scan where we mark the first time we see it
|
||||||
|
std::unordered_set<int> seen_registers;
|
||||||
|
auto scanUses = [&](UseList & u) {
|
||||||
|
listBegin(u.free_flags);
|
||||||
|
for(int i = 0; i < u.values.size; i++) {
|
||||||
|
int reg = get(u.values,i);
|
||||||
|
listInsert(u.free_flags, seen_registers.count(reg) == 0);
|
||||||
|
seen_registers.insert(reg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for(auto sit = stages.rbegin(); sit != stages.rend(); sit++) {
|
||||||
|
scanUses(sit->outputs);
|
||||||
|
for(auto iit = sit->instructions.rbegin(); iit != sit->instructions.rend(); iit++) {
|
||||||
|
scanUses(iit->inputs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void insertStagesTo(int64_t cur_stage, int64_t goal_stage, size_t & input_pos, size_t & output_pos) {
|
||||||
|
while(cur_stage < goal_stage) {
|
||||||
|
cur_stage++;
|
||||||
|
stages.emplace_back();
|
||||||
|
auto & stage = stages.back();
|
||||||
|
listBegin(stage.inputs);
|
||||||
|
for(;input_pos < graph->inputs().size(); input_pos++) {
|
||||||
|
auto input = graph->inputs()[input_pos];
|
||||||
|
if((int64_t)input->stage() > cur_stage)
|
||||||
|
break;
|
||||||
|
// unused inputs are given a false register -1 so that we never hold a
|
||||||
|
// reference to the tensor data, otherwise we would fail to clean them
|
||||||
|
// up since they do not have a last use at which to free them
|
||||||
|
int reg = input->uses().size() > 0 ? getOrAllocateRegister(input) : -1;
|
||||||
|
listInsert(stage.inputs, reg);
|
||||||
|
}
|
||||||
|
listBegin(stage.outputs.values);
|
||||||
|
for(;output_pos < graph->outputs().size(); output_pos++) {
|
||||||
|
auto output = graph->outputs()[output_pos];
|
||||||
|
if((int64_t)output->stage() > cur_stage)
|
||||||
|
break;
|
||||||
|
listInsert(stage.outputs.values, getOrAllocateRegister(output));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// helpers to build/access RegList objects
|
||||||
|
int get(ListHandle<int> & list, int i) {
|
||||||
|
return int_data[list.start + i];
|
||||||
|
}
|
||||||
|
void listBegin(ListHandle<int> & list) {
|
||||||
|
list.start = int_data.size();
|
||||||
|
list.size = 0;
|
||||||
|
}
|
||||||
|
void listInsert(ListHandle<int> & list, int value) {
|
||||||
|
JIT_ASSERTM(list.start + list.size == (int)int_data.size(), "another list already started");
|
||||||
|
int_data.push_back(value);
|
||||||
|
list.size++;
|
||||||
|
}
|
||||||
|
void listBegin(ListHandle<bool> & list) {
|
||||||
|
list.start = bool_data.size();
|
||||||
|
list.size = 0;
|
||||||
|
}
|
||||||
|
void listInsert(ListHandle<bool> & list, int value) {
|
||||||
|
JIT_ASSERTM(list.start + list.size == (int)bool_data.size(), "another list already started");
|
||||||
|
bool_data.push_back(value);
|
||||||
|
list.size++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int getOrAllocateRegister(Node * n, bool required = false) {
|
||||||
|
size_t u = n->unique();
|
||||||
|
if(unique_to_reg.count(u) > 0)
|
||||||
|
return unique_to_reg[u];
|
||||||
|
JIT_ASSERT(!required);
|
||||||
|
int r = register_size++;
|
||||||
|
unique_to_reg[u] = r;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
std::shared_ptr<Graph> graph;
|
||||||
|
std::unordered_map<size_t, int> unique_to_reg; // map from unique of nodes to register in register table
|
||||||
|
|
||||||
|
friend struct InterpreterState;
|
||||||
|
std::vector<Stage> stages;
|
||||||
|
int register_size = 0;
|
||||||
|
|
||||||
|
// all memory ArrayRef<int> are slices of this, to make sure
|
||||||
|
// the interpreter is mostly linearly scanning through memory
|
||||||
|
std::vector<int> int_data;
|
||||||
|
std::vector<bool> bool_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
// InterpreterState state that is held across stages and used to compute a Code
|
||||||
|
struct InterpreterStateImpl {
|
||||||
|
InterpreterStateImpl(const Code & function_)
|
||||||
|
: function(function_.pImpl),
|
||||||
|
int_data(function->int_data.data()),
|
||||||
|
bool_data(function->bool_data),
|
||||||
|
registers(function->register_size) {
|
||||||
|
}
|
||||||
|
void runOneStage(
|
||||||
|
const std::vector<at::Tensor> & inputs,
|
||||||
|
std::vector<at::Tensor> & outputs) {
|
||||||
|
//std::cout << "running stage: " << current_stage << " of " << function->stages.size() << "\n";
|
||||||
|
JIT_ASSERT(current_stage < function->stages.size());
|
||||||
|
auto & stage = function->stages[current_stage++];
|
||||||
|
JIT_ASSERT((int)inputs.size() == stage.inputs.size);
|
||||||
|
for(int i = 0; i < stage.inputs.size; i++) {
|
||||||
|
int reg = get(stage.inputs,i);
|
||||||
|
if(reg >= 0) { // otherwise this input is dead, and we do not store it to avoid holding the reference
|
||||||
|
registers[reg] = inputs[i];
|
||||||
|
}
|
||||||
|
//std::cout << "registers[" << reg << "] = inputs[" << i << "](" << inputs[i].defined() << ")\n";
|
||||||
|
}
|
||||||
|
for(auto & inst : stage.instructions) {
|
||||||
|
loadTensorsFromRegisters(inst.inputs, input_buffer);
|
||||||
|
inst.callback(input_buffer, output_buffer);
|
||||||
|
for(int i = 0; i < inst.outputs.size; i++) {
|
||||||
|
int reg = get(inst.outputs,i);
|
||||||
|
registers[reg] = std::move(output_buffer[i]);
|
||||||
|
//std::cout << "registers[" << reg << "] = outputs[" << i << "](" << output_buffer[i].defined() << ")\n";
|
||||||
|
}
|
||||||
|
output_buffer.clear();
|
||||||
|
input_buffer.clear();
|
||||||
|
}
|
||||||
|
outputs.clear();
|
||||||
|
loadTensorsFromRegisters(stage.outputs, outputs);
|
||||||
|
}
|
||||||
|
int get(const ListHandle<int> & list, int i) {
|
||||||
|
return int_data[list.start + i];
|
||||||
|
};
|
||||||
|
bool get(const ListHandle<bool> & list, int i) {
|
||||||
|
return bool_data[list.start + i];
|
||||||
|
}
|
||||||
|
void loadTensorsFromRegisters(const UseList & uses, std::vector<at::Tensor> & outputs) {
|
||||||
|
for(int i = 0; i < uses.values.size; i++) {
|
||||||
|
int reg = get(uses.values,i);
|
||||||
|
auto & value = registers[reg];
|
||||||
|
//std::cout << "inputs[" << i << "] = registers[" << reg << "] (" << value.defined() << ")";
|
||||||
|
if(get(uses.free_flags,i)) {
|
||||||
|
//std::cout << " and FREED";
|
||||||
|
outputs.push_back(std::move(value));
|
||||||
|
} else {
|
||||||
|
outputs.push_back(value);
|
||||||
|
}
|
||||||
|
//std::cout << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t current_stage = 0;
|
||||||
|
std::shared_ptr<CodeImpl> function; // keep function alive
|
||||||
|
// these are just copies of function to prevent indirections in intepreter
|
||||||
|
int * int_data;
|
||||||
|
const std::vector<bool> & bool_data;
|
||||||
|
|
||||||
|
|
||||||
|
// this holds all the tensors for this interpreter run
|
||||||
|
// we don't bother minimizing the size of this vector, since the extra
|
||||||
|
// memory used by the pointers in this will be small
|
||||||
|
// instead we are very aggresive about releasing tensors when they become dead
|
||||||
|
// to make sure memory management happens efficiently.
|
||||||
|
|
||||||
|
// We optimize for the case where derivatives are run with retain_graph=False
|
||||||
|
// in the case where it is true, then the interpreter and this array get copied
|
||||||
|
// if this every becomes a bottleneck then we _should_ consider minimizing the
|
||||||
|
// total number or register
|
||||||
|
std::vector<at::Tensor> registers;
|
||||||
|
|
||||||
|
// single buffer for input calls to ATen functions, so that we do not reallocate
|
||||||
|
std::vector<at::Tensor> input_buffer;
|
||||||
|
// also to prevent allocations
|
||||||
|
std::vector<at::Tensor> output_buffer;
|
||||||
|
};
|
||||||
|
|
||||||
|
Code::Code(std::shared_ptr<Graph> & graph)
|
||||||
|
: pImpl(new CodeImpl(graph)) {}
|
||||||
|
Code::~Code() {}
|
||||||
|
InterpreterState::InterpreterState(const Code & function)
|
||||||
|
: pImpl(new InterpreterStateImpl(function)) {}
|
||||||
|
InterpreterState::~InterpreterState() {}
|
||||||
|
void InterpreterState::runOneStage(
|
||||||
|
const std::vector<at::Tensor> & inputs,
|
||||||
|
std::vector<at::Tensor> & outputs) {
|
||||||
|
return pImpl->runOneStage(inputs, outputs);
|
||||||
|
}
|
||||||
|
InterpreterState InterpreterState::clone() const {
|
||||||
|
return InterpreterState(new InterpreterStateImpl(*pImpl));
|
||||||
|
}
|
||||||
|
InterpreterState::InterpreterState(InterpreterStateImpl * pImpl) : pImpl(pImpl) {}
|
||||||
|
|
||||||
|
}}
|
||||||
53
torch/csrc/jit/interpreter.h
Normal file
53
torch/csrc/jit/interpreter.h
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
#pragma once
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace at {
|
||||||
|
struct Tensor;
|
||||||
|
}
|
||||||
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
|
struct NotImplementedException : public std::logic_error {
|
||||||
|
NotImplementedException()
|
||||||
|
: std::logic_error("Function not yet implemented.") {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// The interpreter run Graphs with Tensor inputs and Tensor outputs
|
||||||
|
// a separate component in the autograd handles unwrapping and wrapping
|
||||||
|
// variable objects for use in the interpreter.
|
||||||
|
|
||||||
|
struct CodeImpl;
|
||||||
|
struct InterpreterStateImpl;
|
||||||
|
struct Graph;
|
||||||
|
|
||||||
|
struct Code {
|
||||||
|
Code()
|
||||||
|
: pImpl(nullptr) {}
|
||||||
|
Code(std::shared_ptr<Graph> & graph);
|
||||||
|
~Code();
|
||||||
|
operator bool() const {
|
||||||
|
return pImpl != nullptr;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
std::shared_ptr<CodeImpl> pImpl;
|
||||||
|
friend class InterpreterStateImpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InterpreterState {
|
||||||
|
InterpreterState(const Code & code);
|
||||||
|
// advance the interpreter state by running one stage. Returning the
|
||||||
|
// outputs for that stage, suspending the computation.
|
||||||
|
// Call this function again continues computation where it left off.
|
||||||
|
void runOneStage(
|
||||||
|
const std::vector<at::Tensor> & inputs,
|
||||||
|
std::vector<at::Tensor> & outputs);
|
||||||
|
~InterpreterState();
|
||||||
|
// create a copy of InterpreterState with its current state
|
||||||
|
// used when retain_graph=True so that stages can be re-run
|
||||||
|
InterpreterState clone() const;
|
||||||
|
private:
|
||||||
|
InterpreterState(InterpreterStateImpl * pImpl);
|
||||||
|
std::shared_ptr<InterpreterStateImpl> pImpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
}}
|
||||||
36
torch/csrc/jit/interpreter_autograd_function.h
Normal file
36
torch/csrc/jit/interpreter_autograd_function.h
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "torch/csrc/jit/interpreter.h"
|
||||||
|
#include "torch/csrc/autograd/function.h"
|
||||||
|
#include "torch/csrc/autograd/functions/utils.h"
|
||||||
|
#include "torch/csrc/autograd/functions/basic_ops.h"
|
||||||
|
namespace torch { namespace jit {
|
||||||
|
struct InterpreterAutogradFunction : public autograd::Function {
|
||||||
|
InterpreterAutogradFunction(const jit::Code & code)
|
||||||
|
: interp_(code) {}
|
||||||
|
InterpreterAutogradFunction(const InterpreterState & interp_, autograd::FunctionFlags && f)
|
||||||
|
: autograd::Function(std::move(f)), interp_(interp_) {}
|
||||||
|
|
||||||
|
virtual void willReleaseVariables() override {
|
||||||
|
keep_graph = false;
|
||||||
|
}
|
||||||
|
virtual autograd::variable_list apply(const autograd::variable_list& inputs) override {
|
||||||
|
std::vector<at::Tensor> tinputs;
|
||||||
|
std::vector<at::Tensor> toutputs;
|
||||||
|
for(auto & i : inputs) {
|
||||||
|
tinputs.push_back(i.data());
|
||||||
|
}
|
||||||
|
InterpreterState interp = (keep_graph) ? interp_.clone() : interp_;
|
||||||
|
keep_graph = true;
|
||||||
|
interp.runOneStage(tinputs, toutputs);
|
||||||
|
auto r = autograd::wrap_outputs(inputs, std::move(toutputs), [&](autograd::FunctionFlags f) {
|
||||||
|
return std::make_shared<InterpreterAutogradFunction>(interp, std::move(f));
|
||||||
|
});
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
bool keep_graph = true;
|
||||||
|
InterpreterState interp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}}
|
||||||
|
|
@ -854,7 +854,7 @@ struct PythonOp : public Node {
|
||||||
|
|
||||||
// The Python object which contains the implementation of this function.
|
// The Python object which contains the implementation of this function.
|
||||||
// This is either a class (non-legacy) or an object (legacy). See
|
// This is either a class (non-legacy) or an object (legacy). See
|
||||||
// TraceInterpreter for execution semantics.
|
// TraceInterpreterState for execution semantics.
|
||||||
THPObjectPtr pyobj;
|
THPObjectPtr pyobj;
|
||||||
// The calling convention for the Python function.
|
// The calling convention for the Python function.
|
||||||
// 's' -- python scalar argument
|
// 's' -- python scalar argument
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@
|
||||||
#include "torch/csrc/jit/passes/graph_fuser.h"
|
#include "torch/csrc/jit/passes/graph_fuser.h"
|
||||||
#include "torch/csrc/jit/passes/inplace_check.h"
|
#include "torch/csrc/jit/passes/inplace_check.h"
|
||||||
#include "torch/csrc/jit/python_arg_flatten.h"
|
#include "torch/csrc/jit/python_arg_flatten.h"
|
||||||
|
#include "torch/csrc/jit/interpreter.h"
|
||||||
|
#include "torch/csrc/jit/interpreter_autograd_function.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
@ -54,7 +56,7 @@ struct CompiledFunction {
|
||||||
, is_volatile_(is_volatile) {}
|
, is_volatile_(is_volatile) {}
|
||||||
|
|
||||||
bool ready() {
|
bool ready() {
|
||||||
if (closure_) return true;
|
if (is_ready_) return true;
|
||||||
|
|
||||||
// Remove expired traces
|
// Remove expired traces
|
||||||
traces_.erase(std::remove_if(traces_.begin(),
|
traces_.erase(std::remove_if(traces_.begin(),
|
||||||
|
|
@ -83,20 +85,30 @@ struct CompiledFunction {
|
||||||
PeepholeOptimize(complete_trace->graph);
|
PeepholeOptimize(complete_trace->graph);
|
||||||
FuseGraph(complete_trace->graph);
|
FuseGraph(complete_trace->graph);
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
|
code_ = jit::Code(complete_trace->graph);
|
||||||
|
} catch(const jit::NotImplementedException & ex) {
|
||||||
closure_ = std::make_shared<AutogradClosureFactory>(complete_trace.get());
|
closure_ = std::make_shared<AutogradClosureFactory>(complete_trace.get());
|
||||||
|
}
|
||||||
|
is_ready_ = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
variable_list run(const variable_list& in_vars) {
|
variable_list run(const variable_list& in_vars) {
|
||||||
JIT_ASSERT(closure_);
|
JIT_ASSERT(is_ready_);
|
||||||
AutoNoGIL _gil_guard;
|
AutoNoGIL _gil_guard;
|
||||||
|
if(closure_) {
|
||||||
auto fn = closure_->construct();
|
auto fn = closure_->construct();
|
||||||
return (*fn)(in_vars);
|
return (*fn)(in_vars);
|
||||||
|
} else {
|
||||||
|
InterpreterAutogradFunction interp(code_);
|
||||||
|
interp.willReleaseVariables(); // forward pass is never reused, so it is safe to release anything it can
|
||||||
|
return interp.apply(in_vars);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* add_trace(PyObject *args, const variable_list& in_vars) {
|
PyObject* add_trace(PyObject *args, const variable_list& in_vars) {
|
||||||
JIT_ASSERT(!closure_);
|
JIT_ASSERT(!is_ready_);
|
||||||
// Start tracing
|
// Start tracing
|
||||||
auto trace = tracer::enter(fmap<TraceInput>(in_vars), is_volatile_ ? 1 : (fn_.nderivs_ + 1));
|
auto trace = tracer::enter(fmap<TraceInput>(in_vars), is_volatile_ ? 1 : (fn_.nderivs_ + 1));
|
||||||
|
|
||||||
|
|
@ -120,7 +132,9 @@ struct CompiledFunction {
|
||||||
|
|
||||||
CompiledFunction& fn_;
|
CompiledFunction& fn_;
|
||||||
std::string out_desc_;
|
std::string out_desc_;
|
||||||
|
bool is_ready_ = false;
|
||||||
std::shared_ptr<AutogradClosureFactory> closure_;
|
std::shared_ptr<AutogradClosureFactory> closure_;
|
||||||
|
jit::Code code_;
|
||||||
std::vector<std::shared_ptr<TracingState>> traces_;
|
std::vector<std::shared_ptr<TracingState>> traces_;
|
||||||
bool is_volatile_;
|
bool is_volatile_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
#include "torch/csrc/jit/attributes.h"
|
#include "torch/csrc/jit/attributes.h"
|
||||||
#include "torch/csrc/jit/interned_strings.h"
|
#include "torch/csrc/jit/interned_strings.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "torch/csrc/jit/interpreter.h"
|
||||||
|
|
||||||
namespace torch { namespace jit {
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
|
|
@ -246,7 +247,208 @@ void internedStringsTests () {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
at::Tensor t_use(at::Tensor x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
at::Tensor t_def(at::Tensor x) {
|
||||||
|
return x.t();
|
||||||
|
}
|
||||||
|
|
||||||
|
// given the difference of output vs expected tensor, check whether the
|
||||||
|
// difference is within a relative tolerance range. This is a standard way of
|
||||||
|
// matching tensor values upto certain precision
|
||||||
|
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
|
||||||
|
double maxValue = 0.0;
|
||||||
|
for (auto& tensor : inputs) {
|
||||||
|
maxValue = fmax(tensor.abs().max().toCFloat(), maxValue);
|
||||||
|
}
|
||||||
|
return diff.abs().max().toCFloat() < 2e-6 * maxValue;
|
||||||
|
}
|
||||||
|
bool almostEqual(const at::Tensor & a, const at::Tensor & b) {
|
||||||
|
return checkRtol(a - b,{a, b});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool exactlyEqual(const at::Tensor & a, const at::Tensor & b) {
|
||||||
|
return (a - b).abs().max().toCFloat() == 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<at::Tensor, at::Tensor>
|
||||||
|
lstm(at::Tensor input,
|
||||||
|
at::Tensor hx,
|
||||||
|
at::Tensor cx,
|
||||||
|
at::Tensor w_ih,
|
||||||
|
at::Tensor w_hh) {
|
||||||
|
auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh));
|
||||||
|
|
||||||
|
auto chunked_gates = gates.chunk(4, 1);
|
||||||
|
auto ingate = chunked_gates[0];
|
||||||
|
auto forgetgate = chunked_gates[1];
|
||||||
|
auto cellgate = chunked_gates[2];
|
||||||
|
auto outgate = chunked_gates[3];
|
||||||
|
|
||||||
|
ingate = ingate.sigmoid();
|
||||||
|
outgate = outgate.sigmoid();
|
||||||
|
cellgate = cellgate.tanh();
|
||||||
|
forgetgate = forgetgate.sigmoid();
|
||||||
|
|
||||||
|
auto cy = (forgetgate * cx) + (ingate * cellgate);
|
||||||
|
auto hy = outgate * cy.tanh();
|
||||||
|
|
||||||
|
return {hy, cy};
|
||||||
|
}
|
||||||
|
|
||||||
|
Symbol sym(const char * str) {
|
||||||
|
return stringToSymbol(str);
|
||||||
|
}
|
||||||
|
|
||||||
|
Node * node(Graph& graph, const char * n, ArrayRef<Node*> inputs) {
|
||||||
|
return graph.appendNode(graph.create(sym(n),inputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
Node * add(Graph & g, Node * a, Node * b) {
|
||||||
|
auto r = node(g, "add", {a,b});
|
||||||
|
r->t_(sym("alpha"), at::Scalar(1).toTensor());
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Node*, Node*> build_lstm_body(
|
||||||
|
Graph & g,
|
||||||
|
Node * input,
|
||||||
|
Node * hx,
|
||||||
|
Node * cx,
|
||||||
|
Node * w_ih,
|
||||||
|
Node * w_hh) {
|
||||||
|
auto gates = add(g, node(g,"mm",{ input, w_ih }), node(g, "mm", {hx, w_hh}));
|
||||||
|
auto chunked_gates = node(g, "chunk", { gates });
|
||||||
|
chunked_gates->i_(sym("chunks"), 4);
|
||||||
|
chunked_gates->i_(sym("dim"), 1);
|
||||||
|
auto ingate = g.appendNode(g.createSelect(chunked_gates, 0));
|
||||||
|
auto forgetgate = g.appendNode(g.createSelect(chunked_gates, 1));
|
||||||
|
auto cellgate = g.appendNode(g.createSelect(chunked_gates, 2));
|
||||||
|
auto outgate = g.appendNode(g.createSelect(chunked_gates, 3));
|
||||||
|
ingate = node(g,"sigmoid",{ingate});
|
||||||
|
outgate = node(g,"sigmoid",{outgate});
|
||||||
|
cellgate = node(g,"tanh",{cellgate});
|
||||||
|
forgetgate = node(g,"sigmoid",{forgetgate});
|
||||||
|
|
||||||
|
auto cy = add(g, node(g,"mul", {forgetgate, cx}) , node(g, "mul", {ingate, cellgate}));
|
||||||
|
auto hy = node(g, "mul", {outgate, node(g, "tanh", {cy})});
|
||||||
|
|
||||||
|
return std::make_tuple(hy,cy);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Graph> build_lstm() {
|
||||||
|
auto r = std::make_shared<Graph>();
|
||||||
|
auto & g = *r;
|
||||||
|
Node * input = g.addInput();
|
||||||
|
Node * hx = g.addInput();
|
||||||
|
Node * cx = g.addInput();
|
||||||
|
Node * w_ih = g.addInput();
|
||||||
|
Node * w_hh = g.addInput();
|
||||||
|
|
||||||
|
Node * hy;
|
||||||
|
Node * cy;
|
||||||
|
std::tie(hy,cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh);
|
||||||
|
|
||||||
|
g.registerOutput(hy);
|
||||||
|
g.registerOutput(cy);
|
||||||
|
g.lint();
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Graph> build_lstm_stages() {
|
||||||
|
auto r = std::make_shared<Graph>();
|
||||||
|
auto & g = *r;
|
||||||
|
Node * input = g.addInput();
|
||||||
|
Node * hx = g.addInput();
|
||||||
|
Node * cx = g.addInput();
|
||||||
|
Node * w_ih = g.addInput();
|
||||||
|
Node * w_hh = g.addInput();
|
||||||
|
|
||||||
|
Node * hy;
|
||||||
|
Node * cy;
|
||||||
|
std::tie(hy,cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh);
|
||||||
|
|
||||||
|
// use some stuff from the previous stage as well
|
||||||
|
// as a new input
|
||||||
|
g.advanceStage();
|
||||||
|
hx = hy;
|
||||||
|
g.registerOutput(cy);
|
||||||
|
cx = g.addInput();
|
||||||
|
|
||||||
|
std::tie(hy,cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh);
|
||||||
|
|
||||||
|
g.registerOutput(hy);
|
||||||
|
g.registerOutput(cy);
|
||||||
|
g.lint();
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void interpTest() {
|
||||||
|
constexpr int batch_size = 4;
|
||||||
|
constexpr int input_size = 256;
|
||||||
|
constexpr int seq_len = 32;
|
||||||
|
|
||||||
|
int hidden_size = 2*input_size;
|
||||||
|
|
||||||
|
auto input = at::CUDA(at::kFloat).randn({seq_len, batch_size, input_size});
|
||||||
|
auto hx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||||
|
auto cx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||||
|
auto w_ih = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, input_size}));
|
||||||
|
auto w_hh = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, hidden_size}));
|
||||||
|
|
||||||
|
auto lstm_g = build_lstm();
|
||||||
|
Code lstm_function(lstm_g);
|
||||||
|
std::vector<at::Tensor> outputs;
|
||||||
|
InterpreterState lstm_interp(lstm_function);
|
||||||
|
lstm_interp.runOneStage({input[0], hx, cx, w_ih, w_hh}, outputs);
|
||||||
|
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
|
||||||
|
|
||||||
|
//std::cout << almostEqual(outputs[0],hx) << "\n";
|
||||||
|
JIT_ASSERT(exactlyEqual(outputs[0],hx));
|
||||||
|
JIT_ASSERT(exactlyEqual(outputs[1],cx));
|
||||||
|
}
|
||||||
|
|
||||||
|
void interpStageTest() {
|
||||||
|
constexpr int batch_size = 4;
|
||||||
|
constexpr int input_size = 256;
|
||||||
|
constexpr int seq_len = 32;
|
||||||
|
|
||||||
|
int hidden_size = 2*input_size;
|
||||||
|
auto input = at::CUDA(at::kFloat).randn({seq_len, batch_size, input_size});
|
||||||
|
auto hx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||||
|
auto cx = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||||
|
auto cx1 = at::CUDA(at::kFloat).randn({batch_size, hidden_size});
|
||||||
|
auto w_ih = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, input_size}));
|
||||||
|
auto w_hh = t_def(at::CUDA(at::kFloat).randn({4 * hidden_size, hidden_size}));
|
||||||
|
|
||||||
|
|
||||||
|
auto lstm_g = build_lstm_stages();
|
||||||
|
Code lstm_function(lstm_g);
|
||||||
|
std::vector<at::Tensor> outputs;
|
||||||
|
InterpreterState lstm_interp(lstm_function);
|
||||||
|
lstm_interp.runOneStage({input[0], hx, cx, w_ih, w_hh}, outputs);
|
||||||
|
auto cy0 = outputs[0];
|
||||||
|
lstm_interp.runOneStage({cx1}, outputs);
|
||||||
|
at::Tensor ihx = outputs[0];
|
||||||
|
at::Tensor icx = outputs[1];
|
||||||
|
|
||||||
|
|
||||||
|
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
|
||||||
|
std::tie(hx, cx) = lstm(input[0], hx, cx1, w_ih, w_hh);
|
||||||
|
|
||||||
|
//std::cout << almostEqual(outputs[0],hx) << "\n";
|
||||||
|
JIT_ASSERT(exactlyEqual(outputs[0],hx));
|
||||||
|
JIT_ASSERT(exactlyEqual(outputs[1],cx));
|
||||||
|
}
|
||||||
|
|
||||||
void runJITCPPTests() {
|
void runJITCPPTests() {
|
||||||
|
interpTest();
|
||||||
|
interpStageTest();
|
||||||
codeTemplateTest();
|
codeTemplateTest();
|
||||||
fusionTests();
|
fusionTests();
|
||||||
attributesTest();
|
attributesTest();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user