mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63130 Extend `type_parser` to handle `NamedTuple` type. It can be extended to handle other types when needed. The custom type will follow the following format: ``` "qualified_named[ NamedTuple, [ [filed_name_1, field_type_1], [filed_name_2, field_type_2] ] ]" ``` For example: ``` "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[ NamedTuple, [ [float_features, Tensor], [id_list_features, List[Tensor]], [label, Tensor], [weight, Tensor], ] ]" ``` For nested types, the order of type lists from type table should be: ``` std::string type_1 = “__torch__.C [ NamedTuple, [ [field_name_c_1, Tensor], [field_name_c_2, Tuple[Tensor, Tensor]], ] ]” std::string type_2 = “__torch__.B [ NamedTuple, [ [field_name_b, __torch__.C ] ] ]” std::string type_3 = “__torch__.A[ NamedTuple, [ [field_name_a, __torch__.B] ] ]” std::vector<std::string> type_strs = {type_str_1, type_str_2, type_3}; std::vector<TypePtr> type_ptrs = c10::parseType(type_strs); ``` namedtuple from both `collection` and `typing` are supported ``` from typing import NamedTuple from collections import namedtuple ``` This change only adds the parser and now new runtime can read the above format. ghstack-source-id: 141293658 Test Plan: ``` buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.CompatiblePrimitiveType' buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.CompatibleCustomType' buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.InCompatiblePrimitiveType' buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.InCompatibleCustomType' ``` Reviewed By: iseeyuan Differential Revision: D30261547 fbshipit-source-id: 68a9974338464e320b39a5c613dc048f6c5adeb5
149 lines
5.0 KiB
C++
149 lines
5.0 KiB
C++
#include <ATen/core/ivalue.h>
|
|
#include <torch/csrc/jit/mobile/parse_bytecode.h>
|
|
#include <torch/csrc/jit/mobile/type_parser.h>
|
|
#include <torch/csrc/jit/runtime/instruction.h>
|
|
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
|
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
|
#include <torch/custom_class_detail.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
OpCode parseOpCode(const char* str);
|
|
using c10::IValue;
|
|
|
|
IValue expect_field(
|
|
c10::ivalue::TupleElements& elements,
|
|
const std::string& expected_name,
|
|
size_t entry) {
|
|
auto row = std::move(elements.at(entry)).toTuple();
|
|
TORCH_INTERNAL_ASSERT(
|
|
row->elements().at(0).toStringRef() == expected_name,
|
|
"Expected ",
|
|
expected_name,
|
|
" found ",
|
|
row->elements().at(0).toStringRef());
|
|
return std::move(row->elements().at(1));
|
|
}
|
|
|
|
namespace mobile {
|
|
|
|
namespace {
|
|
#define COUNT_OPCODE(_, _a) 1 +
|
|
constexpr size_t numOpcodes = FORALL_OPCODES(COUNT_OPCODE) 0;
|
|
#undef COUNT_OPCODE
|
|
|
|
// Pickled strings are memoized, so we can cache a mapping from
|
|
// pointers to parsed OpCodes to speed up parsing.
|
|
class OpCodeCache {
|
|
private:
|
|
// We store as void* to emphasize that we care only about the
|
|
// address and should not be dereferencing these pointers.
|
|
std::array<const void*, numOpcodes> keys_{};
|
|
std::array<OpCode, numOpcodes> values_{};
|
|
size_t usedEntries_ = 0;
|
|
|
|
public:
|
|
OpCodeCache() {
|
|
memset(keys_.data(), 0, keys_.size() * sizeof(keys_[0]));
|
|
}
|
|
|
|
OpCode parse(const c10::ivalue::ConstantString& s) {
|
|
const auto endIt = keys_.begin() + usedEntries_;
|
|
auto it = std::find_if(
|
|
keys_.begin(), endIt, [&s](const void* k) { return k == &s; });
|
|
if (it == endIt) {
|
|
OpCode result = parseOpCode(s.string().c_str());
|
|
if (usedEntries_ < numOpcodes) {
|
|
keys_[usedEntries_] = &s;
|
|
values_[usedEntries_++] = result;
|
|
}
|
|
return result;
|
|
}
|
|
// NOTE: I tried implementing the transpose heuristic here to
|
|
// speed up the search, but it removed the benefit of this cache.
|
|
return values_[it - keys_.begin()];
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void parseInstructions(
|
|
const std::string& function_name,
|
|
c10::ivalue::TupleElements&& ins_list,
|
|
c10::ivalue::TupleElements& debug_handles_m_tuple,
|
|
mobile::Function* function) {
|
|
c10::List<int64_t> debug_handles_list;
|
|
if (!debug_handles_m_tuple.empty()) {
|
|
const std::string& debug_info_function_name =
|
|
debug_handles_m_tuple[0].toStringRef();
|
|
TORCH_CHECK(
|
|
debug_info_function_name == function_name,
|
|
"The function names in the bytecode table and the debug info table do not match.");
|
|
IValue& debug_handles_table = debug_handles_m_tuple[1];
|
|
auto debugHandlesTableElements =
|
|
std::move(*std::move(debug_handles_table).toTuple()).elements();
|
|
debug_handles_list = (expect_field(
|
|
debugHandlesTableElements,
|
|
"function_debug_handles",
|
|
BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
|
|
.toTuple()
|
|
->elements())[0]
|
|
.toIntList();
|
|
TORCH_CHECK(
|
|
debug_handles_list.size() == ins_list.size(),
|
|
"The numbers of instructions and debug handles strings do not match.");
|
|
}
|
|
|
|
// NOTE: this won't perform particularly well if the ins_list IValue
|
|
// didn't come from unpickler and thus have its strings
|
|
// interned. Consider adding a flag to bypass the cache if that
|
|
// becomes an important use case.
|
|
OpCodeCache opCodeCache;
|
|
for (const auto j : c10::irange(ins_list.size())) {
|
|
auto ins_tuple = std::move(ins_list[j]).toTuple();
|
|
c10::ArrayRef<IValue> ins_item = ins_tuple->elements();
|
|
TORCH_CHECK(
|
|
ins_item.size() == 3,
|
|
"There should be three parts in an instruction. The function name is ",
|
|
function_name);
|
|
OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
|
|
int X = ins_item[1].toInt();
|
|
int N = ins_item[2].toInt();
|
|
if (!debug_handles_list.empty()) {
|
|
int64_t debug_handle = debug_handles_list[j];
|
|
function->append_instruction(op_code, X, N, debug_handle);
|
|
} else {
|
|
function->append_instruction(op_code, X, N);
|
|
}
|
|
}
|
|
}
|
|
|
|
void parseConstants(
|
|
const c10::ivalue::TupleElements& consts_list,
|
|
mobile::Function* function) {
|
|
for (const auto& constant : consts_list) {
|
|
function->append_constant(constant);
|
|
}
|
|
}
|
|
void parseTypes(
|
|
const c10::ivalue::TupleElements& types_list,
|
|
mobile::Function* function) {
|
|
std::vector<std::string> types_string_list;
|
|
types_string_list.resize(types_list.size());
|
|
for (size_t i = 0; i < types_list.size(); i++) {
|
|
types_string_list[i] = types_list[i].toString()->string();
|
|
}
|
|
|
|
std::vector<c10::TypePtr> types_ptr_list = c10::parseType(types_string_list);
|
|
for (auto& type_ptr : types_ptr_list) {
|
|
function->append_type(type_ptr);
|
|
}
|
|
}
|
|
|
|
void parseRegisterSize(size_t rsize, mobile::Function* function) {
|
|
function->set_register_size(rsize);
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace jit
|
|
} // namespace torch
|