#include #include #include #include #include #include #include #include namespace torch { namespace jit { typedef Value JitValue; typedef Node JitOp; namespace fuser { namespace cuda { constexpr auto NUM_UNARY_OPS = 31; constexpr auto NUM_BINARY_OPS = 24; constexpr auto NUM_BINARY_OPS_WITH_ALPHA = 4; constexpr auto NUM_LERP_OPS = 2; namespace { typedef Val* CgValue; typedef Expr* CgOp; typedef void ( *ParseFuncPtr)(const Node* const, std::unordered_map&); // TODO: add a mutex to make it thread safe. class IrParser { private: static const int nthreads = 128; static const int unroll_factor = 4; public: IrParser(std::shared_ptr graph, CudaKernel* cuda_kernel) : graph_(std::move(graph)), cuda_kernel_(cuda_kernel) { if (init_registry_) { registerJitOperator(); init_registry_ = false; } } // Fuses pointwise ops with loop unrolling (factor = 4). void parse() { FusionGuard fg(cuda_kernel_->fusion_.get()); auto block = graph_->block(); // in case of broadcast, we don't support explicit broadcast, so we need to // convert/expand all inputs tensors to comply to the broadcasted size. // This supports very limited case, which we try to accomodate in graph // partition, that we only merge nodes with identical output shapes. int broadcast_dim = block->outputs()[0]->type()->cast()->dim().value(); // register all inputs; // shape propagation during parsing is effctively done in parsing rules, as // we only explicitly register inputs in the graph. for (auto val : block->inputs()) { TORCH_CHECK(registerValue(val, broadcast_dim)); cuda_kernel_->fusion_->addInput(value_map_[val->unique()]); auto opt_dtype = value_map_[val->unique()]->getDataType(); // computation promotion, we cast fp16 inputs to fp32 and use promoted // type in the computation. if (opt_dtype.has_value() && opt_dtype.value() == DataType::Half) { Val* promoted_val = castOp(DataType::Float, value_map_[val->unique()]); value_map_[val->unique()] = promoted_val; } } // TODO: disable unroll to ensure rand_like generates identical output as // with eager mode bool disable_unroll = false; // compose nodes in topo order; for (const JitOp* node : block->nodes()) { processJitNode(node); if (node->kind() == aten::rand_like) { disable_unroll = true; } } // mark output; for (auto jit_output : block->outputs()) { TensorView* out = static_cast(value_map_[jit_output->unique()]); // demote output dtype to be match PyTorch JIT graph. auto tensor_type = jit_output->type()->cast(); TORCH_INTERNAL_ASSERT( tensor_type, "output of fusion group is not TensorType."); if (tensor_type->scalarType() == at::ScalarType::Half) { // No need to update value_map_ after this point. out = static_cast(castOp(DataType::Half, out)); } cuda_kernel_->fusion_->addOutput(out); // Merge all dimensions because we're only supporting pointwise while (out->nDims() > 1) out->merge(0, 1); // Split into 128 which will be bockDim.x out->split(0, nthreads); // Split by another 4 which will be our unroll factor auto ur_factor = disable_unroll ? 1 : unroll_factor; if (!disable_unroll) { out->split(0, ur_factor); cuda_kernel_->unroll_factor_ = ur_factor; } } // Run through outputs, grab all inputs of outputs // squeeze with computeAt to set overall structure. for (auto output : cuda_kernel_->fusion_->outputs()) { if (output->getValType() != ValType::TensorView) continue; TensorView* out_tv = static_cast(output); for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) { if (inp->getValType().value() == ValType::TensorView) static_cast(inp)->computeAt(out_tv, 1); } out_tv->axis(0)->parallelize(ParallelType::BIDx); } // Run through intermediates, unroll, and bind their axes for (auto val : cuda_kernel_->fusion_->vals()) { if (val->getValType().value() != ValType::TensorView) continue; TensorView* tv = static_cast(val); // Should be true for all intermediates, but if one isn't hooked // up right, skip it and hope for the best for now if (!disable_unroll && tv->nDims() == 3) { tv->axis(-2)->parallelize(ParallelType::Unroll); tv->axis(-1)->parallelize(ParallelType::TIDx); } else { if (tv->nDims() == 2) tv->axis(-1)->parallelize(ParallelType::TIDx); } } } static bool canParseNode(const Node* const node) { if (init_registry_) { // TODO: mutex this guy; registerJitOperator(); init_registry_ = false; } // match signature. auto iter = jit_operator_registry_.find(node->kind()); if (iter == jit_operator_registry_.end()) { return false; } for (auto& pair_op_func : iter->second) { if (node->matches(pair_op_func.first->schema())) { return true; } } return false; } static void registerParseRule( std::shared_ptr& op, ParseFuncPtr fn) { jit_operator_registry_[Symbol::fromQualString(op->schema().name())] .emplace_back(std::make_pair(op, fn)); } private: static void registerJitOperator() { // Register parse-function for each JIT operator; // This is a one-time look up, our hash registry indexes on the pointer in // OperatorRegistry. std::array BinaryOpWithAlpha = { "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"}; for (auto signature : BinaryOpWithAlpha) { auto ptr_op = getOperatorForLiteral(signature); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*); static std::unordered_map< Symbol, std::pair> op_mapping( {{aten::add, std::make_pair( BinaryOpType::Add, static_cast(&add_alpha))}, {aten::sub, std::make_pair( BinaryOpType::Sub, static_cast(&sub_alpha))}}); // TODO: handle scaling factor when it's not constant 1; auto lhs = value_map[node->inputs()[0]->unique()]; auto rhs = value_map[node->inputs()[1]->unique()]; auto alpha = value_map[node->inputs()[2]->unique()]; if (alpha->isOneInt()) { auto out = binaryOp(op_mapping[node->kind()].first, lhs, rhs); value_map.emplace(node->output()->unique(), out); } else { auto out = op_mapping[node->kind()].second(lhs, rhs, alpha); value_map.emplace(node->output()->unique(), out); } }); } std::array BinaryOp = { "aten::div(Tensor self, Tensor other) -> Tensor", "aten::div(Tensor self, Scalar other) -> Tensor", "aten::mul(Tensor self, Tensor other) -> Tensor", "aten::mul(Tensor self, Scalar other) -> Tensor", "aten::atan2(Tensor self, Tensor other) -> Tensor", "aten::max(Tensor self, Tensor other) -> Tensor", "aten::min(Tensor self, Tensor other) -> Tensor", "aten::pow(Tensor self, Tensor exponent) -> Tensor", "aten::pow(Tensor self, Scalar exponent) -> Tensor", "aten::pow(Scalar self, Tensor exponent) -> Tensor", "aten::remainder(Tensor self, Tensor other) -> Tensor", "aten::fmod(Tensor self, Tensor other) -> Tensor", "aten::eq(Tensor self, Tensor other) -> Tensor", "aten::eq(Tensor self, Scalar other) -> Tensor", "aten::ne(Tensor self, Tensor other) -> Tensor", "aten::ne(Tensor self, Scalar other) -> Tensor", "aten::ge(Tensor self, Tensor other) -> Tensor", "aten::ge(Tensor self, Scalar other) -> Tensor", "aten::gt(Tensor self, Tensor other) -> Tensor", "aten::gt(Tensor self, Scalar other) -> Tensor", "aten::le(Tensor self, Tensor other) -> Tensor", "aten::le(Tensor self, Scalar other) -> Tensor", "aten::lt(Tensor self, Tensor other) -> Tensor", "aten::lt(Tensor self, Scalar other) -> Tensor"}; for (auto signature : BinaryOp) { auto ptr_op = getOperatorForLiteral(signature); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { static std::unordered_map op_mapping( {{aten::div, BinaryOpType::Div}, {aten::mul, BinaryOpType::Mul}, {aten::add, BinaryOpType::Add}, {aten::sub, BinaryOpType::Sub}, {aten::atan2, BinaryOpType::Atan2}, {aten::min, BinaryOpType::Min}, {aten::max, BinaryOpType::Max}, {aten::pow, BinaryOpType::Pow}, {aten::remainder, BinaryOpType::Remainder}, {aten::fmod, BinaryOpType::Fmod}, {aten::lt, BinaryOpType::LT}, {aten::le, BinaryOpType::LE}, {aten::gt, BinaryOpType::GT}, {aten::ge, BinaryOpType::GE}, {aten::ne, BinaryOpType::NE}, {aten::eq, BinaryOpType::Eq}}); auto lhs = value_map[node->inputs()[0]->unique()]; auto rhs = value_map[node->inputs()[1]->unique()]; auto out = binaryOp(op_mapping[node->kind()], lhs, rhs); value_map.emplace(node->output()->unique(), out); }); } // TODO: cast operations should be merged in. std::array UnaryOp = { "aten::neg(Tensor self) -> Tensor", "aten::abs(Tensor self) -> Tensor", "aten::log(Tensor self) -> Tensor", "aten::log10(Tensor self) -> Tensor", "aten::log1p(Tensor self) -> Tensor", "aten::log2(Tensor self) -> Tensor", "aten::lgamma(Tensor self) -> Tensor", "aten::exp(Tensor self) -> Tensor", "aten::expm1(Tensor self) -> Tensor", "aten::erf(Tensor self) -> Tensor", "aten::erfc(Tensor self) -> Tensor", "aten::cos(Tensor self) -> Tensor", "aten::acos(Tensor self) -> Tensor", "aten::cosh(Tensor self) -> Tensor", "aten::sin(Tensor self) -> Tensor", "aten::asin(Tensor self) -> Tensor", "aten::sinh(Tensor self) -> Tensor", "aten::tan(Tensor self) -> Tensor", "aten::tanh(Tensor self) -> Tensor", "aten::atan(Tensor self) -> Tensor", "aten::sqrt(Tensor self) -> Tensor", "aten::rsqrt(Tensor self) -> Tensor", "aten::ceil(Tensor self) -> Tensor", "aten::floor(Tensor self) -> Tensor", "aten::round(Tensor self) -> Tensor", "aten::trunc(Tensor self) -> Tensor", "aten::frac(Tensor self) -> Tensor", "aten::reciprocal(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", "aten::sigmoid(Tensor self) -> Tensor", "aten::gelu(Tensor self) -> Tensor", }; for (auto signature : UnaryOp) { auto ptr_op = getOperatorForLiteral(signature); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { static std::unordered_map op_mapping({ {aten::neg, UnaryOpType::Neg}, {aten::abs, UnaryOpType::Abs}, {aten::log, UnaryOpType::Log}, {aten::log10, UnaryOpType::Log10}, {aten::log1p, UnaryOpType::Log1p}, {aten::log2, UnaryOpType::Log2}, {aten::lgamma, UnaryOpType::Lgamma}, {aten::exp, UnaryOpType::Exp}, {aten::expm1, UnaryOpType::Expm1}, {aten::erf, UnaryOpType::Erf}, {aten::erfc, UnaryOpType::Erfc}, {aten::cos, UnaryOpType::Cos}, {aten::acos, UnaryOpType::Acos}, {aten::cosh, UnaryOpType::Cosh}, {aten::sin, UnaryOpType::Sin}, {aten::asin, UnaryOpType::Asin}, {aten::sinh, UnaryOpType::Sinh}, {aten::tan, UnaryOpType::Tan}, {aten::tanh, UnaryOpType::Tanh}, {aten::atan, UnaryOpType::Atan}, {aten::sqrt, UnaryOpType::Sqrt}, {aten::rsqrt, UnaryOpType::Rsqrt}, {aten::ceil, UnaryOpType::Ceil}, {aten::floor, UnaryOpType::Floor}, {aten::round, UnaryOpType::Round}, {aten::trunc, UnaryOpType::Trunc}, {aten::frac, UnaryOpType::Frac}, {aten::reciprocal, UnaryOpType::Reciprocal}, {aten::relu, UnaryOpType::Relu}, {aten::sigmoid, UnaryOpType::Sigmoid}, {aten::gelu, UnaryOpType::Gelu}, }); auto operand = value_map[node->input()->unique()]; auto out = unaryOp(op_mapping[node->kind()], operand); value_map.emplace(node->output()->unique(), out); }); } { auto ptr_op = getOperatorForLiteral( "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { auto operand = value_map[node->inputs()[0]->unique()]; auto out = unaryOp(UnaryOpType::RandLike, operand); value_map.emplace(node->output()->unique(), out); }); } { auto ptr_op = getOperatorForLiteral( "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor"); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { auto operand = value_map[node->inputs()[0]->unique()]; auto th = value_map[node->inputs()[1]->unique()]; auto value = value_map[node->inputs()[2]->unique()]; auto out = threshold(operand, th, value); value_map.emplace(node->output()->unique(), out); }); } { auto ptr_op = getOperatorForLiteral( "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor"); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { auto operand = value_map[node->inputs()[0]->unique()]; // TODO: we need to get a proper lower bound per dtype in operand. auto low = value_map.count(node->inputs()[1]->unique()) != 0 ? value_map[node->inputs()[1]->unique()] : new Float(std::numeric_limits::min()); auto high = value_map.count(node->inputs()[2]->unique()) != 0 ? value_map[node->inputs()[2]->unique()] : new Float(std::numeric_limits::max()); auto out = clamp(operand, low, high); value_map.emplace(node->output()->unique(), out); }); } { auto ptr_op = getOperatorForLiteral( "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor"); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { auto condition = value_map[node->inputs()[0]->unique()]; auto x = value_map[node->inputs()[1]->unique()]; auto y = value_map[node->inputs()[2]->unique()]; auto out = where(condition, x, y); value_map.emplace(node->output()->unique(), out); }); } { std::array LerpOp = { "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor", "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"}; for (auto signature : LerpOp) { auto ptr_op = getOperatorForLiteral(signature); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { auto self = value_map[node->inputs()[0]->unique()]; auto end = value_map[node->inputs()[1]->unique()]; auto weight = value_map[node->inputs()[2]->unique()]; auto out = lerp(self, end, weight); value_map.emplace(node->output()->unique(), out); }); } } { auto ptr_op = getOperatorForLiteral( "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"); registerParseRule( ptr_op, [](const Node* const node, std::unordered_map& value_map) -> void { auto self = value_map[node->inputs()[0]->unique()]; auto tensor1 = value_map[node->inputs()[1]->unique()]; auto tensor2 = value_map[node->inputs()[2]->unique()]; auto value = value_map[node->inputs()[3]->unique()]; auto out = addcmul(self, tensor1, tensor2, value); value_map.emplace(node->output()->unique(), out); }); } } void processJitNode(const JitOp* node) { if (node->kind() == prim::Constant) { // partition doesn't take constant node explicitly, but it does and copy // constant into subgraph. So we need to register constants in codegen IR; for (auto output : node->outputs()) { TORCH_CHECK(registerScalar(output)); } } else { auto iter = IrParser::jit_operator_registry_.find(node->kind()); // make sure we have a parser for the op; TORCH_CHECK( iter != IrParser::jit_operator_registry_.end(), "CudaFusionGroup Parser doesn't handle operator kind(): ", node->kind().toDisplayString()); for (auto& pair_op_func : iter->second) { if (node->matches(pair_op_func.first->schema())) { pair_op_func.second(node, value_map_); return; } } TORCH_CHECK( false, "CudaFusionGroup Parser doesn't recognize operator overload:", canonicalSchemaString(node->schema())); } } bool registerValue(const JitValue* val, int broadcast_dim = -1) { return registerTensor(val, broadcast_dim) || registerScalar(val); } bool registerScalar(const JitValue* val) { if (val->type()->isSubtypeOf(static_cast(FloatType::get()))) { CgValue cg_val; if (auto ival = constant_as(val)) { cg_val = new Float(ival.value()); } else { cg_val = new Float(); } value_map_.emplace(val->unique(), cg_val); return true; } else if (val->type()->isSubtypeOf( static_cast(IntType::get()))) { CgValue cg_val; if (auto ival = constant_as(val)) { cg_val = new Int(ival.value()); } else { cg_val = new Int(); } value_map_.emplace(val->unique(), cg_val); return true; } else if (val->type()->isSubtypeOf( static_cast(NoneType::get()))) { // TODO: should we consider adding support for NoneType; return true; } return false; } bool registerTensor(const JitValue* val, int broadcast_dim = -1) { CgValue cg_val; if (auto tensor_type = val->type()->cast()) { // TODO: make this a static function in Tensor class; // create tensor; if (broadcast_dim >= 0) { tensor_type = tensor_type->withDim(broadcast_dim); } // TODO: make this a static function in Tensor class; // create tensor; cg_val = new TensorView(tensor_type); value_map_.emplace(val->unique(), cg_val); return true; } return false; } std::shared_ptr graph_; CudaKernel* cuda_kernel_; // maps from JitValue::unique() to fusion Val; std::unordered_map value_map_; // parsing rule registry. static std::unordered_map< Symbol, std::vector, ParseFuncPtr>>> jit_operator_registry_; static bool init_registry_; }; std::unordered_map< Symbol, std::vector, ParseFuncPtr>>> IrParser::jit_operator_registry_; bool IrParser::init_registry_ = true; } // namespace bool isNodeParsible(const Node* const node) { return IrParser::canParseNode(node); } void parseJitIR(std::shared_ptr& graph, CudaKernel* cuda_kernel) { IrParser parser(graph, cuda_kernel); parser.parse(); } } // namespace cuda } // namespace fuser } // namespace jit } // namespace torch