mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This PR added more supported operations in CUDA fuser. We are covering major point-wise operations supported in legacy fuser. In an attempt to adapt to legacy executor: 1. added an naive shape propagation pass on pytorch JIT IR; 2. small refactor on graph partitioning; 3. fallback interpreter execution of fusion group; Pull Request resolved: https://github.com/pytorch/pytorch/pull/37849 Reviewed By: yf225 Differential Revision: D21444320 Pulled By: soumith fbshipit-source-id: 712e18ab8497f8d58a07e6f8d200cdab52cf0d74
513 lines
19 KiB
C++
513 lines
19 KiB
C++
#include <torch/csrc/jit/codegen/cuda/parser.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
|
|
|
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
|
#include <torch/csrc/jit/ir/constants.h>
|
|
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
typedef Value JitValue;
|
|
typedef Node JitOp;
|
|
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
constexpr auto NUM_UNARY_OPS = 31;
|
|
constexpr auto NUM_BINARY_OPS = 24;
|
|
constexpr auto NUM_BINARY_OPS_WITH_ALPHA = 4;
|
|
|
|
namespace {
|
|
|
|
typedef Val* CgValue;
|
|
typedef Expr* CgOp;
|
|
|
|
typedef void (
|
|
*ParseFuncPtr)(const Node* const, std::unordered_map<size_t, CgValue>&);
|
|
|
|
// TODO: add a mutex to make it thread safe.
|
|
class IrParser {
|
|
private:
|
|
static const int nthreads = 128;
|
|
static const int unroll_factor = 4;
|
|
|
|
public:
|
|
IrParser(
|
|
std::shared_ptr<Graph> graph,
|
|
Fusion& fusion,
|
|
CudaKernel* cuda_kernel)
|
|
: graph_(std::move(graph)), fusion_(&fusion), cuda_kernel_(cuda_kernel) {
|
|
if (init_registry_) {
|
|
registerJitOperator();
|
|
init_registry_ = false;
|
|
}
|
|
}
|
|
|
|
// Fuses pointwise ops with loop unrolling (factor = 4).
|
|
void parse() {
|
|
FusionGuard fg(fusion_);
|
|
auto block = graph_->block();
|
|
|
|
// in case of broadcast, we don't support explicit broadcast, so we need to
|
|
// convert/expand all inputs tensors to comply to the broadcasted size.
|
|
// This supports very limited case, which we try to accomodate in graph
|
|
// partition, that we only merge nodes with identical output shapes.
|
|
int broadcast_dim =
|
|
block->outputs()[0]->type()->cast<TensorType>()->dim().value();
|
|
|
|
// register all inputs;
|
|
// shape propagation during parsing is effctively done in parsing rules, as
|
|
// we only explicitly register inputs in the graph.
|
|
for (auto val : block->inputs()) {
|
|
TORCH_CHECK(registerValue(val, broadcast_dim));
|
|
fusion_->addInput(value_map_[val->unique()]);
|
|
}
|
|
|
|
// TODO: disable unroll to ensure rand_like generates identical output as
|
|
// with eager mode
|
|
bool disable_unroll = false;
|
|
// compose nodes in topo order;
|
|
for (const JitOp* node : block->nodes()) {
|
|
processJitNode(node);
|
|
if (node->kind() == aten::rand_like) {
|
|
disable_unroll = true;
|
|
}
|
|
}
|
|
|
|
// mark output;
|
|
for (auto jit_output : block->outputs()) {
|
|
TensorView* out =
|
|
static_cast<TensorView*>(value_map_[jit_output->unique()]);
|
|
fusion_->addOutput(out);
|
|
|
|
// Merge all dimensions because we're only supporting pointwise
|
|
while (out->nDims() > 1)
|
|
out->merge(0);
|
|
// Split into 128 which will be bockDim.x
|
|
out->split(0, nthreads);
|
|
// Split by another 4 which will be our unroll factor
|
|
auto ur_factor = disable_unroll ? 1 : unroll_factor;
|
|
out->split(0, ur_factor);
|
|
cuda_kernel_->unroll_factor_ = ur_factor;
|
|
|
|
// Map blocks/threads
|
|
out->axis(0)->parallelize(ParallelType::BIDx);
|
|
out->axis(1)->parallelize(ParallelType::Unroll);
|
|
out->axis(-1)->parallelize(ParallelType::TIDx);
|
|
}
|
|
|
|
// Run through outputs, grab all inputs of outputs
|
|
// squeeze with computeAt to set overall structure.
|
|
for (auto jit_output : block->outputs()) {
|
|
TensorView* out =
|
|
static_cast<TensorView*>(value_map_[jit_output->unique()]);
|
|
|
|
for (Val* inp : fusion_->inputsOf(out)) {
|
|
if (inp->getValType().value() == ValType::TensorView)
|
|
static_cast<TensorView*>(inp)->computeAt(out, 1);
|
|
}
|
|
}
|
|
|
|
// Run through intermediates, unroll, and bind their axes
|
|
for (auto val : fusion_->vals()) {
|
|
if (fusion_->hasInput(val) || fusion_->hasOutput(val))
|
|
continue;
|
|
if (val->getValType().value() != ValType::TensorView)
|
|
continue;
|
|
TensorView* tv = static_cast<TensorView*>(val);
|
|
tv->axis(-2)->parallelize(ParallelType::Unroll);
|
|
tv->axis(-1)->parallelize(ParallelType::TIDx);
|
|
}
|
|
}
|
|
|
|
static bool canParseNode(const Node* const node) {
|
|
if (init_registry_) {
|
|
// TODO: mutex this guy;
|
|
registerJitOperator();
|
|
init_registry_ = false;
|
|
}
|
|
|
|
// match signature.
|
|
auto iter = jit_operator_registry_.find(node->kind());
|
|
if (iter == jit_operator_registry_.end()) {
|
|
return false;
|
|
}
|
|
for (auto& pair_op_func : iter->second) {
|
|
if (node->matches(pair_op_func.first->schema())) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static void registerParseRule(
|
|
std::shared_ptr<Operator>& op,
|
|
ParseFuncPtr fn) {
|
|
jit_operator_registry_[Symbol::fromQualString(op->schema().name())]
|
|
.emplace_back(std::make_pair(op, fn));
|
|
}
|
|
|
|
private:
|
|
static void registerJitOperator() {
|
|
// Register parse-function for each JIT operator;
|
|
// This is a one-time look up, our hash registry indexes on the pointer in
|
|
// OperatorRegistry.
|
|
|
|
std::array<const char*, NUM_BINARY_OPS_WITH_ALPHA> BinaryOpWithAlpha = {
|
|
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
|
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
|
|
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
|
"aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"};
|
|
for (auto signature : BinaryOpWithAlpha) {
|
|
auto ptr_op = getOperatorForLiteral(signature);
|
|
registerParseRule(
|
|
ptr_op,
|
|
[](const Node* const node,
|
|
std::unordered_map<size_t, CgValue>& value_map) -> void {
|
|
static std::unordered_map<
|
|
Symbol,
|
|
std::pair<BinaryOpType, decltype(&add_alpha)>>
|
|
op_mapping({
|
|
{aten::add, std::make_pair(BinaryOpType::Add, &add_alpha)},
|
|
{aten::sub, std::make_pair(BinaryOpType::Sub, &sub_alpha)},
|
|
});
|
|
// TODO: handle scaling factor when it's not constant 1;
|
|
auto lhs = value_map[node->inputs()[0]->unique()];
|
|
auto rhs = value_map[node->inputs()[1]->unique()];
|
|
auto alpha = value_map[node->inputs()[2]->unique()];
|
|
|
|
if (alpha->isOneInt()) {
|
|
auto out = binaryOp(op_mapping[node->kind()].first, lhs, rhs);
|
|
value_map.emplace(node->output()->unique(), out);
|
|
} else {
|
|
auto out = op_mapping[node->kind()].second(lhs, rhs, alpha);
|
|
value_map.emplace(node->output()->unique(), out);
|
|
}
|
|
});
|
|
}
|
|
|
|
std::array<const char*, NUM_BINARY_OPS> BinaryOp = {
|
|
"aten::div(Tensor self, Tensor other) -> Tensor",
|
|
"aten::div(Tensor self, Scalar other) -> Tensor",
|
|
"aten::mul(Tensor self, Tensor other) -> Tensor",
|
|
"aten::mul(Tensor self, Scalar other) -> Tensor",
|
|
"aten::atan2(Tensor self, Tensor other) -> Tensor",
|
|
"aten::max(Tensor self, Tensor other) -> Tensor",
|
|
"aten::min(Tensor self, Tensor other) -> Tensor",
|
|
"aten::pow(Tensor self, Tensor exponent) -> Tensor",
|
|
"aten::pow(Tensor self, Scalar exponent) -> Tensor",
|
|
"aten::pow(Scalar self, Tensor exponent) -> Tensor",
|
|
"aten::remainder(Tensor self, Tensor other) -> Tensor",
|
|
"aten::fmod(Tensor self, Tensor other) -> Tensor",
|
|
"aten::eq(Tensor self, Tensor other) -> Tensor",
|
|
"aten::eq(Tensor self, Scalar other) -> Tensor",
|
|
"aten::ne(Tensor self, Tensor other) -> Tensor",
|
|
"aten::ne(Tensor self, Scalar other) -> Tensor",
|
|
"aten::ge(Tensor self, Tensor other) -> Tensor",
|
|
"aten::ge(Tensor self, Scalar other) -> Tensor",
|
|
"aten::gt(Tensor self, Tensor other) -> Tensor",
|
|
"aten::gt(Tensor self, Scalar other) -> Tensor",
|
|
"aten::le(Tensor self, Tensor other) -> Tensor",
|
|
"aten::le(Tensor self, Scalar other) -> Tensor",
|
|
"aten::lt(Tensor self, Tensor other) -> Tensor",
|
|
"aten::lt(Tensor self, Scalar other) -> Tensor"};
|
|
for (auto signature : BinaryOp) {
|
|
auto ptr_op = getOperatorForLiteral(signature);
|
|
registerParseRule(
|
|
ptr_op,
|
|
[](const Node* const node,
|
|
std::unordered_map<size_t, CgValue>& value_map) -> void {
|
|
static std::unordered_map<Symbol, BinaryOpType> op_mapping(
|
|
{{aten::div, BinaryOpType::Div},
|
|
{aten::mul, BinaryOpType::Mul},
|
|
{aten::add, BinaryOpType::Add},
|
|
{aten::sub, BinaryOpType::Sub},
|
|
{aten::atan2, BinaryOpType::Atan2},
|
|
{aten::min, BinaryOpType::Min},
|
|
{aten::max, BinaryOpType::Max},
|
|
{aten::pow, BinaryOpType::Pow},
|
|
{aten::remainder, BinaryOpType::Remainder},
|
|
{aten::fmod, BinaryOpType::Fmod},
|
|
{aten::lt, BinaryOpType::LT},
|
|
{aten::le, BinaryOpType::LE},
|
|
{aten::gt, BinaryOpType::GT},
|
|
{aten::ge, BinaryOpType::GE},
|
|
{aten::ne, BinaryOpType::NE},
|
|
{aten::eq, BinaryOpType::Eq}});
|
|
auto lhs = value_map[node->inputs()[0]->unique()];
|
|
auto rhs = value_map[node->inputs()[1]->unique()];
|
|
|
|
auto out = binaryOp(op_mapping[node->kind()], lhs, rhs);
|
|
value_map.emplace(node->output()->unique(), out);
|
|
});
|
|
}
|
|
|
|
// TODO: cast operations should be merged in.
|
|
std::array<const char*, NUM_UNARY_OPS> UnaryOp = {
|
|
"aten::neg(Tensor self) -> Tensor",
|
|
"aten::abs(Tensor self) -> Tensor",
|
|
"aten::log(Tensor self) -> Tensor",
|
|
"aten::log10(Tensor self) -> Tensor",
|
|
"aten::log1p(Tensor self) -> Tensor",
|
|
"aten::log2(Tensor self) -> Tensor",
|
|
"aten::lgamma(Tensor self) -> Tensor",
|
|
"aten::exp(Tensor self) -> Tensor",
|
|
"aten::expm1(Tensor self) -> Tensor",
|
|
"aten::erf(Tensor self) -> Tensor",
|
|
"aten::erfc(Tensor self) -> Tensor",
|
|
"aten::cos(Tensor self) -> Tensor",
|
|
"aten::acos(Tensor self) -> Tensor",
|
|
"aten::cosh(Tensor self) -> Tensor",
|
|
"aten::sin(Tensor self) -> Tensor",
|
|
"aten::asin(Tensor self) -> Tensor",
|
|
"aten::sinh(Tensor self) -> Tensor",
|
|
"aten::tan(Tensor self) -> Tensor",
|
|
"aten::tanh(Tensor self) -> Tensor",
|
|
"aten::atan(Tensor self) -> Tensor",
|
|
"aten::sqrt(Tensor self) -> Tensor",
|
|
"aten::rsqrt(Tensor self) -> Tensor",
|
|
"aten::ceil(Tensor self) -> Tensor",
|
|
"aten::floor(Tensor self) -> Tensor",
|
|
"aten::round(Tensor self) -> Tensor",
|
|
"aten::trunc(Tensor self) -> Tensor",
|
|
"aten::frac(Tensor self) -> Tensor",
|
|
"aten::reciprocal(Tensor self) -> Tensor",
|
|
"aten::relu(Tensor self) -> Tensor",
|
|
"aten::sigmoid(Tensor self) -> Tensor",
|
|
"aten::gelu(Tensor self) -> Tensor",
|
|
};
|
|
for (auto signature : UnaryOp) {
|
|
auto ptr_op = getOperatorForLiteral(signature);
|
|
registerParseRule(
|
|
ptr_op,
|
|
[](const Node* const node,
|
|
std::unordered_map<size_t, CgValue>& value_map) -> void {
|
|
static std::unordered_map<Symbol, UnaryOpType> op_mapping({
|
|
{aten::neg, UnaryOpType::Neg},
|
|
{aten::abs, UnaryOpType::Abs},
|
|
{aten::log, UnaryOpType::Log},
|
|
{aten::log10, UnaryOpType::Log10},
|
|
{aten::log1p, UnaryOpType::Log1p},
|
|
{aten::log2, UnaryOpType::Log2},
|
|
{aten::lgamma, UnaryOpType::Lgamma},
|
|
{aten::exp, UnaryOpType::Exp},
|
|
{aten::expm1, UnaryOpType::Expm1},
|
|
{aten::erf, UnaryOpType::Erf},
|
|
{aten::erfc, UnaryOpType::Erfc},
|
|
{aten::cos, UnaryOpType::Cos},
|
|
{aten::acos, UnaryOpType::Acos},
|
|
{aten::cosh, UnaryOpType::Cosh},
|
|
{aten::sin, UnaryOpType::Sin},
|
|
{aten::asin, UnaryOpType::Asin},
|
|
{aten::sinh, UnaryOpType::Sinh},
|
|
{aten::tan, UnaryOpType::Tan},
|
|
{aten::tanh, UnaryOpType::Tanh},
|
|
{aten::atan, UnaryOpType::Atan},
|
|
{aten::sqrt, UnaryOpType::Sqrt},
|
|
{aten::rsqrt, UnaryOpType::Rsqrt},
|
|
{aten::ceil, UnaryOpType::Ceil},
|
|
{aten::floor, UnaryOpType::Floor},
|
|
{aten::round, UnaryOpType::Round},
|
|
{aten::trunc, UnaryOpType::Trunc},
|
|
{aten::frac, UnaryOpType::Frac},
|
|
{aten::reciprocal, UnaryOpType::Reciprocal},
|
|
{aten::relu, UnaryOpType::Relu},
|
|
{aten::sigmoid, UnaryOpType::Sigmoid},
|
|
{aten::gelu, UnaryOpType::Gelu},
|
|
});
|
|
auto operand = value_map[node->input()->unique()];
|
|
|
|
auto out = unaryOp(op_mapping[node->kind()], operand);
|
|
value_map.emplace(node->output()->unique(), out);
|
|
});
|
|
}
|
|
|
|
{
|
|
auto ptr_op = getOperatorForLiteral(
|
|
"aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
|
|
registerParseRule(
|
|
ptr_op,
|
|
[](const Node* const node,
|
|
std::unordered_map<size_t, CgValue>& value_map) -> void {
|
|
auto operand = value_map[node->inputs()[0]->unique()];
|
|
|
|
auto out = unaryOp(UnaryOpType::RandLike, operand);
|
|
value_map.emplace(node->output()->unique(), out);
|
|
});
|
|
}
|
|
|
|
{
|
|
auto ptr_op = getOperatorForLiteral(
|
|
"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
|
|
registerParseRule(
|
|
ptr_op,
|
|
[](const Node* const node,
|
|
std::unordered_map<size_t, CgValue>& value_map) -> void {
|
|
auto operand = value_map[node->inputs()[0]->unique()];
|
|
auto th = value_map[node->inputs()[1]->unique()];
|
|
auto value = value_map[node->inputs()[2]->unique()];
|
|
|
|
auto out = threshold(operand, th, value);
|
|
value_map.emplace(node->output()->unique(), out);
|
|
});
|
|
}
|
|
|
|
{
|
|
auto ptr_op = getOperatorForLiteral(
|
|
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
|
|
registerParseRule(
|
|
ptr_op,
|
|
[](const Node* const node,
|
|
std::unordered_map<size_t, CgValue>& value_map) -> void {
|
|
auto operand = value_map[node->inputs()[0]->unique()];
|
|
// TODO: we need to get a proper lower bound per dtype in operand.
|
|
auto low = value_map.count(node->inputs()[1]->unique()) != 0
|
|
? value_map[node->inputs()[1]->unique()]
|
|
: new Float(std::numeric_limits<float>::min());
|
|
auto high = value_map.count(node->inputs()[2]->unique()) != 0
|
|
? value_map[node->inputs()[2]->unique()]
|
|
: new Float(std::numeric_limits<float>::max());
|
|
|
|
auto out = clamp(operand, low, high);
|
|
value_map.emplace(node->output()->unique(), out);
|
|
});
|
|
}
|
|
|
|
{
|
|
auto ptr_op = getOperatorForLiteral(
|
|
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
|
|
registerParseRule(
|
|
ptr_op,
|
|
[](const Node* const node,
|
|
std::unordered_map<size_t, CgValue>& value_map) -> void {
|
|
auto condition = value_map[node->inputs()[0]->unique()];
|
|
auto x = value_map[node->inputs()[1]->unique()];
|
|
auto y = value_map[node->inputs()[2]->unique()];
|
|
|
|
auto out = where(condition, x, y);
|
|
value_map.emplace(node->output()->unique(), out);
|
|
});
|
|
}
|
|
}
|
|
|
|
void processJitNode(const JitOp* node) {
|
|
if (node->kind() == prim::Constant) {
|
|
// partition doesn't take constant node explicitly, but it does and copy
|
|
// constant into subgraph. So we need to register constants in codegen IR;
|
|
for (auto output : node->outputs()) {
|
|
TORCH_CHECK(registerScalar(output));
|
|
}
|
|
} else {
|
|
auto iter = IrParser::jit_operator_registry_.find(node->kind());
|
|
// make sure we have a parser for the op;
|
|
TORCH_CHECK(
|
|
iter != IrParser::jit_operator_registry_.end(),
|
|
"CudaFusionGroup Parser doesn't handle operator kind(): ",
|
|
node->kind().toDisplayString());
|
|
for (auto& pair_op_func : iter->second) {
|
|
if (node->matches(pair_op_func.first->schema())) {
|
|
pair_op_func.second(node, value_map_);
|
|
return;
|
|
}
|
|
}
|
|
TORCH_CHECK(
|
|
false,
|
|
"CudaFusionGroup Parser doesn't recognize operator overload:",
|
|
canonicalSchemaString(node->schema()));
|
|
}
|
|
}
|
|
|
|
bool registerValue(const JitValue* val, int broadcast_dim = -1) {
|
|
return registerTensor(val, broadcast_dim) || registerScalar(val);
|
|
}
|
|
|
|
bool registerScalar(const JitValue* val) {
|
|
if (val->type()->isSubtypeOf(static_cast<c10::TypePtr>(FloatType::get()))) {
|
|
CgValue cg_val;
|
|
if (auto ival = constant_as<float>(val)) {
|
|
cg_val = new Float(ival.value());
|
|
} else {
|
|
cg_val = new Float();
|
|
}
|
|
value_map_.emplace(val->unique(), cg_val);
|
|
return true;
|
|
} else if (val->type()->isSubtypeOf(
|
|
static_cast<c10::TypePtr>(IntType::get()))) {
|
|
CgValue cg_val;
|
|
if (auto ival = constant_as<int>(val)) {
|
|
cg_val = new Int(ival.value());
|
|
} else {
|
|
cg_val = new Int();
|
|
}
|
|
value_map_.emplace(val->unique(), cg_val);
|
|
return true;
|
|
} else if (val->type()->isSubtypeOf(
|
|
static_cast<c10::TypePtr>(NoneType::get()))) {
|
|
// TODO: should we consider adding support for NoneType;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool registerTensor(const JitValue* val, int broadcast_dim = -1) {
|
|
CgValue cg_val;
|
|
if (auto tensor_type = val->type()->cast<TensorType>()) {
|
|
// TODO: make this a static function in Tensor class;
|
|
// create tensor;
|
|
if (broadcast_dim >= 0) {
|
|
tensor_type = tensor_type->withDim(broadcast_dim);
|
|
}
|
|
// TODO: make this a static function in Tensor class;
|
|
// create tensor;
|
|
cg_val = new TensorView(tensor_type);
|
|
value_map_.emplace(val->unique(), cg_val);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
std::shared_ptr<Graph> graph_;
|
|
Fusion* fusion_;
|
|
CudaKernel* cuda_kernel_;
|
|
|
|
// maps from JitValue::unique() to fusion Val;
|
|
std::unordered_map<size_t, CgValue> value_map_;
|
|
// parsing rule registry.
|
|
static std::unordered_map<
|
|
Symbol,
|
|
std::vector<std::pair<std::shared_ptr<Operator>, ParseFuncPtr>>>
|
|
jit_operator_registry_;
|
|
static bool init_registry_;
|
|
};
|
|
|
|
std::unordered_map<
|
|
Symbol,
|
|
std::vector<std::pair<std::shared_ptr<Operator>, ParseFuncPtr>>>
|
|
IrParser::jit_operator_registry_;
|
|
bool IrParser::init_registry_ = true;
|
|
|
|
} // namespace
|
|
|
|
bool isNodeParsible(const Node* const node) {
|
|
return IrParser::canParseNode(node);
|
|
}
|
|
|
|
void parseJitIR(
|
|
std::shared_ptr<Graph>& graph,
|
|
Fusion& fusion,
|
|
CudaKernel* cuda_kernel) {
|
|
IrParser parser(graph, fusion, cuda_kernel);
|
|
parser.parse();
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|