mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[jit][edge] Enable lite interpreter to correctly handle INTERFACE_CALL instruction. (#65972)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65972 ghstack-source-id: 141842336 Test Plan: buck test mode/dev //caffe2/test:mobile -- --exact 'caffe2/test:mobile - test_stacktrace_interface_call (mobile.test_lite_script_module.TestLiteScriptModule)' Reviewed By: qihqi Differential Revision: D31326147 fbshipit-source-id: 338ff4ce8ddc9502ffe0add49057b33b52a24955
This commit is contained in:
parent
d6b15bfcbd
commit
12ede84dbb
|
|
@ -5,6 +5,7 @@ import torch.utils.bundled_inputs
|
||||||
import io
|
import io
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
import inspect
|
import inspect
|
||||||
|
from torch.testing import FileCheck
|
||||||
|
|
||||||
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
|
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
|
|
@ -422,6 +423,55 @@ class TestLiteScriptModule(TestCase):
|
||||||
# torch::jit::JITException to extend c10::Error.
|
# torch::jit::JITException to extend c10::Error.
|
||||||
self.assertTrue('self.val and val are same' in error_message)
|
self.assertTrue('self.val and val are same' in error_message)
|
||||||
|
|
||||||
|
def test_stacktrace_interface_call(self):
|
||||||
|
@torch.jit.interface
|
||||||
|
class Forward(torch.nn.Module):
|
||||||
|
def forward(self, x) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forwardError(self, x) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class B(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forwardError(self, x):
|
||||||
|
return self.call() + x
|
||||||
|
|
||||||
|
def call(self):
|
||||||
|
return torch.ones(-1)
|
||||||
|
|
||||||
|
class A(torch.nn.Module):
|
||||||
|
b : Forward
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.b = B()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
self.b.forward(torch.ones(1))
|
||||||
|
self.b.forwardError(torch.ones(1))
|
||||||
|
|
||||||
|
a = torch.jit.script(A())
|
||||||
|
torch._C._enable_mobile_interface_call_export()
|
||||||
|
buffer = io.BytesIO(a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
|
||||||
|
buffer.seek(0)
|
||||||
|
mobile_module = _load_for_lite_interpreter(buffer)
|
||||||
|
try:
|
||||||
|
mobile_module()
|
||||||
|
self.assertTrue(False)
|
||||||
|
except RuntimeError as exp:
|
||||||
|
FileCheck().check("Trying to create tensor with negative dimension") \
|
||||||
|
.check("Traceback of TorchScript") \
|
||||||
|
.check("self.b.forwardError").check_next("~~~~~~~~~~~~~~~~~~~ <--- HERE") \
|
||||||
|
.check("return self.call").check_next("~~~~~~~~~ <--- HERE") \
|
||||||
|
.check("return torch.ones").check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
|
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -193,8 +193,8 @@ const std::shared_ptr<Code> Function::get_code() const {
|
||||||
return code_;
|
return code_;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t Function::getExceptionDebugHandle() const {
|
const std::vector<int64_t>& Function::getExceptionDebugHandles() const {
|
||||||
return getInterpretersExceptionDebugHandle();
|
return getInterpretersExceptionDebugHandles();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class TORCH_API Function : public torch::jit::Function {
|
||||||
// Returns the debug handle corresponding to where the execution
|
// Returns the debug handle corresponding to where the execution
|
||||||
// is halted due to exception.
|
// is halted due to exception.
|
||||||
// If no corresponding debug handle is found then -1 is returned.
|
// If no corresponding debug handle is found then -1 is returned.
|
||||||
int64_t getExceptionDebugHandle() const;
|
const std::vector<int64_t>& getExceptionDebugHandles() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
c10::QualifiedName name_;
|
c10::QualifiedName name_;
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ InterpreterState::InterpreterState(const Code& code) {
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
static thread_local DebugHandle exception_debug_handle_{-1};
|
static thread_local std::vector<DebugHandle> exception_debug_handles_;
|
||||||
void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
||||||
auto userObj = c10::ivalue::Object::create(
|
auto userObj = c10::ivalue::Object::create(
|
||||||
c10::StrongTypePtr(type->compilation_unit(), type),
|
c10::StrongTypePtr(type->compilation_unit(), type),
|
||||||
|
|
@ -45,8 +45,8 @@ void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
|
||||||
|
|
||||||
using namespace at;
|
using namespace at;
|
||||||
|
|
||||||
int64_t getInterpretersExceptionDebugHandle() {
|
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles() {
|
||||||
return exception_debug_handle_;
|
return exception_debug_handles_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void InterpreterState::enterFrame(const Code& code) {
|
void InterpreterState::enterFrame(const Code& code) {
|
||||||
|
|
@ -60,12 +60,18 @@ void InterpreterState::leaveFrame() {
|
||||||
frames_.pop_back();
|
frames_.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
void InterpreterState::saveExceptionDebugHandle() {
|
void InterpreterState::saveExceptionDebugHandles() {
|
||||||
const auto& frame = frames_.back();
|
std::vector<DebugHandle> exception_debug_handles;
|
||||||
if (auto handle = frame.getDebugHandle()) {
|
for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) {
|
||||||
exception_debug_handle_ = *handle;
|
size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0);
|
||||||
|
if (auto handle = frame->getDebugHandle(pc)) {
|
||||||
|
exception_debug_handles.push_back(*handle);
|
||||||
|
} else {
|
||||||
|
exception_debug_handles.push_back(-1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
exception_debug_handles_ = std::move(exception_debug_handles);
|
||||||
|
}
|
||||||
|
|
||||||
void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
|
void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
|
||||||
bool newFrame =
|
bool newFrame =
|
||||||
|
|
@ -142,8 +148,7 @@ bool InterpreterState::run(Stack& stack) {
|
||||||
->getMethod(code.constants_[inst.X].toStringRef());
|
->getMethod(code.constants_[inst.X].toStringRef());
|
||||||
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
||||||
method.name(), debug_handle, stack);
|
method.name(), debug_handle, stack);
|
||||||
method.run(stack);
|
callFunction(method, stack);
|
||||||
frame.step();
|
|
||||||
} break;
|
} break;
|
||||||
case LOAD:
|
case LOAD:
|
||||||
stack.emplace_back(reg(inst.X));
|
stack.emplace_back(reg(inst.X));
|
||||||
|
|
@ -285,15 +290,15 @@ bool InterpreterState::run(Stack& stack) {
|
||||||
}
|
}
|
||||||
// This exception must be caught first as it derived from c10::Error
|
// This exception must be caught first as it derived from c10::Error
|
||||||
} catch (c10::BackendRuntimeException& e) {
|
} catch (c10::BackendRuntimeException& e) {
|
||||||
saveExceptionDebugHandle();
|
saveExceptionDebugHandles();
|
||||||
TORCH_RETHROW(e);
|
TORCH_RETHROW(e);
|
||||||
} catch (c10::Error& error) {
|
} catch (c10::Error& error) {
|
||||||
// Reason for catching and rethrowing the error is so that we can
|
// Reason for catching and rethrowing the error is so that we can
|
||||||
// set the exception pc that is queried later
|
// set the exception pc that is queried later
|
||||||
saveExceptionDebugHandle();
|
saveExceptionDebugHandles();
|
||||||
TORCH_RETHROW(error);
|
TORCH_RETHROW(error);
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
saveExceptionDebugHandle();
|
saveExceptionDebugHandles();
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
// for (auto val : stack) {
|
// for (auto val : stack) {
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ struct InterpreterState {
|
||||||
private:
|
private:
|
||||||
void enterFrame(const Code&);
|
void enterFrame(const Code&);
|
||||||
void leaveFrame();
|
void leaveFrame();
|
||||||
void saveExceptionDebugHandle();
|
void saveExceptionDebugHandles();
|
||||||
void callFunction(torch::jit::Function& f, Stack& stack);
|
void callFunction(torch::jit::Function& f, Stack& stack);
|
||||||
|
|
||||||
c10::IValue& reg(size_t reg);
|
c10::IValue& reg(size_t reg);
|
||||||
|
|
@ -24,14 +24,7 @@ struct InterpreterState {
|
||||||
std::vector<Frame> frames_;
|
std::vector<Frame> frames_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Interpreter executes instruction in a loop one by one
|
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles();
|
||||||
// from a list of instructions. PC is a program counter pointer
|
|
||||||
// pointing to the current instruction being executed.
|
|
||||||
// This function returns the current PC.
|
|
||||||
// Note that this is set only when exception occurs.
|
|
||||||
// since this is a thread local variable and setting it for
|
|
||||||
// every instruction will add overhead of thread local variable access.
|
|
||||||
DebugHandle getInterpretersExceptionDebugHandle();
|
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -204,7 +204,7 @@ void Method::run(Stack& stack) const {
|
||||||
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
||||||
if (error_message.empty()) {
|
if (error_message.empty()) {
|
||||||
error_message = owner_->getDebugTable().getSourceDebugString(
|
error_message = owner_->getDebugTable().getSourceDebugString(
|
||||||
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
|
function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -226,7 +226,9 @@ void Method::run(Stack& stack) const {
|
||||||
// This exception must be caught first as it derived from c10::Error
|
// This exception must be caught first as it derived from c10::Error
|
||||||
} catch (c10::BackendRuntimeException& e) {
|
} catch (c10::BackendRuntimeException& e) {
|
||||||
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
||||||
e.pushDebugHandle(function_->getExceptionDebugHandle());
|
for (auto handle : function_->getExceptionDebugHandles()) {
|
||||||
|
e.pushDebugHandle(handle);
|
||||||
|
}
|
||||||
// symbolicate all handles
|
// symbolicate all handles
|
||||||
auto debug_string = owner_->getDebugTable().getSourceDebugString(
|
auto debug_string = owner_->getDebugTable().getSourceDebugString(
|
||||||
e.getDebugHandles(), getTopModuleTypeName(*owner_));
|
e.getDebugHandles(), getTopModuleTypeName(*owner_));
|
||||||
|
|
@ -237,7 +239,7 @@ void Method::run(Stack& stack) const {
|
||||||
} catch (c10::Error& error) {
|
} catch (c10::Error& error) {
|
||||||
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
|
||||||
auto debug_string = owner_->getDebugTable().getSourceDebugString(
|
auto debug_string = owner_->getDebugTable().getSourceDebugString(
|
||||||
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
|
function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
|
||||||
error.add_context(debug_string);
|
error.add_context(debug_string);
|
||||||
#endif
|
#endif
|
||||||
error_message = error.what();
|
error_message = error.what();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user