mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[jit][edge] Enable lite interpreter to correctly handle INTERFACE_CALL instruction. (#65972)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65972 ghstack-source-id: 141842336 Test Plan: buck test mode/dev //caffe2/test:mobile -- --exact 'caffe2/test:mobile - test_stacktrace_interface_call (mobile.test_lite_script_module.TestLiteScriptModule)' Reviewed By: qihqi Differential Revision: D31326147 fbshipit-source-id: 338ff4ce8ddc9502ffe0add49057b33b52a24955
This commit is contained in:
parent
d6b15bfcbd
commit
12ede84dbb
|
|
@ -5,6 +5,7 @@ import torch.utils.bundled_inputs
|
|||
import io
|
||||
from typing import Dict, List
|
||||
import inspect
|
||||
from torch.testing import FileCheck
|
||||
|
||||
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
|
@ -422,6 +423,55 @@ class TestLiteScriptModule(TestCase):
|
|||
# torch::jit::JITException to extend c10::Error.
|
||||
self.assertTrue('self.val and val are same' in error_message)
|
||||
|
||||
def test_stacktrace_interface_call(self):
|
||||
@torch.jit.interface
|
||||
class Forward(torch.nn.Module):
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def forwardError(self, x) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
class B(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
def forwardError(self, x):
|
||||
return self.call() + x
|
||||
|
||||
def call(self):
|
||||
return torch.ones(-1)
|
||||
|
||||
class A(torch.nn.Module):
|
||||
b : Forward
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.b = B()
|
||||
|
||||
def forward(self):
|
||||
self.b.forward(torch.ones(1))
|
||||
self.b.forwardError(torch.ones(1))
|
||||
|
||||
a = torch.jit.script(A())
|
||||
torch._C._enable_mobile_interface_call_export()
|
||||
buffer = io.BytesIO(a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
|
||||
buffer.seek(0)
|
||||
mobile_module = _load_for_lite_interpreter(buffer)
|
||||
try:
|
||||
mobile_module()
|
||||
self.assertTrue(False)
|
||||
except RuntimeError as exp:
|
||||
FileCheck().check("Trying to create tensor with negative dimension") \
|
||||
.check("Traceback of TorchScript") \
|
||||
.check("self.b.forwardError").check_next("~~~~~~~~~~~~~~~~~~~ <--- HERE") \
|
||||
.check("return self.call").check_next("~~~~~~~~~ <--- HERE") \
|
||||
.check("return torch.ones").check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
|
||||
|
||||
|
||||
|
||||
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
|
||||
|
||||
|
|
|
|||
|
|
@ -193,8 +193,8 @@ const std::shared_ptr<Code> Function::get_code() const {
|
|||
return code_;
|
||||
}
|
||||
|
||||
int64_t Function::getExceptionDebugHandle() const {
|
||||
return getInterpretersExceptionDebugHandle();
|
||||
const std::vector<int64_t>& Function::getExceptionDebugHandles() const {
|
||||
return getInterpretersExceptionDebugHandles();
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class TORCH_API Function : public torch::jit::Function {
|
|||
// Returns the debug handle corresponding to where the execution
|
||||
// is halted due to exception.
|
||||
// If no corresponding debug handle is found then -1 is returned.
|
||||
int64_t getExceptionDebugHandle() const;
|
||||
const std::vector<int64_t>& getExceptionDebugHandles() const;
|
||||
|
||||
private:
|
||||
c10::QualifiedName name_;
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ InterpreterState::InterpreterState(const Code& code) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
static thread_local DebugHandle exception_debug_handle_{-1};
|
||||
static thread_local std::vector<DebugHandle> exception_debug_handles_;
|
||||
void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
||||
auto userObj = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(type->compilation_unit(), type),
|
||||
|
|
@ -45,8 +45,8 @@ void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
|
|||
|
||||
using namespace at;
|
||||
|
||||
int64_t getInterpretersExceptionDebugHandle() {
|
||||
return exception_debug_handle_;
|
||||
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles() {
|
||||
return exception_debug_handles_;
|
||||
}
|
||||
|
||||
void InterpreterState::enterFrame(const Code& code) {
|
||||
|
|
@ -60,11 +60,17 @@ void InterpreterState::leaveFrame() {
|
|||
frames_.pop_back();
|
||||
}
|
||||
|
||||
void InterpreterState::saveExceptionDebugHandle() {
|
||||
const auto& frame = frames_.back();
|
||||
if (auto handle = frame.getDebugHandle()) {
|
||||
exception_debug_handle_ = *handle;
|
||||
void InterpreterState::saveExceptionDebugHandles() {
|
||||
std::vector<DebugHandle> exception_debug_handles;
|
||||
for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) {
|
||||
size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0);
|
||||
if (auto handle = frame->getDebugHandle(pc)) {
|
||||
exception_debug_handles.push_back(*handle);
|
||||
} else {
|
||||
exception_debug_handles.push_back(-1);
|
||||
}
|
||||
}
|
||||
exception_debug_handles_ = std::move(exception_debug_handles);
|
||||
}
|
||||
|
||||
void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
|
||||
|
|
@ -142,8 +148,7 @@ bool InterpreterState::run(Stack& stack) {
|
|||
->getMethod(code.constants_[inst.X].toStringRef());
|
||||
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
||||
method.name(), debug_handle, stack);
|
||||
method.run(stack);
|
||||
frame.step();
|
||||
callFunction(method, stack);
|
||||
} break;
|
||||
case LOAD:
|
||||
stack.emplace_back(reg(inst.X));
|
||||
|
|
@ -285,15 +290,15 @@ bool InterpreterState::run(Stack& stack) {
|
|||
}
|
||||
// This exception must be caught first as it derived from c10::Error
|
||||
} catch (c10::BackendRuntimeException& e) {
|
||||
saveExceptionDebugHandle();
|
||||
saveExceptionDebugHandles();
|
||||
TORCH_RETHROW(e);
|
||||
} catch (c10::Error& error) {
|
||||
// Reason for catching and rethrowing the error is so that we can
|
||||
// set the exception pc that is queried later
|
||||
saveExceptionDebugHandle();
|
||||
saveExceptionDebugHandles();
|
||||
TORCH_RETHROW(error);
|
||||
} catch (...) {
|
||||
saveExceptionDebugHandle();
|
||||
saveExceptionDebugHandles();
|
||||
throw;
|
||||
}
|
||||
// for (auto val : stack) {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ struct InterpreterState {
|
|||
private:
|
||||
void enterFrame(const Code&);
|
||||
void leaveFrame();
|
||||
void saveExceptionDebugHandle();
|
||||
void saveExceptionDebugHandles();
|
||||
void callFunction(torch::jit::Function& f, Stack& stack);
|
||||
|
||||
c10::IValue& reg(size_t reg);
|
||||
|
|
@ -24,14 +24,7 @@ struct InterpreterState {
|
|||
std::vector<Frame> frames_;
|
||||
};
|
||||
|
||||
// Interpreter executes instruction in a loop one by one
|
||||
// from a list of instructions. PC is a program counter pointer
|
||||
// pointing to the current instruction being executed.
|
||||
// This function returns the current PC.
|
||||
// Note that this is set only when exception occurs.
|
||||
// since this is a thread local variable and setting it for
|
||||
// every instruction will add overhead of thread local variable access.
|
||||
DebugHandle getInterpretersExceptionDebugHandle();
|
||||
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles();
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ void Method::run(Stack& stack) const {
|
|||
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
||||
if (error_message.empty()) {
|
||||
error_message = owner_->getDebugTable().getSourceDebugString(
|
||||
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
|
||||
function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -226,7 +226,9 @@ void Method::run(Stack& stack) const {
|
|||
// This exception must be caught first as it derived from c10::Error
|
||||
} catch (c10::BackendRuntimeException& e) {
|
||||
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
||||
e.pushDebugHandle(function_->getExceptionDebugHandle());
|
||||
for (auto handle : function_->getExceptionDebugHandles()) {
|
||||
e.pushDebugHandle(handle);
|
||||
}
|
||||
// symbolicate all handles
|
||||
auto debug_string = owner_->getDebugTable().getSourceDebugString(
|
||||
e.getDebugHandles(), getTopModuleTypeName(*owner_));
|
||||
|
|
@ -237,7 +239,7 @@ void Method::run(Stack& stack) const {
|
|||
} catch (c10::Error& error) {
|
||||
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
||||
auto debug_string = owner_->getDebugTable().getSourceDebugString(
|
||||
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
|
||||
function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
|
||||
error.add_context(debug_string);
|
||||
#endif
|
||||
error_message = error.what();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user