mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: The INTERFACE_CALL opcode for the mobile TorchScript interpreter contained an out of bounds read issue leading to memory corruption. This change adds an explicit check that the number of inputs passed to the format method called when handling the INTERFACE_CALL opcode is a valid and within bounds of the stack. Test Plan: contbuild + OSS signals Differential Revision: D49739450 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110301 Approved by: https://github.com/dbort
404 lines
13 KiB
C++
404 lines
13 KiB
C++
#include <torch/csrc/jit/mobile/interpreter.h>
|
|
|
|
#include <ATen/core/class_type.h>
|
|
#include <ATen/core/dynamic_type.h>
|
|
#include <ATen/core/function.h>
|
|
#include <ATen/core/jit_type.h>
|
|
#include <ATen/core/operator_name.h>
|
|
#include <ATen/record_function.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/backends/backend_exception.h>
|
|
#include <torch/csrc/jit/mobile/function.h>
|
|
#include <torch/csrc/jit/mobile/observer.h>
|
|
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
|
#include <torch/csrc/jit/runtime/jit_exception.h>
|
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
char const* toString(OpCode op);
|
|
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
|
namespace mobile {
|
|
InterpreterState::InterpreterState(const Code& code) {
|
|
enterFrame(code);
|
|
}
|
|
|
|
namespace {
|
|
static thread_local std::vector<DebugHandle> exception_debug_handles_;
|
|
void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
|
auto userObj = c10::ivalue::Object::create(
|
|
c10::StrongTypePtr(type->compilation_unit(), type),
|
|
type->numAttributes());
|
|
push(stack, std::move(userObj));
|
|
}
|
|
|
|
void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
|
|
at::TypePtr ty = pop(stack).type<c10::DynamicType>();
|
|
for (const at::TypePtr& candidate : types) {
|
|
if (ty->isSubtypeOf(*candidate)) {
|
|
push(stack, true);
|
|
return;
|
|
}
|
|
}
|
|
push(stack, false);
|
|
}
|
|
} // namespace
|
|
|
|
using namespace at;
|
|
|
|
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles() {
|
|
return exception_debug_handles_;
|
|
}
|
|
|
|
void InterpreterState::enterFrame(const Code& code) {
|
|
frames_.emplace_back(code);
|
|
registers_.resize(registers_.size() + code.register_size_);
|
|
}
|
|
|
|
void InterpreterState::leaveFrame() {
|
|
registers_.resize(
|
|
registers_.size() - frames_.back().getCode().register_size_);
|
|
frames_.pop_back();
|
|
}
|
|
|
|
void InterpreterState::saveExceptionDebugHandles() {
|
|
std::vector<DebugHandle> exception_debug_handles;
|
|
for (auto frame = frames_.crbegin(); frame != frames_.crend(); frame++) {
|
|
size_t pc = frame->getPC() - (frame != frames_.crbegin() ? 1 : 0);
|
|
if (auto handle = frame->getDebugHandle(pc)) {
|
|
exception_debug_handles.push_back(*handle);
|
|
} else {
|
|
exception_debug_handles.push_back(-1);
|
|
}
|
|
}
|
|
exception_debug_handles_ = std::move(exception_debug_handles);
|
|
}
|
|
|
|
void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
|
|
bool newFrame =
|
|
f.call(stack, [&](const mobile::Code& code) { enterFrame(code); });
|
|
(frames_.rbegin() + (newFrame ? 1 : 0))->step();
|
|
}
|
|
|
|
bool InterpreterState::run(Stack& stack) {
|
|
while (true) {
|
|
try {
|
|
auto& frame = frames_.back();
|
|
const auto& code = frame.getCode();
|
|
const auto pc = frame.getPC();
|
|
auto inst = frame.getInstruction();
|
|
// If no valid debug handle found then just log pc.
|
|
// This is possible when we did not save debug handles
|
|
|
|
DebugHandle debug_handle = pc;
|
|
if (auto handle = frame.getDebugHandle()) {
|
|
debug_handle = *handle;
|
|
}
|
|
|
|
// std::cout << "RUNNING " << pc << " " << code.instructions_[pc];
|
|
// if (inst.op == OP) {
|
|
// std::cout << ", " << code.op_names_[inst.X].name;
|
|
// if (!code.op_names_[inst.X].overload_name.empty()) {
|
|
// std::cout << "." << code.op_names_[inst.X].overload_name;
|
|
// }
|
|
// }
|
|
// std::cout << std::endl;
|
|
|
|
// TODO(iliacher): remove the workaround after RecordFunction is in
|
|
// Dispatcher
|
|
// Check with iliacher if has been done.
|
|
// Plus this is not safe as if you throw exception record function will be
|
|
// left enabled. That is a TODO
|
|
// NOTE: this recordFunction logic takes up ~2-3% of cpu cycles in some
|
|
// workflows. do we need it and/or can we opt-out of
|
|
// isRecordFunctionEnabled with a macro? if we delete it, things appear to
|
|
// work just fine.
|
|
bool prev_value = isRecordFunctionEnabled();
|
|
if (!prev_value) {
|
|
// enable only for the RecordFunction
|
|
enableRecordFunction(true);
|
|
}
|
|
switch (inst.op) {
|
|
case OP: {
|
|
if (at::hasGlobalCallbacks()) {
|
|
if (auto* mobile_debug_info = static_cast<MobileDebugInfo*>(
|
|
c10::ThreadLocalDebugInfo::get(
|
|
c10::DebugInfoKind::MOBILE_RUNTIME_INFO))) {
|
|
mobile_debug_info->setOpIdx(pc);
|
|
}
|
|
}
|
|
if (inst.X < 0 ||
|
|
static_cast<size_t>(inst.X) >= code.op_names_.size() ||
|
|
static_cast<size_t>(inst.X) >= code.operators_.size()) {
|
|
TORCH_CHECK(false, "Can't load op with index: ", inst.X);
|
|
}
|
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
|
code.op_names_[inst.X].name, debug_handle, stack);
|
|
code.operators_[inst.X](stack);
|
|
frame.step();
|
|
} break;
|
|
case OPN: {
|
|
if (inst.X < 0 ||
|
|
static_cast<size_t>(inst.X) >= code.op_names_.size() ||
|
|
static_cast<size_t>(inst.X) >= code.operators_.size()) {
|
|
TORCH_CHECK(false, "Can't load op with index: ", inst.X);
|
|
}
|
|
stack.emplace_back(inst.N);
|
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
|
code.op_names_[inst.X].name, debug_handle, stack);
|
|
code.operators_[inst.X](stack);
|
|
frame.step();
|
|
} break;
|
|
case CALL: {
|
|
auto& function = *frame.getCode().functions_.at(inst.X);
|
|
callFunction(function, stack);
|
|
} break;
|
|
case INTERFACE_CALL: {
|
|
if (inst.X < 0 ||
|
|
static_cast<size_t>(inst.X) >= code.constants_.size()) {
|
|
TORCH_CHECK(false, "Can't load constant with index: ", inst.X);
|
|
}
|
|
if (inst.N == 0 || inst.N > stack.size()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"INTERFACE_CALL N=",
|
|
inst.N,
|
|
" not in range [1, ",
|
|
stack.size(),
|
|
"]");
|
|
}
|
|
torch::jit::Function& method =
|
|
peek(stack, 0, inst.N)
|
|
.toObject()
|
|
->type()
|
|
->getMethod(code.constants_[inst.X].toStringRef());
|
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
|
method.name(), debug_handle, stack);
|
|
callFunction(method, stack);
|
|
} break;
|
|
case LOAD:
|
|
stack.emplace_back(reg(inst.X));
|
|
frame.step();
|
|
break;
|
|
case MOVE:
|
|
stack.emplace_back(std::move(reg(inst.X)));
|
|
frame.step();
|
|
break;
|
|
case STORE:
|
|
reg(inst.X) = pop(stack);
|
|
frame.step();
|
|
break;
|
|
case STOREN:
|
|
for (size_t i = inst.N; i > 0; --i) {
|
|
reg(inst.X + i - 1) = pop(stack);
|
|
}
|
|
frame.step();
|
|
break;
|
|
case DROP:
|
|
pop(stack);
|
|
frame.step();
|
|
break;
|
|
case DROPR:
|
|
reg(inst.X) = IValue();
|
|
frame.step();
|
|
break;
|
|
case LOADC:
|
|
if (inst.X < 0 ||
|
|
static_cast<size_t>(inst.X) >= code.constants_.size()) {
|
|
TORCH_CHECK(false, "Can't load constant with index: ", inst.X);
|
|
}
|
|
stack.emplace_back(code.constants_[inst.X]);
|
|
frame.step();
|
|
break;
|
|
case GET_ATTR: {
|
|
auto userObj = pop(stack).toObject();
|
|
auto value = userObj->getSlot(inst.X);
|
|
push(stack, std::move(value));
|
|
frame.step();
|
|
} break;
|
|
case SET_ATTR: {
|
|
auto v = pop(stack);
|
|
auto userObj = pop(stack).toObject();
|
|
// Mobile only: since the number of slots is not known, resize the
|
|
// numAttributes before setSlot.
|
|
while (static_cast<int>(userObj->type()->numAttributes()) <= inst.X) {
|
|
std::stringstream ss;
|
|
ss << userObj->type()->numAttributes();
|
|
userObj->type()->addAttribute(ss.str(), c10::NoneType::get());
|
|
}
|
|
userObj->setSlot(inst.X, std::move(v));
|
|
frame.step();
|
|
} break;
|
|
case JF:
|
|
frame.jump(pop(stack).toBool() ? 1 : inst.X);
|
|
break;
|
|
case JMP:
|
|
frame.jump(inst.X);
|
|
break;
|
|
case LOOP: {
|
|
// stack: iteration_count, max_iter, cond, loop_carried_deps...
|
|
auto sframe = stack.end() - (inst.N + 1);
|
|
int64_t trip_count = sframe[0].toInt();
|
|
int64_t max_trip_count = sframe[1].toInt();
|
|
bool cond = sframe[2].toBool();
|
|
if (trip_count < max_trip_count && cond) {
|
|
sframe[2] = trip_count;
|
|
sframe[0] = trip_count + 1;
|
|
frame.step();
|
|
} else {
|
|
size_t n_loop_carried = inst.N - 2;
|
|
for (const auto i : c10::irange(n_loop_carried)) {
|
|
sframe[i] = std::move(sframe[i + 3]);
|
|
}
|
|
drop(stack, 3); // iteration_count, max_iter, cond
|
|
frame.jump(inst.X);
|
|
}
|
|
} break;
|
|
case RET:
|
|
leaveFrame();
|
|
if (!frames_.empty()) {
|
|
continue;
|
|
}
|
|
return false;
|
|
case LIST_CONSTRUCT: {
|
|
listConstruct(stack, *code.types_.at(inst.X), inst.N);
|
|
frame.step();
|
|
} break;
|
|
case LIST_UNPACK: {
|
|
listUnpack(stack, inst.X);
|
|
frame.step();
|
|
} break;
|
|
case TUPLE_CONSTRUCT: {
|
|
tupleConstruct(stack, inst.X);
|
|
frame.step();
|
|
} break;
|
|
case TUPLE_SLICE: {
|
|
tupleSlice(stack, inst.X, inst.X + inst.N);
|
|
frame.step();
|
|
} break;
|
|
case TUPLE_INDEX: {
|
|
tupleIndex(stack);
|
|
frame.step();
|
|
} break;
|
|
case RAISE_EXCEPTION: {
|
|
raiseExceptionWithMessage(stack);
|
|
frame.step();
|
|
} break;
|
|
case __IS__: {
|
|
is(stack);
|
|
frame.step();
|
|
} break;
|
|
case UN_INITIALIZED: {
|
|
unInitialized(stack);
|
|
frame.step();
|
|
} break;
|
|
case __ISNOT__: {
|
|
isNot(stack);
|
|
frame.step();
|
|
} break;
|
|
case FORMAT: {
|
|
format(stack, inst.X);
|
|
frame.step();
|
|
} break;
|
|
case DEVICE: {
|
|
device(stack);
|
|
frame.step();
|
|
} break;
|
|
case DTYPE: {
|
|
dtype(stack);
|
|
frame.step();
|
|
} break;
|
|
case DIM: {
|
|
dim(stack);
|
|
frame.step();
|
|
} break;
|
|
case __NOT__: {
|
|
_not(stack);
|
|
frame.step();
|
|
} break;
|
|
case DICT_INDEX: {
|
|
dictIndex(stack);
|
|
frame.step();
|
|
} break;
|
|
case TO_LIST: {
|
|
toList(stack);
|
|
frame.step();
|
|
} break;
|
|
case NUM_TO_TENSOR: {
|
|
numToTensorScalar(stack);
|
|
frame.step();
|
|
} break;
|
|
case IS_CUDA: {
|
|
isCuda(stack);
|
|
frame.step();
|
|
} break;
|
|
case DICT_CONSTRUCT: {
|
|
dictConstruct(stack, *code.types_.at(inst.X), inst.N);
|
|
frame.step();
|
|
} break;
|
|
case NAMED_TUPLE_CONSTRUCT: {
|
|
namedTupleConstruct(stack, code.types_.at(inst.X), inst.N);
|
|
frame.step();
|
|
} break;
|
|
case CREATE_OBJECT: {
|
|
auto type = code.types_.at(inst.X)->expect<c10::ClassType>();
|
|
createObject(stack, type);
|
|
frame.step();
|
|
} break;
|
|
case ISINSTANCE: {
|
|
at::ArrayRef<TypePtr> types(&code.types_.at(inst.X), inst.N);
|
|
isinstance(stack, types);
|
|
frame.step();
|
|
} break;
|
|
case WARN: {
|
|
drop(stack, 1);
|
|
// Note: Please don't move the pop(stack) code below into the
|
|
// TORCH_WARN macro since TORCH_WARN fails to evaluate its arguments
|
|
// when STRIP_ERROR_MESSAGES is defined (which happens for production
|
|
// mobile builds). This will cause the stack to be in an inconsistent
|
|
// state. It has previously resulted in a SEV (S22350).
|
|
const auto& sref = stack.back().toStringRef();
|
|
TORCH_WARN(sref);
|
|
stack.pop_back();
|
|
frame.step();
|
|
} break;
|
|
default:
|
|
AT_ERROR(toString(inst.op), " is invalid.");
|
|
}
|
|
|
|
if (!prev_value) {
|
|
enableRecordFunction(false);
|
|
}
|
|
// This exception must be caught first as it derived from c10::Error
|
|
} catch (c10::BackendRuntimeException& e) {
|
|
saveExceptionDebugHandles();
|
|
TORCH_RETHROW(e);
|
|
} catch (c10::Error& error) {
|
|
// Reason for catching and rethrowing the error is so that we can
|
|
// set the exception pc that is queried later
|
|
saveExceptionDebugHandles();
|
|
TORCH_RETHROW(error);
|
|
} catch (...) {
|
|
saveExceptionDebugHandles();
|
|
throw;
|
|
}
|
|
// for (auto val : stack) {
|
|
// if (val.isTensor()) {
|
|
// std::cout << val.toTensor().sizes() << std::endl;
|
|
// } else {
|
|
// std::cout << val << std::endl;
|
|
// }
|
|
// }
|
|
}
|
|
return false;
|
|
}
|
|
|
|
IValue& InterpreterState::reg(size_t reg) {
|
|
return *(registers_.end() - reg);
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace jit
|
|
} // namespace torch
|