pytorch/torch/csrc/jit/runtime/static/impl.h
Hao Lu 8538a79bfe [jit][static] Basic executor (#43647)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43647

Nothing fancy, just a basic implementation of the graph executor without using stack machine.

Reviewed By: bwasti

Differential Revision: D23208413

fbshipit-source-id: e483bb6ad7ba8591bbe1767e669654d82f42c356
2020-08-28 23:20:07 -07:00

52 lines
1.2 KiB
C++

#pragma once
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/inliner.h>
#ifdef FBCODE_CAFFE2
#include <folly/container/F14Map.h>
#endif
namespace torch {
namespace jit {
class TORCH_API StaticRuntime {
public:
explicit StaticRuntime(std::shared_ptr<torch::jit::Graph> g)
: graph_(std::move(g)) {}
explicit StaticRuntime(const torch::jit::Module& m);
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
#ifdef FBCODE_CAFFE2
using ConstantMap = folly::F14FastMap<Value*, IValue>;
#else
using ConstantMap = std::unordered_map<Value*, IValue>;
#endif
private:
torch::jit::Module module_;
std::shared_ptr<torch::jit::Graph> graph_;
// Static runtime states
// Constant table (including weights)
ConstantMap constant_table_;
// The nodes we need to run
std::vector<std::pair<Node*, Operation>> nodes_;
void getInputIValues(
Node* node,
const ConstantMap& ws,
std::vector<IValue>& stack) const;
void runNodes(ConstantMap& ws_) const;
};
} // namespace jit
} // namespace torch