pytorch/torch/csrc/jit/mobile/parse_bytecode.cpp
Martin Yuan 30a7c768d7 [RFC] Modularize functions of parsing bytecode (#61862)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61862

Modularize functions of parsing bytecode tables so that they can be used as needed in situations other than mobile lite interpreter.
* The decoupled functions are re-used by current lite interpreter loader.
* The bytecode can be serialized/deserialized from other formats.
* The decoupled functions have minimum dependencies on other PyTorch components.

Next:
Build a driver binary to include the parser and interpreter, but only has necessary dependency on other PyTorch components.
ghstack-source-id: 137867287

Test Plan:
As an example, a simple bytecode is parsed to a mobile function, and directly run in the added unit test, `RunTimeTest:ParseBytecode`. It contains basic control flow (if, else) and basic data orchestration (list construction).
CI

Reviewed By: larryliu0820

Differential Revision: D29798382

Pulled By: iseeyuan

fbshipit-source-id: 1c173a5f5d37097e3a97baec3f3e48e1eea1400f
2021-09-11 22:24:05 -07:00

111 lines
3.4 KiB
C++

#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/custom_class_detail.h>
namespace torch {
namespace jit {
OpCode parseOpCode(const char* str);
using c10::IValue;
IValue expect_field(
std::vector<IValue>& elements,
const std::string& expected_name,
size_t entry) {
auto row = std::move(elements.at(entry)).toTuple();
TORCH_INTERNAL_ASSERT(
row->elements().at(0).toStringRef() == expected_name,
"Expected ",
expected_name,
" found ",
row->elements().at(0).toStringRef());
return std::move(row->elements().at(1));
}
namespace mobile {
namespace {} // namespace
void parseInstructions(
const std::string& function_name,
const std::vector<IValue>& ins_list,
std::vector<IValue>& debug_handles_m_tuple,
mobile::Function* function) {
c10::List<int64_t> debug_handles_list;
if (!debug_handles_m_tuple.empty()) {
const std::string& debug_info_function_name =
debug_handles_m_tuple[0].toStringRef();
TORCH_CHECK(
debug_info_function_name == function_name,
"The function names in the bytecode table and the debug info table do not match.");
IValue& debug_handles_table = debug_handles_m_tuple[1];
debug_handles_list =
(expect_field(
std::move(debug_handles_table).toTuple()->elements(),
"function_debug_handles",
BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
.toTuple()
->elements())[0]
.toIntList();
TORCH_CHECK(
debug_handles_list.size() == ins_list.size(),
"The numbers of instructions and debug handles strings do not match.");
}
for (const auto j : c10::irange(ins_list.size())) {
std::vector<IValue> ins_item =
std::move(*std::move(ins_list[j]).toTuple()).elements();
TORCH_CHECK(
ins_item.size() == 3,
"There should be three parts in an instruction. The function name is ",
function_name);
OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str());
int X = ins_item[1].toInt();
int N = ins_item[2].toInt();
if (!debug_handles_list.empty()) {
int64_t debug_handle = debug_handles_list[j];
function->append_instruction(op_code, X, N, debug_handle);
} else {
function->append_instruction(op_code, X, N);
}
}
}
void parseConstants(
const std::vector<IValue>& consts_list,
mobile::Function* function) {
for (const auto& constant : consts_list) {
function->append_constant(constant);
}
}
void parseTypes(
const std::vector<IValue>& types_list,
mobile::Function* function) {
static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
for (const auto& t : types_list) {
c10::QualifiedName qn(t.toStringRef());
if (classPrefix.isPrefixOf(qn)) {
auto classType = getCustomClass(qn.qualifiedName());
TORCH_CHECK(
classType,
"The implementation of class ",
qn.qualifiedName(),
" cannot be found.");
function->append_type(classType);
} else {
function->append_type(c10::parseType(t.toStringRef()));
}
}
}
void parseRegisterSize(size_t rsize, mobile::Function* function) {
function->set_register_size(rsize);
}
} // namespace mobile
} // namespace jit
} // namespace torch