pytorch/torch/csrc/jit/mobile/function.cpp
Martin Yuan 3003c5f91b OPN ops TupleConstruct/Unpack and format. (#29635)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29635

TupleConstruct/Unpack as OPN ops.

Test Plan: Imported from OSS

Differential Revision: D18499602

fbshipit-source-id: 389b21d3ea532ef6fa729d67ce34214d86700cd2
2019-11-15 16:22:42 -08:00

93 lines
3.3 KiB
C++

#include "function.h"
#include "interpreter.h"
#include <torch/csrc/jit/instruction.h>
#include <torch/csrc/jit/vararg_functions.h>
#include <ATen/core/op_registration/op_registration.h>
namespace torch{
namespace jit{
namespace {
// Different from the implementation in register_prim_ops.cpp,
// where named tuple is not supported yet.
void tupleConstructFunc(int num_inputs, Stack& stack) {
std::vector<IValue> elems{
std::make_move_iterator(stack.end() - num_inputs),
std::make_move_iterator(stack.end())};
drop(stack, num_inputs);
push(stack, c10::ivalue::Tuple::create(std::move(elems)));
}
}
char const * toString(OpCode op);
namespace mobile {
Function::Function(c10::QualifiedName name)
: name_(name), code_(std::make_shared<Code>()) {}
void Function::append_instruction(OpCode op, int X, int N) {
TORCH_CHECK(isOpSupportedInMobile(op), toString(op),
" is not supported in mobile module.");
code_->instructions_.emplace_back(op, X, N);
}
void Function::append_operator(const std::string& name,
const std::string& overload_name) {
// Keep the original opname in code_
code_->op_names_.emplace_back(name, overload_name);
auto opname = code_->op_names_.back();
// Add "_" prefix to work around the double registration both of jit/generated
// and here. TODO: remove it when we have separate build for lite interpreter.
opname.name = "_" + opname.name;
auto op = c10::Dispatcher::singleton().findSchema(opname);
TORCH_CHECK(op.has_value(), opname.name, ".", opname.overload_name, " cannot be found.");
code_->operators_.emplace_back(op);
}
void Function::build_vararg_operator_table() {
for (auto& ins : code_->instructions_) {
if (ins.op == OPN) {
auto opname = code_->op_names_[ins.X];
if (opname.name == "prim::ListConstruct") {
if (opname.overload_name == "int") {
code_->vararg_operators_.emplace_back(listConstructFunc<int64_t>);
} else if (opname.overload_name == "float") {
code_->vararg_operators_.emplace_back(listConstructFunc<double>);
} else if (opname.overload_name == "bool") {
code_->vararg_operators_.emplace_back(listConstructFunc<bool>);
} else if (opname.overload_name == "Tensor") {
code_->vararg_operators_.emplace_back(tensorListConstructFunc);
} else {
AT_ERROR("Type of ListConstruct is not supported.");
}
} else if (opname.name == "prim::TupleConstruct") {
code_->vararg_operators_.emplace_back(tupleConstructFunc);
} else if (opname.name == "prim::TupleUnpack") {
code_->vararg_operators_.emplace_back(tupleUnpackFunc);
} else if (opname.name == "aten::format") {
code_->vararg_operators_.emplace_back(formatFunc);
}
else {
AT_ERROR("OPN operator ", opname.name, " is not supported.");
}
ins.X = code_->vararg_operators_.size() - 1;
}
}
}
void Function::append_constant(const c10::IValue& constant) {
code_->constants_.push_back(constant);
}
void Function::set_register_size(size_t size) {
code_->register_size_ = size;
}
bool Function::run(Stack& stack) const {
InterpreterState interp_state(code_);
return interp_state.run(stack);
}
} // namespace mobile
} // namespace torch
} // namespace jit