[jit][edge] Enable lite interpreter to correctly handle INTERFACE_CALL instruction. (#65972)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65972

ghstack-source-id: 141842336

Test Plan: buck test mode/dev //caffe2/test:mobile -- --exact 'caffe2/test:mobile - test_stacktrace_interface_call (mobile.test_lite_script_module.TestLiteScriptModule)'

Reviewed By: qihqi

Differential Revision: D31326147

fbshipit-source-id: 338ff4ce8ddc9502ffe0add49057b33b52a24955
This commit is contained in:
Zhengxu Chen 2021-10-29 13:11:41 -07:00 committed by Facebook GitHub Bot
parent d6b15bfcbd
commit 12ede84dbb
6 changed files with 77 additions and 27 deletions

View File

@ -5,6 +5,7 @@ import torch.utils.bundled_inputs
import io
from typing import Dict, List
import inspect
from torch.testing import FileCheck
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
from torch.testing._internal.common_utils import TestCase, run_tests
@ -422,6 +423,55 @@ class TestLiteScriptModule(TestCase):
# torch::jit::JITException to extend c10::Error.
self.assertTrue('self.val and val are same' in error_message)
def test_stacktrace_interface_call(self):
@torch.jit.interface
class Forward(torch.nn.Module):
def forward(self, x) -> torch.Tensor:
pass
def forwardError(self, x) -> torch.Tensor:
pass
class B(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def forwardError(self, x):
return self.call() + x
def call(self):
return torch.ones(-1)
class A(torch.nn.Module):
b : Forward
def __init__(self):
super().__init__()
self.b = B()
def forward(self):
self.b.forward(torch.ones(1))
self.b.forwardError(torch.ones(1))
a = torch.jit.script(A())
torch._C._enable_mobile_interface_call_export()
buffer = io.BytesIO(a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
try:
mobile_module()
self.assertTrue(False)
except RuntimeError as exp:
FileCheck().check("Trying to create tensor with negative dimension") \
.check("Traceback of TorchScript") \
.check("self.b.forwardError").check_next("~~~~~~~~~~~~~~~~~~~ <--- HERE") \
.check("return self.call").check_next("~~~~~~~~~ <--- HERE") \
.check("return torch.ones").check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):

View File

@ -193,8 +193,8 @@ const std::shared_ptr<Code> Function::get_code() const {
return code_;
}
int64_t Function::getExceptionDebugHandle() const {
return getInterpretersExceptionDebugHandle();
const std::vector<int64_t>& Function::getExceptionDebugHandles() const {
return getInterpretersExceptionDebugHandles();
}
} // namespace mobile

View File

@ -48,7 +48,7 @@ class TORCH_API Function : public torch::jit::Function {
// Returns the debug handle corresponding to where the execution
// is halted due to exception.
// If no corresponding debug handle is found then -1 is returned.
int64_t getExceptionDebugHandle() const;
const std::vector<int64_t>& getExceptionDebugHandles() const;
private:
c10::QualifiedName name_;

View File

@ -23,7 +23,7 @@ InterpreterState::InterpreterState(const Code& code) {
}
namespace {
static thread_local DebugHandle exception_debug_handle_{-1};
static thread_local std::vector<DebugHandle> exception_debug_handles_;
void createObject(Stack& stack, const at::ClassTypePtr& type) {
auto userObj = c10::ivalue::Object::create(
c10::StrongTypePtr(type->compilation_unit(), type),
@ -45,8 +45,8 @@ void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
using namespace at;
int64_t getInterpretersExceptionDebugHandle() {
return exception_debug_handle_;
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles() {
return exception_debug_handles_;
}
void InterpreterState::enterFrame(const Code& code) {
@ -60,12 +60,18 @@ void InterpreterState::leaveFrame() {
frames_.pop_back();
}
void InterpreterState::saveExceptionDebugHandle() {
const auto& frame = frames_.back();
if (auto handle = frame.getDebugHandle()) {
exception_debug_handle_ = *handle;
void InterpreterState::saveExceptionDebugHandles() {
std::vector<DebugHandle> exception_debug_handles;
for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) {
size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0);
if (auto handle = frame->getDebugHandle(pc)) {
exception_debug_handles.push_back(*handle);
} else {
exception_debug_handles.push_back(-1);
}
}
exception_debug_handles_ = std::move(exception_debug_handles);
}
void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
bool newFrame =
@ -142,8 +148,7 @@ bool InterpreterState::run(Stack& stack) {
->getMethod(code.constants_[inst.X].toStringRef());
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
method.name(), debug_handle, stack);
method.run(stack);
frame.step();
callFunction(method, stack);
} break;
case LOAD:
stack.emplace_back(reg(inst.X));
@ -285,15 +290,15 @@ bool InterpreterState::run(Stack& stack) {
}
// This exception must be caught first as it derived from c10::Error
} catch (c10::BackendRuntimeException& e) {
saveExceptionDebugHandle();
saveExceptionDebugHandles();
TORCH_RETHROW(e);
} catch (c10::Error& error) {
// Reason for catching and rethrowing the error is so that we can
// set the exception pc that is queried later
saveExceptionDebugHandle();
saveExceptionDebugHandles();
TORCH_RETHROW(error);
} catch (...) {
saveExceptionDebugHandle();
saveExceptionDebugHandles();
throw;
}
// for (auto val : stack) {

View File

@ -16,7 +16,7 @@ struct InterpreterState {
private:
void enterFrame(const Code&);
void leaveFrame();
void saveExceptionDebugHandle();
void saveExceptionDebugHandles();
void callFunction(torch::jit::Function& f, Stack& stack);
c10::IValue& reg(size_t reg);
@ -24,14 +24,7 @@ struct InterpreterState {
std::vector<Frame> frames_;
};
// Interpreter executes instruction in a loop one by one
// from a list of instructions. PC is a program counter pointer
// pointing to the current instruction being executed.
// This function returns the current PC.
// Note that this is set only when exception occurs.
// since this is a thread local variable and setting it for
// every instruction will add overhead of thread local variable access.
DebugHandle getInterpretersExceptionDebugHandle();
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles();
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -204,7 +204,7 @@ void Method::run(Stack& stack) const {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
if (error_message.empty()) {
error_message = owner_->getDebugTable().getSourceDebugString(
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
}
#endif
@ -226,7 +226,9 @@ void Method::run(Stack& stack) const {
// This exception must be caught first as it derived from c10::Error
} catch (c10::BackendRuntimeException& e) {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
e.pushDebugHandle(function_->getExceptionDebugHandle());
for (auto handle : function_->getExceptionDebugHandles()) {
e.pushDebugHandle(handle);
}
// symbolicate all handles
auto debug_string = owner_->getDebugTable().getSourceDebugString(
e.getDebugHandles(), getTopModuleTypeName(*owner_));
@ -237,7 +239,7 @@ void Method::run(Stack& stack) const {
} catch (c10::Error& error) {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
auto debug_string = owner_->getDebugTable().getSourceDebugString(
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
error.add_context(debug_string);
#endif
error_message = error.what();