mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33717 Because of the special treatment of operator names for lite interpreter, all the operators used in lite interpreter are still prepended by "_". Add the necessary registrations for MNIST model. All the ops with autograd capability are included in torch_mobile_train. After rebase the selective build from D19649074 can be utilized to strip the unused ops. Note that this diff is for feasibility test. The training accuracy are not covered in the test. ghstack-source-id: 97780066 Test Plan: ``` buck run xplat/caffe2/fb/lite_trainer:lite_trainer -c pt.disable_gen_tracing=1 -c pt.static_dispatch=0 -- --model=/path/MnistModel.bc ``` {F227898221} Reviewed By: dreiss Differential Revision: D19743201 fbshipit-source-id: cacadd76f3729faa0018d147a69466bbf54312fd
161 lines
4.5 KiB
C++
161 lines
4.5 KiB
C++
#include "interpreter.h"
|
|
#include <torch/csrc/jit/mobile/function.h>
|
|
#include <ATen/core/operator_name.h>
|
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
|
|
|
#if defined(PYTORCH_MOBILE_OPERATOR_OBSERVER)
|
|
#include <torch/csrc/autograd/record_function.h>
|
|
#include <torch/csrc/jit/mobile/observer.h>
|
|
#endif
|
|
|
|
namespace torch{
|
|
namespace jit{
|
|
char const * toString(OpCode op);
|
|
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
|
namespace mobile {
|
|
InterpreterState::InterpreterState(std::shared_ptr<Code> code) : code_(std::move(code)) {
|
|
registers_.resize(code_->register_size_);
|
|
}
|
|
|
|
bool InterpreterState::run(Stack& stack) {
|
|
size_t pc = 0;
|
|
while (true) {
|
|
Instruction inst = code_->instructions_[pc];
|
|
|
|
// std::cout << "RUNNING " << pc << " " << code_->instructions_[pc];
|
|
// if (inst.op == OP) {
|
|
// std::cout << ", " << code_->op_names_[inst.X].name << "." <<
|
|
// code_->op_names_[inst.X].overload_name;
|
|
// }
|
|
// std::cout << std::endl;
|
|
// for (auto val : stack) {
|
|
// if (val.isTensor()) {
|
|
// std::cout << val.toTensor().sizes() << std::endl;
|
|
// } else {
|
|
// std::cout << val << std::endl;
|
|
// }
|
|
// }
|
|
switch (inst.op) {
|
|
case OP: {
|
|
#if defined(PYTORCH_MOBILE_OPERATOR_OBSERVER)
|
|
if (auto debug_info = at::getThreadLocalDebugInfo()) {
|
|
if (auto* mobile_debug_info = dynamic_cast<MobileDebugInfo*>(
|
|
debug_info.get())) {
|
|
mobile_debug_info->setOpIdx(pc);
|
|
}
|
|
}
|
|
RECORD_FUNCTION(code_->op_names_[inst.X].name, stack);
|
|
#endif
|
|
code_->operators_[inst.X](stack);
|
|
++pc;
|
|
} break;
|
|
case OPN: {
|
|
stack.push_back(inst.N);
|
|
code_->operators_[inst.X](stack);
|
|
++pc;
|
|
} break;
|
|
case LOAD:
|
|
stack.emplace_back(reg(inst.X));
|
|
++pc;
|
|
break;
|
|
case MOVE:
|
|
stack.emplace_back(std::move(reg(inst.X)));
|
|
++pc;
|
|
break;
|
|
case STORE:
|
|
reg(inst.X) = pop(stack);
|
|
++pc;
|
|
break;
|
|
case STOREN:
|
|
for (size_t i = inst.N; i > 0; --i) {
|
|
reg(inst.X + i - 1) = pop(stack);
|
|
}
|
|
++pc;
|
|
break;
|
|
case DROP:
|
|
pop(stack);
|
|
++pc;
|
|
break;
|
|
case DROPR:
|
|
reg(inst.X) = IValue();
|
|
++pc;
|
|
break;
|
|
case LOADC:
|
|
stack.emplace_back(code_->constants_[inst.X]);
|
|
++pc;
|
|
break;
|
|
case GET_ATTR: {
|
|
auto userObj = pop(stack).toObject();
|
|
auto value = userObj->getSlot(inst.X);
|
|
push(stack, std::move(value));
|
|
++pc;
|
|
} break;
|
|
case SET_ATTR: {
|
|
auto v = pop(stack);
|
|
auto userObj = pop(stack).toObject();
|
|
// Mobile only: since the number of slots is not known, resize the numAttributes
|
|
// before setSlot.
|
|
while (userObj->type()->numAttributes() <= inst.X) {
|
|
std::stringstream ss;
|
|
ss << userObj->type()->numAttributes();
|
|
userObj->type()->addAttribute(ss.str(), c10::NoneType::create());
|
|
}
|
|
userObj->setSlot(inst.X, std::move(v));
|
|
++pc;
|
|
} break;
|
|
case JF:
|
|
pc += (pop(stack).toBool()) ? 1 : inst.X;
|
|
break;
|
|
case JMP:
|
|
pc += inst.X;
|
|
break;
|
|
case LOOP: {
|
|
// stack: iteration_count, max_iter, cond, loop_carried_deps...
|
|
auto frame = stack.end() - (inst.N + 1);
|
|
int64_t trip_count = frame[0].toInt();
|
|
int64_t max_trip_count = frame[1].toInt();
|
|
bool cond = frame[2].toBool();
|
|
if (trip_count < max_trip_count && cond) {
|
|
frame[2] = trip_count;
|
|
frame[0] = trip_count + 1;
|
|
++pc;
|
|
} else {
|
|
size_t n_loop_carried = inst.N - 2;
|
|
for (size_t i = 0; i < n_loop_carried; ++i) {
|
|
frame[i] = std::move(frame[i + 3]);
|
|
}
|
|
drop(stack, 3); // iteration_count, max_iter, cond
|
|
pc += inst.X;
|
|
}
|
|
} break;
|
|
case RET:
|
|
return false;
|
|
case LIST_CONSTRUCT: {
|
|
auto type = code_->types_[inst.X]->expect<at::ListType>();
|
|
listConstruct(stack, type, inst.N);
|
|
++pc;
|
|
} break;
|
|
case TUPLE_CONSTRUCT: {
|
|
tupleConstruct(stack, inst.X);
|
|
++pc;
|
|
} break;
|
|
case WARN: {
|
|
drop(stack, 1);
|
|
AT_WARN(pop(stack).toStringRef());
|
|
++pc;
|
|
} break;
|
|
default:
|
|
AT_ERROR(toString(inst.op), " is invalid.");
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
IValue& InterpreterState::reg(size_t reg) {
|
|
return *(registers_.end() - reg);
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace torch
|
|
} // namespace jit
|