pytorch/torch/csrc/jit/mobile/function.cpp
Kimish Patel f4a921600a [PyTorch, Mobile] Serialization format change for source range (#54284)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54284

In order to bring mobile deployment, via lite interpreter, on feature
parity with JIT, with respect model level debug information we must make
model level debug information available to mobile runtime.
At the moment, model level debug information is stored in SourceRange
which associates node's of graph to where the come from in original
python source code.
This information is serialized as part of debug_pkl and deserialized
when JIT loads the model and reads the model code.
On lite interpreter, we do not have access to all the functionality of
JIT and hence we cannot load model in the same way as JIT, by reading
code, constructing module hierarchy and graph corresponding module
methods etc. Instead in, lite interpreter, only bytecode corresonding to
the compiled graph, Code, is saved.
Thus in order to annotate OPs in the bytecode with equivalent
SourceRange information we do the following:
1. During model serialization, we create a unique tag for each source
range of the model.
2. Create a map of <SourceRange, tag>
3. During debug_pkl serialization we save tag along with SourceRange, on
top of byte offset.
4. During bytecode generation, the methods of the top module are
lowered. During this process methods are inlined. In the inlined graph,
when the node of a graph is lowered to bytecode, we query node's source
range and look it up against the map.
5. Resulting source range tag is serialized in module_debug_info.
6. During model deserialization, we read all the debug_pkl records in
the archieve and create a map of <tag, SourceRange>
7. This map can be used to find source code information.

During mobile runtime:
1. We read all the debug_pkl records and create <tag=debug_handle,
SourceRange> map.
   1.1 This map, MobileDebugInfo, is a member of mobile Module.
2. Interpreter catches appropriate exceptions and sets the thread local
debug handle and rethrows the exception.
3. In Function's run method we catch exception and query current debug
handle where the exception happened.
4. Query MobileDebugInfo with debug handle to retrieve source range and
augment error with source range info.

This information is still incomplete as it does not contain entire
callstack.

In the following diffs we will serialize InlinedCallStack directly.

Note that compilation is gated by SYMBOLICATE_MOBILE_DEBUG_HANDLE macro,
so that mobile builds can avoid building MobileDebugInfo, source range
and source range pickler/unpickler. Later we will add path where, if
building without debug support stack trace will contain only debug
handles. They can be symbolicated later.

Test Plan:
Ported bunch of source range tests from test_jit.py. Added on more test
in test_lite_interpreter.py

Imported from OSS

Reviewed By: raziel

Differential Revision: D27174722

fbshipit-source-id: a7b7c6088ce16dec37e823c7fefa4f0b61047e12
2021-05-04 09:19:27 -07:00

142 lines
3.9 KiB
C++

#include <torch/csrc/jit/mobile/function.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/custom_class_detail.h>
namespace torch {
namespace jit {
char const* toString(OpCode op);
namespace mobile {
Function::Function(c10::QualifiedName name)
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
const c10::QualifiedName& Function::qualname() const {
return name_;
}
const std::string& Function::name() const {
return name_.name();
}
void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
TORCH_CHECK(
isOpSupportedInMobile(op),
toString(op),
" is not supported in mobile module.");
code_->instructions_.emplace_back(op, X, N);
code_->debug_handles_.emplace_back(dbg_handle);
}
bool Function::append_operator(
const std::string& name,
const std::string& overload_name,
int64_t model_version) {
// Keep the original opname in code_
code_->op_names_.emplace_back(name, overload_name);
auto opname = code_->op_names_.back();
const auto& opname_c10 = opname;
std::function<void(Stack&)> fn;
auto jit_op = findOperatorFor(opname);
if (jit_op) {
fn = [jit_op](Stack& stack) { jit_op->getOperation()(&stack); };
} else {
auto op = c10::Dispatcher::singleton().findSchema(opname_c10);
if (op.has_value()) {
fn = [op](Stack& stack) { op->callBoxed(&stack); };
} else {
return false;
}
}
if (model_version == 0x3LL &&
opname == c10::OperatorName("aten::_convolution", "")) {
// Since byte-code versions 0x4L, convolution has an additional
// default-value argument (allow_tf32=True, see
// https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
// backward compatibility with models of byte-code version <= 0x3L, where
// this bool argument does not yet exist.
fn = [fn](Stack& stack) {
stack.push_back(true);
fn(stack);
};
}
code_->operators_.emplace_back(fn);
return true;
}
void Function::set_module_debug_info_list_size(size_t size) {
pc_to_module_debug_info_.resize(size);
for (size_t i = 0; i < size; ++i) {
pc_to_module_debug_info_[i] = "<no module info>";
}
}
void Function::set_module_info(const std::string& module_info, size_t pc) {
TORCH_CHECK(
pc < pc_to_module_debug_info_.size(),
"Module debug info index out of boundary.");
pc_to_module_debug_info_[pc] = module_info;
}
void Function::append_constant(const c10::IValue& constant) {
code_->constants_.push_back(constant);
}
void Function::append_type(const at::TypePtr& type) {
code_->types_.push_back(type);
}
void Function::set_register_size(size_t size) {
code_->register_size_ = size;
}
std::string Function::get_module_debug_info(size_t pc) const {
TORCH_CHECK(
pc < pc_to_module_debug_info_.size(),
"Module debug info index out of boundary.");
return pc_to_module_debug_info_[pc];
}
void Function::setSchema(c10::FunctionSchema schema) {
schema_ = std::move(schema);
}
const at::optional<c10::FunctionSchema>& Function::getSchema() const {
return schema_;
}
bool Function::run(Stack& stack) const {
const auto& schema = getSchema();
if (schema) { // if we have a schema then resolve optional args if any
schema->checkAndNormalizeInputs(
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
}
InterpreterState interp_state(code_);
return interp_state.run(stack);
}
c10::IValue Function::operator()(Stack& stack) const {
run(stack);
return stack.front();
}
const std::shared_ptr<Code> Function::get_code() const {
return code_;
}
int64_t Function::getExceptionDebugHandle() const {
size_t pc = getInterpretersExceptionPC();
return (pc < code_->debug_handles_.size()) ? code_->debug_handles_[pc] : -1;
}
} // namespace mobile
} // namespace jit
} // namespace torch