pytorch/torch/csrc/jit/codegen/fuser/fallback.cpp
Michael Suo dbe850af5b [jit] do the code reorg (#33851)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851

Rationale and context described in #33828.

Script to reproduce the move:
https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9
ghstack-source-id: 99079645

Test Plan: Make sure CI passes

Reviewed By: jamesr66a

Differential Revision: D20133869

fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
2020-02-27 13:02:51 -08:00

55 lines
1.5 KiB
C++

#include <torch/csrc/jit/codegen/fuser/fallback.h>
#include <ATen/core/functional.h> //fmap
#include <ATen/core/stack.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/ir/ir.h>
#include <stdexcept>
namespace torch {
namespace jit {
namespace fuser {
namespace {
c10::OperatorOptions aliasAnalysisIsSpecialCase() {
c10::OperatorOptions options;
options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE);
return options;
}
} // namespace
// Registers fused operators so that fused graphs can properly generate fallback
// code.
RegisterOperators reg_fused_operators({Operator(
prim::FusedConcat,
[](const Node* node) -> Operation {
int64_t dim = node->i(attr::dim);
int64_t num_inputs = node->inputs().size();
return [dim, num_inputs](Stack& stack) {
auto result = at::cat(
fmap(
last(stack, num_inputs),
[](const IValue& i) { return i.toTensor(); }),
dim);
drop(stack, num_inputs);
pack(stack, std::move(result));
return 0;
};
},
aliasAnalysisIsSpecialCase())});
void runFallback(int64_t key, Stack& stack) {
auto maybe_spec = retrieve(key);
if (!maybe_spec)
throw std::runtime_error("Failed to find fusion spec to run fallback.");
InterpreterState{(*maybe_spec)->code()}.run(stack);
}
} // namespace fuser
} // namespace jit
} // namespace torch