mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34160 I constructed the patch by deleting OperatorOptions and then rerouting all queries for AliasAnalysisKind to FunctionSchema. Some of the behavior is kind of bogus: we really shouldn't be mutating FunctionSchema after the fact, but that won't get fixed until we actually switch to true schema merging. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D20282846 Pulled By: ezyang fbshipit-source-id: ba7bca6e8adc3365789639b88e54c4e881b1692e
53 lines
1.4 KiB
C++
53 lines
1.4 KiB
C++
#include <torch/csrc/jit/codegen/fuser/fallback.h>
|
|
|
|
#include <ATen/core/functional.h> //fmap
|
|
#include <ATen/core/stack.h>
|
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
|
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
|
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
#include <stdexcept>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
|
|
namespace {
|
|
c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
|
|
return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
|
|
}
|
|
} // namespace
|
|
|
|
// Registers fused operators so that fused graphs can properly generate fallback
|
|
// code.
|
|
RegisterOperators reg_fused_operators({Operator(
|
|
prim::FusedConcat,
|
|
[](const Node* node) -> Operation {
|
|
int64_t dim = node->i(attr::dim);
|
|
int64_t num_inputs = node->inputs().size();
|
|
return [dim, num_inputs](Stack& stack) {
|
|
auto result = at::cat(
|
|
fmap(
|
|
last(stack, num_inputs),
|
|
[](const IValue& i) { return i.toTensor(); }),
|
|
dim);
|
|
drop(stack, num_inputs);
|
|
pack(stack, std::move(result));
|
|
return 0;
|
|
};
|
|
},
|
|
aliasAnalysisIsSpecialCase())});
|
|
|
|
void runFallback(int64_t key, Stack& stack) {
|
|
auto maybe_spec = retrieve(key);
|
|
if (!maybe_spec)
|
|
throw std::runtime_error("Failed to find fusion spec to run fallback.");
|
|
|
|
InterpreterState{(*maybe_spec)->code()}.run(stack);
|
|
}
|
|
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|