mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: The premise of this approach is that a small subset of neural networks are well represented by a data flow graph. The README contains more information. The name is subject to change, but I thought it was a cute reference to fire. suo let me know if you'd prefer this in a different spot. Since it lowers a JIT'd module directly I assumed the JIT folder would be appropriate. There is no exposed Python interface yet (but is mocked up in `test_accelerant.py`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/42753 Reviewed By: zou3519 Differential Revision: D23043771 Pulled By: bwasti fbshipit-source-id: 5353731e3aae31c08b5b49820815da98113eb551
46 lines
1.2 KiB
C++
46 lines
1.2 KiB
C++
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
StaticRuntime::StaticRuntime(
|
|
const torch::jit::Module& m,
|
|
std::shared_ptr<torch::jit::Graph> g)
|
|
: graph_(std::move(g)), module_(m.deepcopy()) {
|
|
Inline(*graph_);
|
|
ConstantPropagation(graph_);
|
|
for (auto n : graph_->nodes()) {
|
|
if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) {
|
|
throw std::runtime_error("Cannot accelerate unfrozen graphs");
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<at::Tensor> StaticRuntime::run(
|
|
const std::vector<at::Tensor>& inps) const {
|
|
std::vector<torch::jit::IValue> stack;
|
|
if (graph_->inputs().at(0)->type()->is_module()) {
|
|
stack.emplace_back(module_._ivalue());
|
|
}
|
|
for (const auto& inp : inps) {
|
|
stack.emplace_back(inp);
|
|
}
|
|
torch::jit::Code code(graph_, "");
|
|
torch::jit::InterpreterState interp(code);
|
|
interp.run(stack);
|
|
std::vector<at::Tensor> out;
|
|
for (const auto& v : stack) {
|
|
if (v.isTuple()) {
|
|
auto t = v.toTuple();
|
|
for (const auto& el : t->elements()) {
|
|
out.emplace_back(el.toTensor());
|
|
}
|
|
} else {
|
|
out.emplace_back(v.toTensor());
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|