pytorch/torch/csrc/jit/codegen/cuda/parser.cpp
jiej 1667aa6451 [CUDA_FUSER] Expand operation support for cuda fuser (#37849)
Summary:
This PR added more supported operations in CUDA fuser. We are covering major point-wise operations supported in legacy fuser.

In an attempt to adapt to legacy executor:
1. added an naive shape propagation pass on pytorch JIT IR;
2. small refactor on graph partitioning;
3. fallback interpreter execution of fusion group;
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37849

Reviewed By: yf225

Differential Revision: D21444320

Pulled By: soumith

fbshipit-source-id: 712e18ab8497f8d58a07e6f8d200cdab52cf0d74
2020-05-07 09:21:09 -07:00

513 lines
19 KiB
C++

#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/ir/constants.h>
#include <unordered_map>
#include <utility>
namespace torch {
namespace jit {
typedef Value JitValue;
typedef Node JitOp;
namespace fuser {
namespace cuda {
constexpr auto NUM_UNARY_OPS = 31;
constexpr auto NUM_BINARY_OPS = 24;
constexpr auto NUM_BINARY_OPS_WITH_ALPHA = 4;
namespace {
typedef Val* CgValue;
typedef Expr* CgOp;
typedef void (
*ParseFuncPtr)(const Node* const, std::unordered_map<size_t, CgValue>&);
// TODO: add a mutex to make it thread safe.
class IrParser {
private:
static const int nthreads = 128;
static const int unroll_factor = 4;
public:
IrParser(
std::shared_ptr<Graph> graph,
Fusion& fusion,
CudaKernel* cuda_kernel)
: graph_(std::move(graph)), fusion_(&fusion), cuda_kernel_(cuda_kernel) {
if (init_registry_) {
registerJitOperator();
init_registry_ = false;
}
}
// Fuses pointwise ops with loop unrolling (factor = 4).
void parse() {
FusionGuard fg(fusion_);
auto block = graph_->block();
// in case of broadcast, we don't support explicit broadcast, so we need to
// convert/expand all inputs tensors to comply to the broadcasted size.
// This supports very limited case, which we try to accomodate in graph
// partition, that we only merge nodes with identical output shapes.
int broadcast_dim =
block->outputs()[0]->type()->cast<TensorType>()->dim().value();
// register all inputs;
// shape propagation during parsing is effctively done in parsing rules, as
// we only explicitly register inputs in the graph.
for (auto val : block->inputs()) {
TORCH_CHECK(registerValue(val, broadcast_dim));
fusion_->addInput(value_map_[val->unique()]);
}
// TODO: disable unroll to ensure rand_like generates identical output as
// with eager mode
bool disable_unroll = false;
// compose nodes in topo order;
for (const JitOp* node : block->nodes()) {
processJitNode(node);
if (node->kind() == aten::rand_like) {
disable_unroll = true;
}
}
// mark output;
for (auto jit_output : block->outputs()) {
TensorView* out =
static_cast<TensorView*>(value_map_[jit_output->unique()]);
fusion_->addOutput(out);
// Merge all dimensions because we're only supporting pointwise
while (out->nDims() > 1)
out->merge(0);
// Split into 128 which will be bockDim.x
out->split(0, nthreads);
// Split by another 4 which will be our unroll factor
auto ur_factor = disable_unroll ? 1 : unroll_factor;
out->split(0, ur_factor);
cuda_kernel_->unroll_factor_ = ur_factor;
// Map blocks/threads
out->axis(0)->parallelize(ParallelType::BIDx);
out->axis(1)->parallelize(ParallelType::Unroll);
out->axis(-1)->parallelize(ParallelType::TIDx);
}
// Run through outputs, grab all inputs of outputs
// squeeze with computeAt to set overall structure.
for (auto jit_output : block->outputs()) {
TensorView* out =
static_cast<TensorView*>(value_map_[jit_output->unique()]);
for (Val* inp : fusion_->inputsOf(out)) {
if (inp->getValType().value() == ValType::TensorView)
static_cast<TensorView*>(inp)->computeAt(out, 1);
}
}
// Run through intermediates, unroll, and bind their axes
for (auto val : fusion_->vals()) {
if (fusion_->hasInput(val) || fusion_->hasOutput(val))
continue;
if (val->getValType().value() != ValType::TensorView)
continue;
TensorView* tv = static_cast<TensorView*>(val);
tv->axis(-2)->parallelize(ParallelType::Unroll);
tv->axis(-1)->parallelize(ParallelType::TIDx);
}
}
static bool canParseNode(const Node* const node) {
if (init_registry_) {
// TODO: mutex this guy;
registerJitOperator();
init_registry_ = false;
}
// match signature.
auto iter = jit_operator_registry_.find(node->kind());
if (iter == jit_operator_registry_.end()) {
return false;
}
for (auto& pair_op_func : iter->second) {
if (node->matches(pair_op_func.first->schema())) {
return true;
}
}
return false;
}
static void registerParseRule(
std::shared_ptr<Operator>& op,
ParseFuncPtr fn) {
jit_operator_registry_[Symbol::fromQualString(op->schema().name())]
.emplace_back(std::make_pair(op, fn));
}
private:
static void registerJitOperator() {
// Register parse-function for each JIT operator;
// This is a one-time look up, our hash registry indexes on the pointer in
// OperatorRegistry.
std::array<const char*, NUM_BINARY_OPS_WITH_ALPHA> BinaryOpWithAlpha = {
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"};
for (auto signature : BinaryOpWithAlpha) {
auto ptr_op = getOperatorForLiteral(signature);
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
static std::unordered_map<
Symbol,
std::pair<BinaryOpType, decltype(&add_alpha)>>
op_mapping({
{aten::add, std::make_pair(BinaryOpType::Add, &add_alpha)},
{aten::sub, std::make_pair(BinaryOpType::Sub, &sub_alpha)},
});
// TODO: handle scaling factor when it's not constant 1;
auto lhs = value_map[node->inputs()[0]->unique()];
auto rhs = value_map[node->inputs()[1]->unique()];
auto alpha = value_map[node->inputs()[2]->unique()];
if (alpha->isOneInt()) {
auto out = binaryOp(op_mapping[node->kind()].first, lhs, rhs);
value_map.emplace(node->output()->unique(), out);
} else {
auto out = op_mapping[node->kind()].second(lhs, rhs, alpha);
value_map.emplace(node->output()->unique(), out);
}
});
}
std::array<const char*, NUM_BINARY_OPS> BinaryOp = {
"aten::div(Tensor self, Tensor other) -> Tensor",
"aten::div(Tensor self, Scalar other) -> Tensor",
"aten::mul(Tensor self, Tensor other) -> Tensor",
"aten::mul(Tensor self, Scalar other) -> Tensor",
"aten::atan2(Tensor self, Tensor other) -> Tensor",
"aten::max(Tensor self, Tensor other) -> Tensor",
"aten::min(Tensor self, Tensor other) -> Tensor",
"aten::pow(Tensor self, Tensor exponent) -> Tensor",
"aten::pow(Tensor self, Scalar exponent) -> Tensor",
"aten::pow(Scalar self, Tensor exponent) -> Tensor",
"aten::remainder(Tensor self, Tensor other) -> Tensor",
"aten::fmod(Tensor self, Tensor other) -> Tensor",
"aten::eq(Tensor self, Tensor other) -> Tensor",
"aten::eq(Tensor self, Scalar other) -> Tensor",
"aten::ne(Tensor self, Tensor other) -> Tensor",
"aten::ne(Tensor self, Scalar other) -> Tensor",
"aten::ge(Tensor self, Tensor other) -> Tensor",
"aten::ge(Tensor self, Scalar other) -> Tensor",
"aten::gt(Tensor self, Tensor other) -> Tensor",
"aten::gt(Tensor self, Scalar other) -> Tensor",
"aten::le(Tensor self, Tensor other) -> Tensor",
"aten::le(Tensor self, Scalar other) -> Tensor",
"aten::lt(Tensor self, Tensor other) -> Tensor",
"aten::lt(Tensor self, Scalar other) -> Tensor"};
for (auto signature : BinaryOp) {
auto ptr_op = getOperatorForLiteral(signature);
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
static std::unordered_map<Symbol, BinaryOpType> op_mapping(
{{aten::div, BinaryOpType::Div},
{aten::mul, BinaryOpType::Mul},
{aten::add, BinaryOpType::Add},
{aten::sub, BinaryOpType::Sub},
{aten::atan2, BinaryOpType::Atan2},
{aten::min, BinaryOpType::Min},
{aten::max, BinaryOpType::Max},
{aten::pow, BinaryOpType::Pow},
{aten::remainder, BinaryOpType::Remainder},
{aten::fmod, BinaryOpType::Fmod},
{aten::lt, BinaryOpType::LT},
{aten::le, BinaryOpType::LE},
{aten::gt, BinaryOpType::GT},
{aten::ge, BinaryOpType::GE},
{aten::ne, BinaryOpType::NE},
{aten::eq, BinaryOpType::Eq}});
auto lhs = value_map[node->inputs()[0]->unique()];
auto rhs = value_map[node->inputs()[1]->unique()];
auto out = binaryOp(op_mapping[node->kind()], lhs, rhs);
value_map.emplace(node->output()->unique(), out);
});
}
// TODO: cast operations should be merged in.
std::array<const char*, NUM_UNARY_OPS> UnaryOp = {
"aten::neg(Tensor self) -> Tensor",
"aten::abs(Tensor self) -> Tensor",
"aten::log(Tensor self) -> Tensor",
"aten::log10(Tensor self) -> Tensor",
"aten::log1p(Tensor self) -> Tensor",
"aten::log2(Tensor self) -> Tensor",
"aten::lgamma(Tensor self) -> Tensor",
"aten::exp(Tensor self) -> Tensor",
"aten::expm1(Tensor self) -> Tensor",
"aten::erf(Tensor self) -> Tensor",
"aten::erfc(Tensor self) -> Tensor",
"aten::cos(Tensor self) -> Tensor",
"aten::acos(Tensor self) -> Tensor",
"aten::cosh(Tensor self) -> Tensor",
"aten::sin(Tensor self) -> Tensor",
"aten::asin(Tensor self) -> Tensor",
"aten::sinh(Tensor self) -> Tensor",
"aten::tan(Tensor self) -> Tensor",
"aten::tanh(Tensor self) -> Tensor",
"aten::atan(Tensor self) -> Tensor",
"aten::sqrt(Tensor self) -> Tensor",
"aten::rsqrt(Tensor self) -> Tensor",
"aten::ceil(Tensor self) -> Tensor",
"aten::floor(Tensor self) -> Tensor",
"aten::round(Tensor self) -> Tensor",
"aten::trunc(Tensor self) -> Tensor",
"aten::frac(Tensor self) -> Tensor",
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::sigmoid(Tensor self) -> Tensor",
"aten::gelu(Tensor self) -> Tensor",
};
for (auto signature : UnaryOp) {
auto ptr_op = getOperatorForLiteral(signature);
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
static std::unordered_map<Symbol, UnaryOpType> op_mapping({
{aten::neg, UnaryOpType::Neg},
{aten::abs, UnaryOpType::Abs},
{aten::log, UnaryOpType::Log},
{aten::log10, UnaryOpType::Log10},
{aten::log1p, UnaryOpType::Log1p},
{aten::log2, UnaryOpType::Log2},
{aten::lgamma, UnaryOpType::Lgamma},
{aten::exp, UnaryOpType::Exp},
{aten::expm1, UnaryOpType::Expm1},
{aten::erf, UnaryOpType::Erf},
{aten::erfc, UnaryOpType::Erfc},
{aten::cos, UnaryOpType::Cos},
{aten::acos, UnaryOpType::Acos},
{aten::cosh, UnaryOpType::Cosh},
{aten::sin, UnaryOpType::Sin},
{aten::asin, UnaryOpType::Asin},
{aten::sinh, UnaryOpType::Sinh},
{aten::tan, UnaryOpType::Tan},
{aten::tanh, UnaryOpType::Tanh},
{aten::atan, UnaryOpType::Atan},
{aten::sqrt, UnaryOpType::Sqrt},
{aten::rsqrt, UnaryOpType::Rsqrt},
{aten::ceil, UnaryOpType::Ceil},
{aten::floor, UnaryOpType::Floor},
{aten::round, UnaryOpType::Round},
{aten::trunc, UnaryOpType::Trunc},
{aten::frac, UnaryOpType::Frac},
{aten::reciprocal, UnaryOpType::Reciprocal},
{aten::relu, UnaryOpType::Relu},
{aten::sigmoid, UnaryOpType::Sigmoid},
{aten::gelu, UnaryOpType::Gelu},
});
auto operand = value_map[node->input()->unique()];
auto out = unaryOp(op_mapping[node->kind()], operand);
value_map.emplace(node->output()->unique(), out);
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
auto out = unaryOp(UnaryOpType::RandLike, operand);
value_map.emplace(node->output()->unique(), out);
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
auto th = value_map[node->inputs()[1]->unique()];
auto value = value_map[node->inputs()[2]->unique()];
auto out = threshold(operand, th, value);
value_map.emplace(node->output()->unique(), out);
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
// TODO: we need to get a proper lower bound per dtype in operand.
auto low = value_map.count(node->inputs()[1]->unique()) != 0
? value_map[node->inputs()[1]->unique()]
: new Float(std::numeric_limits<float>::min());
auto high = value_map.count(node->inputs()[2]->unique()) != 0
? value_map[node->inputs()[2]->unique()]
: new Float(std::numeric_limits<float>::max());
auto out = clamp(operand, low, high);
value_map.emplace(node->output()->unique(), out);
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
registerParseRule(
ptr_op,
[](const Node* const node,
std::unordered_map<size_t, CgValue>& value_map) -> void {
auto condition = value_map[node->inputs()[0]->unique()];
auto x = value_map[node->inputs()[1]->unique()];
auto y = value_map[node->inputs()[2]->unique()];
auto out = where(condition, x, y);
value_map.emplace(node->output()->unique(), out);
});
}
}
void processJitNode(const JitOp* node) {
if (node->kind() == prim::Constant) {
// partition doesn't take constant node explicitly, but it does and copy
// constant into subgraph. So we need to register constants in codegen IR;
for (auto output : node->outputs()) {
TORCH_CHECK(registerScalar(output));
}
} else {
auto iter = IrParser::jit_operator_registry_.find(node->kind());
// make sure we have a parser for the op;
TORCH_CHECK(
iter != IrParser::jit_operator_registry_.end(),
"CudaFusionGroup Parser doesn't handle operator kind(): ",
node->kind().toDisplayString());
for (auto& pair_op_func : iter->second) {
if (node->matches(pair_op_func.first->schema())) {
pair_op_func.second(node, value_map_);
return;
}
}
TORCH_CHECK(
false,
"CudaFusionGroup Parser doesn't recognize operator overload:",
canonicalSchemaString(node->schema()));
}
}
bool registerValue(const JitValue* val, int broadcast_dim = -1) {
return registerTensor(val, broadcast_dim) || registerScalar(val);
}
bool registerScalar(const JitValue* val) {
if (val->type()->isSubtypeOf(static_cast<c10::TypePtr>(FloatType::get()))) {
CgValue cg_val;
if (auto ival = constant_as<float>(val)) {
cg_val = new Float(ival.value());
} else {
cg_val = new Float();
}
value_map_.emplace(val->unique(), cg_val);
return true;
} else if (val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(IntType::get()))) {
CgValue cg_val;
if (auto ival = constant_as<int>(val)) {
cg_val = new Int(ival.value());
} else {
cg_val = new Int();
}
value_map_.emplace(val->unique(), cg_val);
return true;
} else if (val->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
// TODO: should we consider adding support for NoneType;
return true;
}
return false;
}
bool registerTensor(const JitValue* val, int broadcast_dim = -1) {
CgValue cg_val;
if (auto tensor_type = val->type()->cast<TensorType>()) {
// TODO: make this a static function in Tensor class;
// create tensor;
if (broadcast_dim >= 0) {
tensor_type = tensor_type->withDim(broadcast_dim);
}
// TODO: make this a static function in Tensor class;
// create tensor;
cg_val = new TensorView(tensor_type);
value_map_.emplace(val->unique(), cg_val);
return true;
}
return false;
}
std::shared_ptr<Graph> graph_;
Fusion* fusion_;
CudaKernel* cuda_kernel_;
// maps from JitValue::unique() to fusion Val;
std::unordered_map<size_t, CgValue> value_map_;
// parsing rule registry.
static std::unordered_map<
Symbol,
std::vector<std::pair<std::shared_ptr<Operator>, ParseFuncPtr>>>
jit_operator_registry_;
static bool init_registry_;
};
std::unordered_map<
Symbol,
std::vector<std::pair<std::shared_ptr<Operator>, ParseFuncPtr>>>
IrParser::jit_operator_registry_;
bool IrParser::init_registry_ = true;
} // namespace
bool isNodeParsible(const Node* const node) {
return IrParser::canParseNode(node);
}
void parseJitIR(
std::shared_ptr<Graph>& graph,
Fusion& fusion,
CudaKernel* cuda_kernel) {
IrParser parser(graph, fusion, cuda_kernel);
parser.parse();
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch