mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30734 What are specialized lists? The IValues that hold List[int], List[Tensor], and List[AnythingElse] are different C++ types. e.g. List[int] has a std::vector<int> while List[AnythingElse] holds a std::vector<IValue>. Why do we have specialized lists? When we first created the JIT we needed to bind the ATen C++ API which has std::vector<int>, std::vector<Tensor> as inputs. The easiest way to match this API was to make our IValues contain these same types. Conversion was just unwrapping the IValue, very easy and cheap. What is the problem with specialized lists? We end up with significant special cases through the compiler. Other types like Dict are not specialized. So in the Pickler, for instance, there is a single piece of logic to handle their serialization. For Lists, we end up with multiple cases. Furthermore, it doesn't match Python, leading to problems along translation boundaries. Our pickle serialization is slightly different than python, so it is harder to load objects from our IValue serialization as Python values. They also make it harder to provide an easy-to-use user API. We'd like to match pybind11 for C++ bindings to TorchScript. This would entail having a single torch::List class (untemplated) that can be used to construct inputs. This is made much harder if the underlying ivalue needs to be different depending on the type inside the list. The ideal case would be to have a constructor like ``` template<typename T> List(std::vector<T> foo); ``` It would then set up the type tags correctly based on type T, without the need for passing tags. Do specialized lists improve perf? Not in a way we have been able to measure. Our major concern initially was having to translate a std::vector<IValue> to std::vector<int> to call ATen functions. This was especially a concern for aten::_convolution which takes a number of mostly-constant lists of integers. However, when we measure the effect of actually having to do this conversion for an aten::_convolution, it does not take measurable time (benchmark results below). This is true even if you use a trivial convolution (e.g. 1x1x1), and comment out the actual convolution code. What are the issues removing them? This PR removes list specialization but keeps the serialization format, and IValue APIs almost exactly the same. The only visible change is that toTensorListRef and family have turned into toTensorVector because they now return by value a copy of the list as a vector. Further PRs can then clean up the complexity issues that arose from speclization. This will likely involve removing the isTensorList/isIntList functions, and refactoring the code that used them to work generically. At some point we will also change serialization to no longer write specialized lists in the pickle binary. This is forward incompatible, so will go in its own PR. Benchmark: ``` import torch import torch.nn as nn import torch.nn.functional as F import time class MnistNet(nn.Module): def __init__(self): super(MnistNet, self).__init__() self.conv1 = nn.Conv2d(1, 1, kernel_size=1) self.conv2 = nn.Conv2d(1, 1, kernel_size=1) def forward(self, x): for i in range(10): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) return x model = MnistNet() x = torch.rand(1, 1, 1, 1) r = torch.jit.trace(model, x ) r(x) r(x) r(x) r(x) print(torch.jit.last_executed_optimized_graph()) while True: b = time.time() for i in range(100): r(x) e = time.time() print(e - b) ``` Results (no observable difference): ``` Before (actual conv) 0.13251137733459473 0.13260436058044434 0.13276338577270508 0.1327497959136963 0.13250041007995605 0.13270330429077148 0.13290190696716309 0.13265132904052734 0.13274288177490234 0.1326758861541748 0.13253355026245117 0.13254785537719727 0.13260746002197266 0.13285017013549805 0.13264012336730957 0.132490873336792 0.13280034065246582 0.13243484497070312 0.1325232982635498 0.1326127052307129 0.13264131546020508 0.13274383544921875 0.13298296928405762 0.1326909065246582 ------------------- After (actual conv) 0.13127517700195312 0.13150334358215332 0.13092470169067383 0.13102364540100098 0.13134360313415527 0.13155555725097656 0.13314104080200195 0.13151955604553223 0.13160037994384766 0.1315293312072754 0.13137340545654297 0.13148093223571777 0.131455659866333 0.1327371597290039 0.13134026527404785 0.13152337074279785 0.13151192665100098 0.13165974617004395 0.13403725624084473 0.13251852989196777 0.13135504722595215 0.1315624713897705 0.1317615509033203 0.1314380168914795 0.13157200813293457 -------------------- The following replace the convolution operator with a no-op, to show that even if the conv op was made faster, then we still would not see a difference: Before (fake conv) 0.0069539546966552734 0.0069522857666015625 0.007120847702026367 0.007344722747802734 0.007689952850341797 0.007932662963867188 0.00761723518371582 0.007501363754272461 0.007532835006713867 0.007141828536987305 0.007174253463745117 0.007114410400390625 0.007071495056152344 ------------------ After (fake conv) 0.007458209991455078 0.007337093353271484 0.007268190383911133 0.007313251495361328 0.007306575775146484 0.007468700408935547 0.0073091983795166016 0.007308483123779297 0.007538318634033203 0.007356882095336914 0.007464170455932617 0.007372140884399414 ``` Test Plan: Imported from OSS Differential Revision: D18814702 Pulled By: zdevito fbshipit-source-id: 0371c73b63068fdc12f24b801371ea90f23531a6
150 lines
4.1 KiB
C++
150 lines
4.1 KiB
C++
#include "interpreter.h"
|
|
#include <torch/csrc/jit/mobile/function.h>
|
|
#include <ATen/core/operator_name.h>
|
|
|
|
#if defined(PYTORCH_MOBILE_OBSERVER)
|
|
#include <torch/csrc/autograd/record_function.h>
|
|
#include <torch/csrc/jit/mobile/observer.h>
|
|
#endif
|
|
|
|
namespace torch{
|
|
namespace jit{
|
|
char const * toString(OpCode op);
|
|
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
|
namespace mobile {
|
|
InterpreterState::InterpreterState(std::shared_ptr<Code> code) : code_(code) {
|
|
registers_.resize(code_->register_size_);
|
|
}
|
|
|
|
namespace {
|
|
template <typename dtype> // int64_t, bool, double
|
|
void listConstruct(Stack& stack, int num_inputs) {
|
|
auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
|
|
c10::List<dtype> vals(
|
|
fmap(inputs, [](const IValue& v) { return v.to<dtype>(); }));
|
|
drop(stack, num_inputs);
|
|
push(stack, std::move(vals));
|
|
}
|
|
}
|
|
|
|
bool InterpreterState::run(Stack& stack) {
|
|
size_t pc = 0;
|
|
while (true) {
|
|
Instruction inst = code_->instructions_[pc];
|
|
|
|
// std::cout << "RUNNING " << pc << " " << code_->instructions_[pc];
|
|
// if (inst.op == OP) {
|
|
// std::cout << ", " << code_->op_names_[inst.X].name << "." <<
|
|
// code_->op_names_[inst.X].overload_name;
|
|
// }
|
|
// std::cout << std::endl;
|
|
// for (auto val : stack) {
|
|
// if (val.isTensor()) {
|
|
// std::cout << val.toTensor().sizes() << std::endl;
|
|
// } else {
|
|
// std::cout << val << std::endl;
|
|
// }
|
|
// }
|
|
switch (inst.op) {
|
|
case OP: {
|
|
#if defined(PYTORCH_MOBILE_OBSERVER)
|
|
if (auto debug_info = at::getThreadLocalDebugInfo()) {
|
|
if (auto* mobile_debug_info = dynamic_cast<MobileDebugInfo*>(
|
|
debug_info.get())) {
|
|
mobile_debug_info->setOpIdx(pc);
|
|
}
|
|
}
|
|
RECORD_FUNCTION(code_->op_names_[inst.X].name, stack);
|
|
#endif
|
|
|
|
c10::Dispatcher::singleton().callBoxed(*code_->operators_[inst.X], &stack);
|
|
++pc;
|
|
} break;
|
|
case OPN: {
|
|
code_->vararg_operators_[inst.X](inst.N, stack);
|
|
++pc;
|
|
} break;
|
|
case LOAD:
|
|
stack.emplace_back(reg(inst.X));
|
|
++pc;
|
|
break;
|
|
case MOVE:
|
|
stack.emplace_back(std::move(reg(inst.X)));
|
|
++pc;
|
|
break;
|
|
case STORE:
|
|
reg(inst.X) = pop(stack);
|
|
++pc;
|
|
break;
|
|
case STOREN:
|
|
for (size_t i = inst.N; i > 0; --i) {
|
|
reg(inst.X + i - 1) = pop(stack);
|
|
}
|
|
++pc;
|
|
break;
|
|
case DROP:
|
|
pop(stack);
|
|
++pc;
|
|
break;
|
|
case DROPR:
|
|
reg(inst.X) = IValue();
|
|
++pc;
|
|
break;
|
|
case LOADC:
|
|
stack.emplace_back(code_->constants_[inst.X]);
|
|
++pc;
|
|
break;
|
|
case GET_ATTR: {
|
|
auto userObj = pop(stack).toObject();
|
|
auto value = userObj->getSlot(inst.X);
|
|
push(stack, std::move(value));
|
|
++pc;
|
|
} break;
|
|
case SET_ATTR: {
|
|
auto v = pop(stack);
|
|
auto userObj = pop(stack).toObject();
|
|
userObj->setSlot(inst.X, std::move(v));
|
|
++pc;
|
|
} break;
|
|
case JF:
|
|
pc += (pop(stack).toBool()) ? 1 : inst.X;
|
|
break;
|
|
case JMP:
|
|
pc += inst.X;
|
|
break;
|
|
case LOOP: {
|
|
// stack: iteration_count, max_iter, cond, loop_carried_deps...
|
|
auto frame = stack.end() - (inst.N + 1);
|
|
int64_t trip_count = frame[0].toInt();
|
|
int64_t max_trip_count = frame[1].toInt();
|
|
bool cond = frame[2].toBool();
|
|
if (trip_count < max_trip_count && cond) {
|
|
frame[2] = trip_count;
|
|
frame[0] = trip_count + 1;
|
|
++pc;
|
|
} else {
|
|
size_t n_loop_carried = inst.N - 2;
|
|
for (size_t i = 0; i < n_loop_carried; ++i) {
|
|
frame[i] = std::move(frame[i + 3]);
|
|
}
|
|
drop(stack, 3); // iteration_count, max_iter, cond
|
|
pc += inst.X;
|
|
}
|
|
} break;
|
|
case RET:
|
|
return false;
|
|
default:
|
|
AT_ERROR(toString(inst.op), " is invalid.");
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
IValue& InterpreterState::reg(size_t reg) {
|
|
return *(registers_.end() - reg);
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace torch
|
|
} // namespace jit
|