pytorch/torch/csrc/jit/mobile/interpreter.cpp
Martin Yuan 01edb7450f [Lite Trainer] Add necessary registrations for MNIST model (#33717)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33717

Because of the special treatment of operator names for lite interpreter, all the operators used in lite interpreter are still prepended by "_". Add the necessary registrations for MNIST model. All the ops with autograd capability are included in torch_mobile_train. After rebase the selective build from D19649074 can be utilized to strip the unused ops.

Note that this diff is for feasibility test. The training accuracy are not covered in the test.
ghstack-source-id: 97780066

Test Plan:
```
buck run xplat/caffe2/fb/lite_trainer:lite_trainer -c pt.disable_gen_tracing=1 -c pt.static_dispatch=0 -- --model=/path/MnistModel.bc
```
{F227898221}

Reviewed By: dreiss

Differential Revision: D19743201

fbshipit-source-id: cacadd76f3729faa0018d147a69466bbf54312fd
2020-03-06 15:49:03 -08:00

161 lines
4.5 KiB
C++

#include "interpreter.h"
#include <torch/csrc/jit/mobile/function.h>
#include <ATen/core/operator_name.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#if defined(PYTORCH_MOBILE_OPERATOR_OBSERVER)
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/jit/mobile/observer.h>
#endif
namespace torch{
namespace jit{
char const * toString(OpCode op);
std::ostream& operator<<(std::ostream& out, Instruction inst);
namespace mobile {
InterpreterState::InterpreterState(std::shared_ptr<Code> code) : code_(std::move(code)) {
registers_.resize(code_->register_size_);
}
bool InterpreterState::run(Stack& stack) {
size_t pc = 0;
while (true) {
Instruction inst = code_->instructions_[pc];
// std::cout << "RUNNING " << pc << " " << code_->instructions_[pc];
// if (inst.op == OP) {
// std::cout << ", " << code_->op_names_[inst.X].name << "." <<
// code_->op_names_[inst.X].overload_name;
// }
// std::cout << std::endl;
// for (auto val : stack) {
// if (val.isTensor()) {
// std::cout << val.toTensor().sizes() << std::endl;
// } else {
// std::cout << val << std::endl;
// }
// }
switch (inst.op) {
case OP: {
#if defined(PYTORCH_MOBILE_OPERATOR_OBSERVER)
if (auto debug_info = at::getThreadLocalDebugInfo()) {
if (auto* mobile_debug_info = dynamic_cast<MobileDebugInfo*>(
debug_info.get())) {
mobile_debug_info->setOpIdx(pc);
}
}
RECORD_FUNCTION(code_->op_names_[inst.X].name, stack);
#endif
code_->operators_[inst.X](stack);
++pc;
} break;
case OPN: {
stack.push_back(inst.N);
code_->operators_[inst.X](stack);
++pc;
} break;
case LOAD:
stack.emplace_back(reg(inst.X));
++pc;
break;
case MOVE:
stack.emplace_back(std::move(reg(inst.X)));
++pc;
break;
case STORE:
reg(inst.X) = pop(stack);
++pc;
break;
case STOREN:
for (size_t i = inst.N; i > 0; --i) {
reg(inst.X + i - 1) = pop(stack);
}
++pc;
break;
case DROP:
pop(stack);
++pc;
break;
case DROPR:
reg(inst.X) = IValue();
++pc;
break;
case LOADC:
stack.emplace_back(code_->constants_[inst.X]);
++pc;
break;
case GET_ATTR: {
auto userObj = pop(stack).toObject();
auto value = userObj->getSlot(inst.X);
push(stack, std::move(value));
++pc;
} break;
case SET_ATTR: {
auto v = pop(stack);
auto userObj = pop(stack).toObject();
// Mobile only: since the number of slots is not known, resize the numAttributes
// before setSlot.
while (userObj->type()->numAttributes() <= inst.X) {
std::stringstream ss;
ss << userObj->type()->numAttributes();
userObj->type()->addAttribute(ss.str(), c10::NoneType::create());
}
userObj->setSlot(inst.X, std::move(v));
++pc;
} break;
case JF:
pc += (pop(stack).toBool()) ? 1 : inst.X;
break;
case JMP:
pc += inst.X;
break;
case LOOP: {
// stack: iteration_count, max_iter, cond, loop_carried_deps...
auto frame = stack.end() - (inst.N + 1);
int64_t trip_count = frame[0].toInt();
int64_t max_trip_count = frame[1].toInt();
bool cond = frame[2].toBool();
if (trip_count < max_trip_count && cond) {
frame[2] = trip_count;
frame[0] = trip_count + 1;
++pc;
} else {
size_t n_loop_carried = inst.N - 2;
for (size_t i = 0; i < n_loop_carried; ++i) {
frame[i] = std::move(frame[i + 3]);
}
drop(stack, 3); // iteration_count, max_iter, cond
pc += inst.X;
}
} break;
case RET:
return false;
case LIST_CONSTRUCT: {
auto type = code_->types_[inst.X]->expect<at::ListType>();
listConstruct(stack, type, inst.N);
++pc;
} break;
case TUPLE_CONSTRUCT: {
tupleConstruct(stack, inst.X);
++pc;
} break;
case WARN: {
drop(stack, 1);
AT_WARN(pop(stack).toStringRef());
++pc;
} break;
default:
AT_ERROR(toString(inst.op), " is invalid.");
}
}
return false;
}
IValue& InterpreterState::reg(size_t reg) {
return *(registers_.end() - reg);
}
} // namespace mobile
} // namespace torch
} // namespace jit