mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25187 The bytecode export flow: dump the bytecode format for the light weighted interpreter. * The bytecode is generated without input spec optimization. It would be more generic (input independent) with no obvious performance degradation (to be tested). * Main API: torch::jit::script::Module::save(filename, extra_files, bool *bytecode_format* = false). * Both bytecode and module object are exported in pickle format. * The module object (in data.pkl) is the same as the original JIT model. * The serializer is dependent on pickle only (no protobuf or Json). * The major functionality is forked in ScriptModuleSerializer2::serialize(). * The test loader is test_bc_export.cpp. * Simple APIs are added in Code and its implementation to get necessary information (instructions, operators and constants). * Since there's no dependency on graph/node, GetAttr is promoted from an operator to first-class instruction (https://github.com/pytorch/pytorch/pull/25151) . * Some definitions (instructions, writeArchive, etc) that are shared by full JIT and bytecode are pulled out of the local namespace (https://github.com/pytorch/pytorch/pull/25148). The output layout looks like: * folders of methods. * In each method folder (for example, forward/): * bytecode.pkl: instructions and operators * constants{.pkl,/}: constant list in constants.pkl. If there are tensors in constants, the binary tensor files in constants/ folder. * data{.pkl,/}: the module object, with binary tensor files in data/ folder. The same as in torchscript. Test Plan: Imported from OSS Differential Revision: D17076411 fbshipit-source-id: 46eb298e7320d1e585b0101effc0fcfd09219046
121 lines
3.3 KiB
C++
121 lines
3.3 KiB
C++
#include "interpreter.h"
|
|
#include <torch/csrc/jit/mobile/function.h>
|
|
#include <aten/src/ATen/core/operator_name.h>
|
|
|
|
namespace torch{
|
|
namespace jit{
|
|
char const * toString(OpCode op);
|
|
namespace mobile {
|
|
InterpreterState::InterpreterState(std::shared_ptr<Code> code) : code_(code) {
|
|
registers_.resize(code_->register_size_);
|
|
}
|
|
|
|
//InterpreterState::InterpreterState(Function* function)
|
|
// : function_(function) {
|
|
// registers_.resize(function->register_size());
|
|
//}
|
|
|
|
bool InterpreterState::run(Stack& stack) {
|
|
size_t pc = 0;
|
|
while (true) {
|
|
// std::cout << "RUNNING " << pc << " " << instructions_[pc];
|
|
// std::cout << std::endl;
|
|
// for (auto val : stack) {
|
|
// if (val.isTensor()) {
|
|
// std::cout << val.toTensor().sizes() << std::endl;
|
|
// } else {
|
|
// std::cout << val << std::endl;
|
|
// }
|
|
// }
|
|
Instruction inst = code_->instructions_[pc];
|
|
TORCH_CHECK(isOpSupportedInMobile(inst.op), toString(inst.op),
|
|
" is not supported in mobile module.");
|
|
switch (inst.op) {
|
|
case OP: {
|
|
c10::Dispatcher::singleton().callBoxed(*code_->operators_[inst.X], &stack);
|
|
++pc;
|
|
} break;
|
|
case LOAD:
|
|
stack.emplace_back(reg(inst.X));
|
|
++pc;
|
|
break;
|
|
case MOVE:
|
|
stack.emplace_back(std::move(reg(inst.X)));
|
|
++pc;
|
|
break;
|
|
case STORE:
|
|
reg(inst.X) = pop(stack);
|
|
++pc;
|
|
break;
|
|
case STOREN:
|
|
for (size_t i = inst.N; i > 0; --i) {
|
|
reg(inst.X + i - 1) = pop(stack);
|
|
}
|
|
++pc;
|
|
break;
|
|
case DROP:
|
|
pop(stack);
|
|
++pc;
|
|
break;
|
|
case DROPR:
|
|
reg(inst.X) = IValue();
|
|
++pc;
|
|
break;
|
|
case LOADC:
|
|
stack.emplace_back(code_->constants_[inst.X]);
|
|
++pc;
|
|
break;
|
|
case GET_ATTR: {
|
|
auto userObj = pop(stack).toObject();
|
|
auto value = userObj->getSlot(inst.X);
|
|
push(stack, std::move(value));
|
|
++pc;
|
|
} break;
|
|
case SET_ATTR: {
|
|
auto v = pop(stack);
|
|
auto userObj = pop(stack).toObject();
|
|
userObj->setSlot(inst.X, std::move(v));
|
|
++pc;
|
|
} break;
|
|
case JF:
|
|
pc += (pop(stack).toBool()) ? 1 : inst.X;
|
|
break;
|
|
case JMP:
|
|
pc += inst.X;
|
|
break;
|
|
case LOOP: {
|
|
// stack: iteration_count, max_iter, cond, loop_carried_deps...
|
|
auto frame = stack.end() - (inst.N + 1);
|
|
int64_t trip_count = frame[0].toInt();
|
|
int64_t max_trip_count = frame[1].toInt();
|
|
bool cond = frame[2].toBool();
|
|
if (trip_count < max_trip_count && cond) {
|
|
frame[2] = trip_count;
|
|
frame[0] = trip_count + 1;
|
|
++pc;
|
|
} else {
|
|
size_t n_loop_carried = inst.N - 2;
|
|
for (size_t i = 0; i < n_loop_carried; ++i) {
|
|
frame[i] = std::move(frame[i + 3]);
|
|
}
|
|
drop(stack, 3); // iteration_count, max_iter, cond
|
|
pc += inst.X;
|
|
}
|
|
} break;
|
|
case RET:
|
|
return false;
|
|
default:
|
|
AT_ERROR(toString(inst.op), " is invalid.");
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
IValue& InterpreterState::reg(size_t reg) {
|
|
return *(registers_.end() - reg);
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace torch
|
|
} // namespace jit
|