mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
Separating CUDA fuser from CPU fuser.
1. New node in IR - prim::CudaFusionGroup:
This enables the cuda fuser to co-exist along side the old fuser. Allows us
to incrementally build and expand cuda fuser.
2. copied FuseGraph optimization passes to CudaFuserGraph:
We will re-factor & reuse Chunk/Concat in the old fuser logic, which is
handled in the optimization pass at this moment. Unfortunately many code in
the pass is tightly binded with the legacy fuser, which makes code sharing
difficult.
The CudaFusionGraph will support only a subset of operations comparing to
legacy fuser (CUDA only). It is registered as a custom pass post fusion via
```torch._C._jit_register_cuda_fuser()```
To have it in effect, you should also turn off fusion on GPU via
```torch._C._jit_override_can_fuse_on_gpu(False)```
3. We don't have codegen in this PR yet (WIP). Currently we just fall back to
the old fuser.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33527
Differential Revision: D20171598
Pulled By: ZolotukhinM
fbshipit-source-id: 9a3c0f06f46da7eaa80ae7551c04869f5b03ef71
352 lines
11 KiB
C++
352 lines
11 KiB
C++
#include <ATen/ATen.h>
|
|
#include <ATen/core/alias_info.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/csrc/jit/frontend/edit_distance.h>
|
|
|
|
#include <queue>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
using OperatorMap =
|
|
std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
|
|
struct OperatorRegistry {
|
|
private:
|
|
std::mutex lock;
|
|
OperatorMap operators;
|
|
// list of operators whose schema have not yet been parsed, and must
|
|
// be registered before any call to lookup an operator
|
|
std::vector<std::shared_ptr<Operator>> to_register;
|
|
// Those two maps are used to implement lookupByLiteral, which is needed for
|
|
// the n->match(...) calls. Basically, every function schema is assigned a
|
|
// unique string you can use to match it. However, parsing those strings or
|
|
// comparing and hashing them character by character would be very slow, so we
|
|
// use a trick here! Every string literal in your program is guaranteed to
|
|
// have static storage duration and so its address won't change at runtime.
|
|
// This allows us to memoize answers for every pointer, which is done by the
|
|
// operators_by_sig_literal map. Still, this map is initially empty, and so we
|
|
// still need to do the complete string matching at the first time, which is
|
|
// implemented by performing a lookup in the operators_by_sig map.
|
|
std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
|
|
std::unordered_map<const char*, std::shared_ptr<Operator>>
|
|
operators_by_sig_literal;
|
|
|
|
// XXX - caller must be holding lock
|
|
void registerPendingOperators() {
|
|
for (const auto& op : to_register) {
|
|
Symbol sym = Symbol::fromQualString(op->schema().name());
|
|
operators[sym].push_back(op);
|
|
operators_by_sig[canonicalSchemaString(op->schema())] = op;
|
|
}
|
|
to_register.clear();
|
|
}
|
|
|
|
public:
|
|
void registerOperator(Operator&& op) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
to_register.push_back(std::make_shared<Operator>(std::move(op)));
|
|
}
|
|
|
|
const std::shared_ptr<Operator>& lookupByLiteral(const char* name) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
auto it = operators_by_sig_literal.find(name);
|
|
if (it == operators_by_sig_literal.end()) {
|
|
auto op_ptr_it =
|
|
operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
|
|
// Handy debugging code that dumps all operators we know about on mismatch
|
|
#if 0
|
|
if (op_ptr_it == operators_by_sig.end()) {
|
|
for (auto & entry : operators_by_sig) {
|
|
std::cout << entry.first << std::endl;
|
|
}
|
|
}
|
|
#endif
|
|
TORCH_CHECK(
|
|
op_ptr_it != operators_by_sig.end(),
|
|
"Couldn't find an operator for ",
|
|
name,
|
|
". Do you have to update a set of hardcoded JIT ops?");
|
|
it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
static std::vector<std::shared_ptr<Operator>> empty;
|
|
auto it = operators.find(name);
|
|
if (it != operators.end())
|
|
return it->second;
|
|
return empty;
|
|
}
|
|
|
|
std::vector<Symbol> findSimilarOperators(Symbol input_op) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
|
|
using EntryPair = std::pair<int64_t, Symbol>;
|
|
auto cmp = [](const EntryPair& lhs, const EntryPair& rhs) {
|
|
return lhs.first > rhs.first;
|
|
};
|
|
|
|
std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
|
|
rankings(cmp);
|
|
static constexpr size_t MAX_EDIT_DIST = 2u;
|
|
for (const auto& op : operators) {
|
|
auto edit_dist = script::ComputeEditDistance(
|
|
input_op.toQualString(), op.first.toQualString(), MAX_EDIT_DIST);
|
|
if (edit_dist <= MAX_EDIT_DIST) {
|
|
rankings.emplace(edit_dist, op.first);
|
|
}
|
|
}
|
|
std::vector<Symbol> ret;
|
|
while (!rankings.empty()) {
|
|
ret.push_back(rankings.top().second);
|
|
rankings.pop();
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>> getAllOperators() {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
std::vector<std::shared_ptr<Operator>> values;
|
|
values.clear();
|
|
for (auto & kv : operators) {
|
|
values.insert(values.end(), kv.second.begin(), kv.second.end());
|
|
}
|
|
return values;
|
|
}
|
|
};
|
|
|
|
OperatorRegistry& getRegistry() {
|
|
static OperatorRegistry r;
|
|
return r;
|
|
}
|
|
|
|
bool printerHasSpecialCaseFor(Symbol sym) {
|
|
using namespace at;
|
|
// WARNING: by adding a value to this set, you are asserting
|
|
// that you have also added special handling of this symbol to
|
|
// the python_print.cpp. Not adding handling will cause import and export
|
|
// of modules with this new operator to fail. This is only required
|
|
// for operators without schema. Prefer registering your operator with
|
|
// schema to editing this list here. These cases should only be things
|
|
// that require special handling because they do not fit normal schema
|
|
const static std::unordered_set<Symbol> handled = {
|
|
prim::Constant,
|
|
prim::Uninitialized,
|
|
prim::fork,
|
|
prim::ListConstruct,
|
|
prim::DictConstruct,
|
|
prim::ListUnpack,
|
|
prim::Print,
|
|
prim::PythonOp,
|
|
prim::TupleConstruct,
|
|
prim::TupleIndex,
|
|
prim::TupleSlice,
|
|
prim::TupleUnpack,
|
|
prim::CreateObject,
|
|
prim::GetAttr,
|
|
prim::SetAttr,
|
|
prim::CallFunction,
|
|
prim::isinstance,
|
|
prim::unchecked_cast,
|
|
prim::tolist,
|
|
prim::rpc_async,
|
|
};
|
|
|
|
// WARNING: by adding a value to this set, you are asserting that your
|
|
// primitive is only ever added during optimization and does not need
|
|
// to be correctly printed for export (a process that happens before
|
|
// optimization passes run)
|
|
const static std::unordered_set<Symbol> unneeded = {
|
|
c10::onnx::Reshape, // only used in onnx
|
|
c10::onnx::Shape, // only used in onnx
|
|
prim::AutogradZero, // temporarily inserted by autograd
|
|
prim::AutogradAnyNonZero, // temporarily inserted by autograd
|
|
prim::AutogradAdd, // temporarily inserted by autograd
|
|
prim::ConstantChunk, // optimization pass adds it
|
|
prim::DifferentiableGraph, // optimization pass adds it
|
|
prim::BroadcastSizes, // optimization pass (fuser) adds it
|
|
prim::ChunkSizes, // optimization pass (fuser) adds it
|
|
prim::Drop, // used in interpreter only
|
|
prim::FusedConcat, // optimization pass adds it
|
|
prim::FusionGroup, // optimization pass adds it
|
|
prim::CudaFusionGroup, // optimization pass adds it
|
|
prim::Load, // used in interpreter only
|
|
prim::MMTreeReduce, // used as an optimization
|
|
prim::MMBatchSide, // used as an optimization
|
|
prim::Store, // used in interpreter only
|
|
prim::profile, // used in interpreter only
|
|
|
|
};
|
|
|
|
// These namespaces are required to have Python printers unless
|
|
// otherwise noted in unneeded.
|
|
const static std::unordered_set<Symbol> required_namespaces = {
|
|
c10::namespaces::prim,
|
|
c10::namespaces::aten,
|
|
c10::namespaces::onnx,
|
|
};
|
|
|
|
return handled.count(sym) || unneeded.count(sym) ||
|
|
!required_namespaces.count(sym.ns());
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
|
using namespace at;
|
|
// WARNING: by adding a case to this list, you are asserting that you have
|
|
// added a case for the unschematized node in AliasDb::analyze
|
|
const static std::unordered_set<Symbol> handled = {
|
|
prim::If,
|
|
prim::Loop,
|
|
prim::FusionGroup,
|
|
prim::CudaFusionGroup,
|
|
prim::DifferentiableGraph,
|
|
prim::Constant,
|
|
prim::Uninitialized,
|
|
prim::DictConstruct,
|
|
prim::ListConstruct,
|
|
prim::TupleConstruct,
|
|
prim::AutogradZero,
|
|
prim::FusedConcat,
|
|
prim::GradOf,
|
|
prim::MMTreeReduce,
|
|
prim::MMBatchSide,
|
|
prim::BroadcastSizes,
|
|
prim::ChunkSizes,
|
|
prim::Function,
|
|
prim::TupleUnpack,
|
|
prim::TupleIndex,
|
|
prim::TupleSlice,
|
|
prim::ListUnpack,
|
|
prim::PythonOp,
|
|
prim::ConstantChunk,
|
|
prim::BroadcastingChunk,
|
|
prim::fork,
|
|
prim::CreateObject,
|
|
prim::AutogradAdd,
|
|
prim::GetAttr,
|
|
prim::SetAttr,
|
|
prim::profile,
|
|
prim::Print,
|
|
prim::CallFunction,
|
|
prim::CallMethod,
|
|
aten::wait,
|
|
prim::isinstance,
|
|
prim::unchecked_cast,
|
|
prim::tolist,
|
|
prim::rpc_async,
|
|
};
|
|
|
|
// Operators that should not be used by alias analysis
|
|
const static std::unordered_set<Symbol> purposefully_not_handled = {
|
|
prim::Load,
|
|
prim::Store,
|
|
prim::Drop,
|
|
at::onnx::Reshape,
|
|
at::onnx::Shape,
|
|
prim::AutogradAdd,
|
|
};
|
|
|
|
return handled.count(symbol) || purposefully_not_handled.count(symbol);
|
|
}
|
|
|
|
void registerOperator(Operator&& op) {
|
|
if (op.schema().is_varret()) {
|
|
Symbol s = Symbol::fromQualString(op.schema().name());
|
|
if (!printerHasSpecialCaseFor(s)) {
|
|
AT_ERROR(
|
|
"Missing special case in python printer for non-schematized"
|
|
" operator ",
|
|
op.schema().name(),
|
|
". File a bug to add a case for this operator.\n");
|
|
}
|
|
if (!aliasAnalysisHasSpecialCaseFor(s) &&
|
|
op.aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE) {
|
|
AT_ERROR(
|
|
"Missing special case in alias analysis for non-schematized"
|
|
" operator ",
|
|
op.schema().name(),
|
|
". File a bug to add a case for this operator.\n");
|
|
}
|
|
if (aliasAnalysisHasSpecialCaseFor(s) &&
|
|
op.aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA) {
|
|
AT_ERROR(
|
|
"The operator ",
|
|
op.schema().name(),
|
|
" is special cased and cannot use explicit alias analysis.");
|
|
}
|
|
}
|
|
getRegistry().registerOperator(std::move(op));
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>> getAllOperators() {
|
|
return getRegistry().getAllOperators();
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name) {
|
|
return getRegistry().getOperators(name);
|
|
}
|
|
|
|
std::shared_ptr<Operator> findOperatorFor(const c10::OperatorName& full_name) {
|
|
for (const auto& op : getRegistry().getOperators(Symbol::fromQualString(full_name.name))) {
|
|
if (op->schema().overload_name() == full_name.overload_name) {
|
|
return op;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<Symbol> findSimilarOperators(Symbol input_op) {
|
|
return getRegistry().findSimilarOperators(input_op);
|
|
}
|
|
|
|
std::shared_ptr<Operator> getOperatorForLiteral(const char* signature) {
|
|
return getRegistry().lookupByLiteral(signature);
|
|
}
|
|
|
|
std::string canonicalSchemaString(const FunctionSchema& schema) {
|
|
std::ostringstream out;
|
|
|
|
out << schema.name();
|
|
out << "(";
|
|
|
|
bool seen_kwarg_only = false;
|
|
for (size_t i = 0; i < schema.arguments().size(); ++i) {
|
|
if (i > 0)
|
|
out << ", ";
|
|
if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
|
|
out << "*, ";
|
|
seen_kwarg_only = true;
|
|
}
|
|
const auto& arg = schema.arguments()[i];
|
|
out << arg.type()->str() << " " << arg.name();
|
|
}
|
|
|
|
out << ") -> ";
|
|
if (schema.returns().size() == 1) {
|
|
out << schema.returns().at(0).type()->str();
|
|
} else if (schema.returns().size() > 1) {
|
|
out << "(";
|
|
for (size_t i = 0; i < schema.returns().size(); ++i) {
|
|
if (i > 0)
|
|
out << ", ";
|
|
out << schema.returns()[i].type()->str();
|
|
}
|
|
out << ")";
|
|
}
|
|
return out.str();
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|