mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported
nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor
Things reverted from local changes:
aten::gelu with approximation (tracked in PR: https://github.com/pytorch/pytorch/pull/61439)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72127
Reviewed By: HamidShojanazeri
Differential Revision: D34113233
Pulled By: jbschlosser
fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e)
1357 lines
46 KiB
C++
1357 lines
46 KiB
C++
#include <torch/csrc/jit/codegen/cuda/codegen.h>
|
|
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
|
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
|
#include <torch/csrc/jit/codegen/cuda/type.h>
|
|
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
|
|
|
#include <array>
|
|
#include <cmath>
|
|
#include <sstream>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
namespace codegen {
|
|
|
|
namespace {
|
|
|
|
class CudaKernelGenerator : private OptOutConstDispatch {
|
|
static constexpr const char* kTab = " ";
|
|
|
|
public:
|
|
static std::string generateKernelDefinition(
|
|
const kir::Kernel* kernel,
|
|
const std::string& kernel_name) {
|
|
CudaKernelGenerator codegen(kernel);
|
|
codegen.genDeclaration(kernel_name);
|
|
codegen.startBlock();
|
|
codegen.genPrologue();
|
|
codegen.genBody();
|
|
codegen.endBlock();
|
|
TORCH_CHECK(codegen.block_nest_level_ == 0);
|
|
return codegen.code_.str();
|
|
}
|
|
|
|
private:
|
|
explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {}
|
|
|
|
// Generates the kernel function declaration
|
|
void genDeclaration(const std::string& kernel_name) {
|
|
const auto& kernel_summary = kernel_->summary();
|
|
|
|
code_ << "__global__ void " << kernel_name << "(";
|
|
|
|
std::vector<Val*> params;
|
|
|
|
// Inputs & Outputs
|
|
for (auto val : kernel_->inputs()) {
|
|
params.push_back(val);
|
|
}
|
|
for (auto val : kernel_->outputs()) {
|
|
params.push_back(val);
|
|
}
|
|
|
|
// Generate parameter declarations
|
|
for (Val* val : params) {
|
|
if (const auto tv = dynamic_cast<TensorView*>(val)) {
|
|
if (tv->isCpuScalar()) {
|
|
code_ << " CpuScalarTensor<" << val->dtype() << "> " << varName(tv);
|
|
} else {
|
|
code_
|
|
<< "Tensor<" << val->dtype() << ", "
|
|
<< TensorDomain::noReductions(tv->getMaybeRFactorDomain()).size()
|
|
<< "> " << varName(tv);
|
|
}
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(val->isScalar()); // NOLINT (LLVM bug 48525)
|
|
TORCH_INTERNAL_ASSERT(val->definition() == nullptr);
|
|
code_ << val->dtype() << " " << gen(val);
|
|
}
|
|
|
|
if (val != params.back()) {
|
|
code_ << ", ";
|
|
}
|
|
}
|
|
|
|
// Global buffers
|
|
for (auto allocate : kernel_summary.global_allocations) {
|
|
TORCH_INTERNAL_ASSERT(allocate->buffer()->isA<TensorView>());
|
|
const auto tv = allocate->buffer()->as<TensorView>();
|
|
const auto& maybe_rfactor_domain = tv->domain()->hasRFactor()
|
|
? tv->domain()->getRFactorDomain()
|
|
: tv->domain()->getRootDomain();
|
|
const auto nDims = std::count_if(
|
|
maybe_rfactor_domain.begin(),
|
|
maybe_rfactor_domain.end(),
|
|
[](const IterDomain* id) {
|
|
return !id->isReduction() &&
|
|
id->getIterType() != IterType::BroadcastWithoutStride;
|
|
});
|
|
code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> "
|
|
<< varName(tv);
|
|
}
|
|
|
|
// Kernels generating random numbers take extra (seed, offset) arguments
|
|
if (kernel_summary.is_stochastic) {
|
|
code_ << ", at::PhiloxCudaState philox_args";
|
|
}
|
|
|
|
code_ << ") ";
|
|
}
|
|
|
|
// Generates setup code which is executed before the kernel body
|
|
void genPrologue() {
|
|
const auto& kernel_summary = kernel_->summary();
|
|
|
|
// Random number generator (optional)
|
|
if (kernel_summary.is_stochastic) {
|
|
indent()
|
|
<< "const auto idx = ((((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * blockDim.z + threadIdx.z) * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x;";
|
|
indent() << "auto offset = philox_args.captured_ ?\n";
|
|
indent()
|
|
<< " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
|
|
indent() << " philox_args.offset_.val;\n";
|
|
indent() << "Philox rnd(philox_args.seed_, idx, offset);\n";
|
|
}
|
|
|
|
// Do we have any dynamic shared memory buffers?
|
|
const bool has_dynamic_smem =
|
|
!kernel_summary.dynamic_smem_allocations.empty();
|
|
|
|
// Do we have any reductions?
|
|
const bool has_reductions = kernel_summary.has_block_reductions ||
|
|
kernel_summary.has_grid_reductions;
|
|
const bool has_parallel_welford =
|
|
kernel_summary.has_block_welford || kernel_summary.has_grid_welford;
|
|
|
|
// Shared memory
|
|
if (has_dynamic_smem || has_reductions || has_parallel_welford) {
|
|
indent() << "alignas("
|
|
#ifndef __HIP_PLATFORM_HCC__
|
|
<< dataTypeSize(kernel_summary.largest_smem_data_type)
|
|
#else
|
|
<< 8 // for HIP, we want 8-aligned even for smaller datatypes
|
|
#endif
|
|
<< ") extern __shared__ char array[];\n";
|
|
|
|
if (has_dynamic_smem) {
|
|
indent() << "unsigned offset = 0;\n";
|
|
}
|
|
|
|
if (has_reductions || has_parallel_welford) {
|
|
indent() << "void* shared_mem = array;\n";
|
|
if (has_dynamic_smem) {
|
|
if (has_parallel_welford) {
|
|
indent() << "offset += "
|
|
<< "((blockDim.x * blockDim.y * blockDim.z) * 3 * sizeof("
|
|
<< kernel_summary.largest_smem_data_type << "));\n";
|
|
} else {
|
|
indent() << "offset += "
|
|
<< "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
|
|
<< kernel_summary.largest_smem_data_type << "));\n";
|
|
}
|
|
}
|
|
|
|
if (has_parallel_welford) {
|
|
// Unpack shared mem pointer
|
|
auto space_type = kernel_summary.largest_smem_data_type;
|
|
indent()
|
|
<< "nvfuser_index_t block_size = blockDim.x*blockDim.y*blockDim.z;\n";
|
|
indent() << space_type << " *shared_mem_var = "
|
|
<< "static_cast<" << space_type << "*>("
|
|
<< "shared_mem);\n";
|
|
indent() << space_type
|
|
<< " *shared_mem_avg = shared_mem_var + block_size;\n";
|
|
indent() << space_type
|
|
<< " *shared_mem_n = shared_mem_avg + block_size;\n";
|
|
}
|
|
}
|
|
}
|
|
|
|
// Call the initialization function if using a custom block sync
|
|
if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
|
|
indent() << "block_sync::init();\n";
|
|
}
|
|
}
|
|
|
|
void genBody() {
|
|
for (auto expr : kernel_->topLevelExprs()) {
|
|
OptOutConstDispatch::handle(expr);
|
|
}
|
|
}
|
|
|
|
void startBlock(bool continuation = false) {
|
|
if (continuation) {
|
|
code_ << "{\n";
|
|
} else {
|
|
indent() << "{\n";
|
|
}
|
|
++block_nest_level_;
|
|
}
|
|
|
|
void endBlock(const char* sep = "\n") {
|
|
--block_nest_level_;
|
|
TORCH_CHECK(block_nest_level_ >= 0);
|
|
indent() << "}" << sep;
|
|
}
|
|
|
|
std::ostream& indent() {
|
|
for (const auto i : c10::irange(block_nest_level_)) {
|
|
(void)i; // Suppress unused variable warning
|
|
code_ << kTab;
|
|
}
|
|
return code_;
|
|
}
|
|
|
|
std::string gen(const Statement* stmt) {
|
|
std::stringstream tmp_code;
|
|
std::swap(tmp_code, code_);
|
|
auto replacement = replacement_map_.find(stmt);
|
|
if (replacement != replacement_map_.end()) {
|
|
stmt = replacement->second;
|
|
}
|
|
OptOutConstDispatch::handle(stmt);
|
|
std::swap(tmp_code, code_);
|
|
return tmp_code.str();
|
|
}
|
|
|
|
std::string varName(const Val* val) {
|
|
std::stringstream name;
|
|
if (val->isA<TensorView>()) {
|
|
name << "T";
|
|
} else {
|
|
name << typePrefix(val->dtype());
|
|
}
|
|
name << val->name();
|
|
return name.str();
|
|
}
|
|
|
|
std::string genInline(const Statement* stmt) {
|
|
const bool saved_inline = print_inline_;
|
|
print_inline_ = true;
|
|
auto result = gen(stmt);
|
|
print_inline_ = saved_inline;
|
|
// NOLINTNEXTLINE(performance-no-automatic-move)
|
|
return result;
|
|
}
|
|
|
|
void handle(const kir::Predicate* pred) final {
|
|
TORCH_INTERNAL_ASSERT(pred->hasValue());
|
|
code_ << gen(pred->value());
|
|
}
|
|
|
|
void handle(const Bool* pred) final {
|
|
const auto def = pred->definition();
|
|
if (print_inline_ && def != nullptr) {
|
|
code_ << "(" << gen(def) << ")";
|
|
} else if (pred->isConst()) {
|
|
code_ << (*pred->value() ? "true" : "false");
|
|
} else {
|
|
code_ << varName(pred);
|
|
}
|
|
}
|
|
|
|
void handle(const Double* d) final {
|
|
const auto def = d->definition();
|
|
if (print_inline_ && def != nullptr) {
|
|
code_ << "(" << gen(def) << ")";
|
|
} else if (d->isConst()) {
|
|
const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
|
|
code_ << std::setprecision(digits) << *d->value();
|
|
} else {
|
|
code_ << varName(d);
|
|
}
|
|
}
|
|
|
|
void handle(const Int* i) final {
|
|
const auto def = i->definition();
|
|
if (print_inline_ && def != nullptr) {
|
|
code_ << "(" << gen(def) << ")";
|
|
} else if (i->isConst()) {
|
|
code_ << *i->value();
|
|
} else {
|
|
code_ << varName(i);
|
|
}
|
|
}
|
|
|
|
void handle(const NamedScalar* ns) final {
|
|
// dim3 components are unsigned int. Cast to signed integer to
|
|
// support negative indexing
|
|
if (ns->getParallelIndex().has_value() ||
|
|
ns->getParallelDim().has_value()) {
|
|
code_ << "((nvfuser_index_t)" << ns->name() << ")";
|
|
} else {
|
|
code_ << ns->name();
|
|
}
|
|
}
|
|
|
|
void handle(const kir::TensorIndex* ti) final {
|
|
code_ << varName(ti->view()) << "[";
|
|
|
|
bool first = true;
|
|
for (auto* ind : ti->indices()) {
|
|
if (!ind->isZeroInt()) {
|
|
if (!first) {
|
|
code_ << " + ";
|
|
}
|
|
code_ << genInline(ind);
|
|
first = false;
|
|
}
|
|
}
|
|
|
|
if (first) {
|
|
code_ << "0";
|
|
}
|
|
|
|
code_ << "]";
|
|
}
|
|
|
|
void handle(const IterDomain*) final {
|
|
TORCH_INTERNAL_ASSERT(false, "Unreachable");
|
|
}
|
|
|
|
void handle(const TensorDomain*) final {
|
|
TORCH_INTERNAL_ASSERT(false, "Unreachable");
|
|
}
|
|
|
|
void handle(const TensorView*) final {
|
|
TORCH_INTERNAL_ASSERT(false, "Unreachable");
|
|
}
|
|
|
|
void handle(const UnaryOp* uop) final {
|
|
bool is_vector_op = false;
|
|
size_t vector_word_size = 1;
|
|
|
|
if (vectorize_scope_ && uop->out()->isA<kir::TensorIndex>()) {
|
|
auto ti = uop->out()->as<kir::TensorIndex>();
|
|
|
|
bool vectorize_op = false;
|
|
bool misaligned_op = false;
|
|
|
|
for (auto id : ti->view()->domain()->domain()) {
|
|
if (!isParallelTypeVectorize(id->getParallelType())) {
|
|
continue;
|
|
}
|
|
|
|
ExpressionEvaluator expr_eval(id->fusion());
|
|
auto vector_size_optional = expr_eval.evaluate(id->extent());
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
vector_size_optional.has_value(),
|
|
"Could not evaluate constant value bound to vectorized dim.");
|
|
|
|
vector_word_size = vector_size_optional.value();
|
|
|
|
vectorize_op = id->getParallelType() == ParallelType::Vectorize;
|
|
misaligned_op =
|
|
id->getParallelType() == ParallelType::MisalignedVectorize;
|
|
break;
|
|
}
|
|
|
|
if (vectorize_op) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
uop->getUnaryOpType() == UnaryOpType::Set,
|
|
"Cannot vectorize operations that are not sets. ",
|
|
"Use cache_before and cache_after to store/load with vectorized reads into buffers.");
|
|
is_vector_op = true;
|
|
}
|
|
|
|
if (misaligned_op) {
|
|
is_vector_op = (uop->getUnaryOpType() == UnaryOpType::Set);
|
|
}
|
|
|
|
if (is_vector_op && !uop->in()->isScalar()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
uop->out()->dtype() == uop->in()->dtype(),
|
|
"Vectorized store/load requires input and output datatypes match.");
|
|
}
|
|
}
|
|
|
|
if (is_vector_op) {
|
|
if (uop->in()->isScalar()) {
|
|
indent() << "reinterpret_cast<"
|
|
<< "Array<" << uop->out()->dtype() << ", " << vector_word_size
|
|
<< ">*>"
|
|
<< "(&" << gen(uop->out()) << ")->set(" << gen(uop->in())
|
|
<< ");\n";
|
|
} else {
|
|
indent() << "*reinterpret_cast<"
|
|
<< "Array<" << uop->out()->dtype() << ", " << vector_word_size
|
|
<< ">*>"
|
|
<< "(&" << gen(uop->out()) << ")"
|
|
<< " = *reinterpret_cast<"
|
|
<< "Array<" << uop->in()->dtype() << ", " << vector_word_size
|
|
<< ">*>"
|
|
<< "(&" << gen(uop->in()) << ");\n";
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (uop->out()->isA<NamedScalar>()) {
|
|
const auto op_type = uop->getUnaryOpType();
|
|
if (auto op = inline_op_str(op_type)) {
|
|
indent() << gen(uop->out()) << " = " << *op << genInline(uop->in())
|
|
<< ";\n";
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (!print_inline_) {
|
|
indent() << gen(uop->out());
|
|
if (!uop->out()->isScalar() && !uop->in()->isScalar()) {
|
|
code_ << "\n";
|
|
indent() << kTab;
|
|
}
|
|
code_ << " = ";
|
|
}
|
|
|
|
const auto op_type = uop->getUnaryOpType();
|
|
if (auto op = inline_op_str(op_type)) {
|
|
if (alsoBooleanOperator(op_type) &&
|
|
uop->out()->dtype() == DataType::Bool) {
|
|
code_ << stringifyBooleanOp(op_type) << gen(uop->in());
|
|
} else {
|
|
code_ << *op << gen(uop->in());
|
|
}
|
|
} else {
|
|
if (op_type == UnaryOpType::Cast) {
|
|
const auto cast_str =
|
|
cast_func_str({uop->in()->dtype(), uop->out()->dtype()});
|
|
TORCH_INTERNAL_ASSERT(
|
|
cast_str.has_value(),
|
|
"Invalid cast. Input type: ",
|
|
uop->in()->dtype(),
|
|
", output type: ",
|
|
uop->out()->dtype());
|
|
code_ << cast_str.value();
|
|
} else {
|
|
code_ << op_type;
|
|
if (needFloatSuffix(op_type) &&
|
|
uop->out()->dtype() == DataType::Float) {
|
|
code_ << "f";
|
|
}
|
|
}
|
|
|
|
code_ << "(";
|
|
if (op_type == UnaryOpType::RandLike) {
|
|
code_ << "rnd";
|
|
} else {
|
|
code_ << gen(uop->in());
|
|
}
|
|
code_ << ")";
|
|
}
|
|
|
|
if (!print_inline_) {
|
|
code_ << ";\n";
|
|
}
|
|
}
|
|
|
|
std::string genBinaryOp(
|
|
BinaryOpType op_type,
|
|
Val* out,
|
|
const std::string& lhs,
|
|
const std::string& rhs) {
|
|
std::stringstream expr;
|
|
if (auto op = inline_op_str(op_type)) {
|
|
expr << lhs << " ";
|
|
if (alsoBooleanOperator(op_type) && out->dtype() == DataType::Bool) {
|
|
expr << stringifyBooleanOp(op_type);
|
|
} else {
|
|
expr << *op;
|
|
}
|
|
expr << " " << rhs;
|
|
} else {
|
|
if (integer_op_str(op_type) && isIntegralType(out->dtype())) {
|
|
auto int_op = integer_op_str(op_type);
|
|
expr << *int_op;
|
|
} else {
|
|
expr << op_type;
|
|
if (needFloatSuffix(op_type) && out->dtype() == DataType::Float) {
|
|
expr << "f";
|
|
}
|
|
}
|
|
expr << "(" << lhs << ", " << rhs << ")";
|
|
}
|
|
return expr.str();
|
|
}
|
|
|
|
// If one argument is a tensorview and the other is a scalar, make sure we
|
|
// cast the scalar to the tensorview type
|
|
std::string scalarCast(Val* lhs, Val* rhs) {
|
|
// If neither are scalars return
|
|
if (!((lhs->isScalar() || rhs->isScalar()) &&
|
|
(lhs->isA<kir::TensorIndex>() || rhs->isA<kir::TensorIndex>()))) {
|
|
return "";
|
|
}
|
|
|
|
// Looking for mixed tensorview scalar options where types don't match
|
|
// but are either both floating or both int types. We should cast
|
|
// scalar to tensorview type in these instances.
|
|
auto lhs_t = lhs->dtype();
|
|
auto rhs_t = rhs->dtype();
|
|
|
|
// If same type, don't cast anything
|
|
if (lhs_t == rhs_t) {
|
|
return "";
|
|
}
|
|
|
|
// Don't do anything when dealing with bools
|
|
if (lhs_t == DataType::Bool || rhs_t == DataType::Bool) {
|
|
return "";
|
|
}
|
|
|
|
// Mixing floating and int combination
|
|
if ((isFloatingPointType(lhs_t) != isFloatingPointType(rhs_t)) ||
|
|
(isIntegralType(lhs_t) != isIntegralType(rhs_t))) {
|
|
return "";
|
|
}
|
|
|
|
std::stringstream cast;
|
|
cast << "(" << (lhs->isA<kir::TensorIndex>() ? lhs_t : rhs_t) << ") ";
|
|
return cast.str();
|
|
}
|
|
|
|
// If possible, replace pow with mul. Return true when successful.
|
|
bool genPowerWithMul(const BinaryOp* bop) {
|
|
if (bop->getBinaryOpType() != BinaryOpType::Pow) {
|
|
return false;
|
|
}
|
|
|
|
auto rhs = bop->rhs();
|
|
c10::optional<double> exponent;
|
|
if (auto val_int = dynamic_cast<Int*>(rhs)) {
|
|
if (val_int->isConst()) {
|
|
exponent = val_int->value().value();
|
|
}
|
|
} else if (auto val_float = dynamic_cast<Double*>(rhs)) {
|
|
if (val_float->isConst()) {
|
|
auto fp_exp = val_float->value().value();
|
|
double int_exp = 0;
|
|
if (std::modf(fp_exp, &int_exp) == 0) {
|
|
exponent = int_exp;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!exponent.has_value()) {
|
|
return false;
|
|
}
|
|
|
|
// Only **2 and **3 are considered
|
|
if (!(exponent.value() == 2 || exponent.value() == 3)) {
|
|
return false;
|
|
}
|
|
|
|
auto lhs = gen(bop->lhs());
|
|
|
|
if (print_inline_) {
|
|
code_ << lhs << " * " << lhs;
|
|
if (exponent.value() == 3) {
|
|
code_ << " * " << lhs;
|
|
}
|
|
} else {
|
|
indent() << gen(bop->out());
|
|
if (bop->out()->isScalar()) {
|
|
code_ << " = " << lhs << " * " << lhs;
|
|
if (exponent.value() == 3) {
|
|
code_ << " * " << lhs;
|
|
}
|
|
} else {
|
|
code_ << "\n";
|
|
indent() << kTab << "= " << lhs << "\n";
|
|
indent() << kTab << "* " << lhs;
|
|
if (exponent.value() == 3) {
|
|
code_ << "\n";
|
|
indent() << kTab << "* " << lhs;
|
|
}
|
|
}
|
|
}
|
|
|
|
code_ << ";\n";
|
|
return true;
|
|
}
|
|
|
|
void handle(const BinaryOp* bop) final {
|
|
// Try replacing pow with mul
|
|
if (genPowerWithMul(bop)) {
|
|
return;
|
|
}
|
|
|
|
const auto op_type = bop->getBinaryOpType();
|
|
if (print_inline_) {
|
|
// Inline expression: `lhs op rhs`
|
|
code_ << genBinaryOp(
|
|
op_type, bop->out(), gen(bop->lhs()), gen(bop->rhs()));
|
|
} else {
|
|
indent() << gen(bop->out());
|
|
if (bop->out()->isScalar()) {
|
|
// Single line: `out = lhs op rhs;`
|
|
code_ << " = "
|
|
<< genBinaryOp(
|
|
op_type, bop->out(), gen(bop->lhs()), gen(bop->rhs()));
|
|
} else {
|
|
// Split TensorView expressions across multiple lines:
|
|
//
|
|
// out
|
|
// = lhs
|
|
// op rhs;
|
|
//
|
|
|
|
auto cast = scalarCast(bop->lhs(), bop->rhs());
|
|
if (auto op = inline_op_str(op_type)) {
|
|
code_ << "\n";
|
|
indent() << kTab << "= " << (bop->lhs()->isScalar() ? cast : "")
|
|
<< gen(bop->lhs()) << "\n";
|
|
indent() << kTab;
|
|
if (alsoBooleanOperator(op_type) &&
|
|
bop->out()->dtype() == DataType::Bool) {
|
|
code_ << stringifyBooleanOp(op_type);
|
|
} else {
|
|
code_ << *op;
|
|
}
|
|
code_ << " " << (bop->rhs()->isScalar() ? cast : "")
|
|
<< gen(bop->rhs());
|
|
} else {
|
|
if (integer_op_str(op_type) && isIntegralType(bop->out()->dtype())) {
|
|
auto int_op = integer_op_str(op_type);
|
|
code_ << " = " << *int_op << "(\n";
|
|
} else {
|
|
std::stringstream op_str;
|
|
op_str << op_type;
|
|
if (needFloatSuffix(op_type) &&
|
|
bop->out()->dtype() == DataType::Float) {
|
|
op_str << "f";
|
|
}
|
|
code_ << " = " << op_str.str() << "(\n";
|
|
}
|
|
indent() << kTab << (bop->lhs()->isScalar() ? cast : "")
|
|
<< gen(bop->lhs()) << ",\n";
|
|
indent() << kTab << (bop->rhs()->isScalar() ? cast : "")
|
|
<< gen(bop->rhs()) << ")";
|
|
}
|
|
}
|
|
code_ << ";\n";
|
|
}
|
|
}
|
|
|
|
void handle(const TernaryOp* top) final {
|
|
if (!print_inline_) {
|
|
indent() << gen(top->out());
|
|
if (!top->out()->isScalar()) {
|
|
code_ << "\n";
|
|
indent() << kTab;
|
|
}
|
|
code_ << " = ";
|
|
}
|
|
|
|
code_ << top->getTernaryOpType() << "(" << gen(top->in1()) << ", ";
|
|
|
|
// Make sure the two operands of where has the same
|
|
// type. Note that compiling "where(0.0f, 0.0)" fails because of
|
|
// the overloading ambiguity.
|
|
if (top->getTernaryOpType() == TernaryOpType::Where) {
|
|
auto cast = scalarCast(top->in2(), top->in3());
|
|
code_ << (top->in2()->isScalar() ? cast : "") << gen(top->in2()) << ", "
|
|
<< (top->in3()->isScalar() ? cast : "") << gen(top->in3()) << ")";
|
|
} else {
|
|
code_ << gen(top->in2()) << ", " << gen(top->in3()) << ")";
|
|
}
|
|
|
|
if (!print_inline_) {
|
|
code_ << ";\n";
|
|
}
|
|
}
|
|
|
|
std::string genReductionOp(BinaryOpType op_type, Val* out) {
|
|
std::stringstream lambda;
|
|
DataType data_type = out->dtype();
|
|
lambda << "[](" << data_type << " &a, " << data_type << " b) "
|
|
<< "{ a = " << genBinaryOp(op_type, out, "a", "b") << "; }";
|
|
return lambda.str();
|
|
}
|
|
|
|
void handle(const BroadcastOp* stmt) final {
|
|
TORCH_INTERNAL_ASSERT(stmt->out()->isA<kir::TensorIndex>());
|
|
const auto tensor_index = stmt->out()->as<kir::TensorIndex>();
|
|
|
|
const ParallelTypeBitmap parallel_types =
|
|
kernel_->summary().broadcast_parallel_types.at(stmt);
|
|
|
|
if (parallel_types.none()) {
|
|
// Not parallelized
|
|
indent() << gen(stmt->out()) << "\n";
|
|
indent() << kTab << " = " << gen(stmt->in()) << ";\n";
|
|
return;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
!parallel_types.hasBID(),
|
|
"Parallel broadcast across blocks should have been translated to a GridBroadcast IR node");
|
|
|
|
std::stringstream flags_str;
|
|
for (const ParallelType pt : kParallelTypeTIDs) {
|
|
const bool parallel_bcast = parallel_types.get(pt);
|
|
if (pt != kParallelTypeTIDs[0]) {
|
|
flags_str << ", ";
|
|
}
|
|
flags_str << (parallel_bcast ? "true" : "false");
|
|
}
|
|
|
|
const auto data_type = stmt->out()->dtype();
|
|
indent() << "broadcast::blockBroadcast<" << flags_str.str() << ">(\n";
|
|
indent() << kTab << gen(stmt->out()) << ",\n";
|
|
indent() << kTab << gen(stmt->in()) << ",\n";
|
|
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
stmt->predicate() != nullptr && stmt->predicate()->hasValue());
|
|
indent() << kTab << genInline(stmt->predicate()) << ");\n";
|
|
}
|
|
|
|
void genWarpReductionOp(
|
|
const ReductionOp* rop,
|
|
const IterDomain* reduction_id) {
|
|
bool is_single_warp =
|
|
kernel_->getWarpPaddedParallelInfo().is_tidx_single_warp;
|
|
|
|
indent() << "warp::warpReduceTIDX";
|
|
if (is_single_warp) {
|
|
code_ << "<true>(\n";
|
|
} else {
|
|
code_ << "<false>(\n";
|
|
}
|
|
indent() << kTab << gen(rop->out()) << ",\n";
|
|
indent() << kTab << gen(rop->in()) << ",\n";
|
|
indent() << kTab << genReductionOp(rop->getReductionOpType(), rop->out())
|
|
<< ",\n";
|
|
indent() << kTab << "threadIdx,\n";
|
|
indent() << kTab << "blockDim,\n";
|
|
indent() << kTab << "static_cast<" << rop->out()->dtype()
|
|
<< "*>(shared_mem),\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
rop->predicate() != nullptr && rop->predicate()->hasValue());
|
|
indent() << kTab << genInline(rop->predicate()) << ",\n";
|
|
indent() << kTab << rop->out()->dtype() << "(" << genInline(rop->init())
|
|
<< "));\n";
|
|
}
|
|
|
|
void handle(const ReductionOp* rop) final {
|
|
TORCH_INTERNAL_ASSERT(rop->out()->isA<kir::TensorIndex>());
|
|
|
|
const auto out = rop->out()->as<kir::TensorIndex>();
|
|
const auto domain = out->view()->domain();
|
|
|
|
const bool has_block_reduce = domain->hasBlockReduction();
|
|
const bool has_grid_reduce = domain->hasGridReduction();
|
|
|
|
if (!has_block_reduce && !has_grid_reduce) {
|
|
const auto gen_out = gen(out);
|
|
const auto op_type = rop->getReductionOpType();
|
|
indent() << gen_out << " = "
|
|
<< genBinaryOp(op_type, out, gen_out, gen(rop->in())) << ";\n";
|
|
return;
|
|
}
|
|
|
|
if (auto reduction_id = ir_utils::getMaybeWarpReductionDim(rop)) {
|
|
genWarpReductionOp(rop, reduction_id.value());
|
|
return;
|
|
}
|
|
|
|
const auto par_domains = ir_utils::getParallelDomains(rop->out());
|
|
// Get parallel reduction domains
|
|
const bool tidx =
|
|
par_domains.find(ParallelType::TIDx) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDx)->isReduction();
|
|
const bool tidy =
|
|
par_domains.find(ParallelType::TIDy) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDy)->isReduction();
|
|
const bool tidz =
|
|
par_domains.find(ParallelType::TIDz) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDz)->isReduction();
|
|
|
|
const auto data_type = rop->out()->dtype();
|
|
const auto op_type = rop->getReductionOpType();
|
|
|
|
if (has_block_reduce) {
|
|
if (has_grid_reduce) {
|
|
indent() << data_type << " "
|
|
<< "block_result_" << block_reduce_name_ << "="
|
|
<< gen(rop->init()) << ";\n";
|
|
}
|
|
indent() << "blockReduce<" << (tidx ? "true" : "false") << ", "
|
|
<< (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
|
|
<< ">(\n";
|
|
if (has_grid_reduce) {
|
|
indent() << kTab << "block_result_" << block_reduce_name_ << ",\n";
|
|
} else {
|
|
indent() << kTab << gen(rop->out()) << ",\n";
|
|
}
|
|
indent() << kTab << gen(rop->in()) << ",\n";
|
|
indent() << kTab << genReductionOp(op_type, rop->out()) << ",\n";
|
|
indent() << kTab << "threadIdx,\n";
|
|
indent() << kTab << "blockDim,\n";
|
|
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
rop->predicate() != nullptr && rop->predicate()->hasValue());
|
|
auto read_pred = genInline(rop->predicate());
|
|
indent() << kTab << read_pred << ",\n";
|
|
// Pass the write predicate if available and different from the
|
|
// default predicate. The blockReduce runtime function uses the
|
|
// default predicate for both read and write when only the
|
|
// default one is given.
|
|
if (rop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(rop->writePredicate()->hasValue());
|
|
auto write_pred = genInline(rop->writePredicate());
|
|
indent() << kTab << write_pred << ",\n";
|
|
}
|
|
indent() << kTab << data_type << "(" << genInline(rop->init()) << "));\n";
|
|
}
|
|
}
|
|
|
|
void handle(const WelfordOp* wop) final {
|
|
TORCH_INTERNAL_ASSERT(wop->out()->isA<kir::TensorIndex>());
|
|
|
|
const auto out = wop->out()->as<kir::TensorIndex>();
|
|
const auto domain = out->view()->domain();
|
|
|
|
const auto out_var = wop->outVar();
|
|
const auto out_avg = wop->outAvg();
|
|
const auto out_N = wop->outN();
|
|
|
|
const auto in_var = wop->inVar();
|
|
const auto in_avg = wop->inAvg();
|
|
const auto in_N = wop->inN();
|
|
|
|
const bool has_block_reduce = domain->hasBlockReduction();
|
|
const bool has_grid_reduce = domain->hasGridReduction();
|
|
|
|
// Serial WelfordOp generation
|
|
if (!has_block_reduce && !has_grid_reduce) {
|
|
indent() << "welfordCombine ("
|
|
<< "\n";
|
|
indent() << " " << gen(out_avg) << ",\n";
|
|
indent() << " " << gen(out_var) << ",\n";
|
|
indent() << " " << gen(out_N) << ",\n";
|
|
indent() << " " << gen(in_avg) << ",\n";
|
|
if (in_var) {
|
|
indent() << " " << gen(in_var) << ",\n";
|
|
} else {
|
|
indent() << " (" << in_avg->dtype() << ") 0"
|
|
<< ",\n";
|
|
}
|
|
indent() << " (" << out_N->dtype() << ")" << gen(in_N) << ");\n";
|
|
return;
|
|
}
|
|
|
|
const auto par_domains = ir_utils::getParallelDomains(wop->out());
|
|
// Get parallel reduction domains
|
|
const bool tidx =
|
|
par_domains.find(ParallelType::TIDx) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDx)->isReduction();
|
|
const bool tidy =
|
|
par_domains.find(ParallelType::TIDy) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDy)->isReduction();
|
|
const bool tidz =
|
|
par_domains.find(ParallelType::TIDz) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDz)->isReduction();
|
|
|
|
const auto data_type = wop->out()->dtype();
|
|
|
|
if (has_block_reduce) {
|
|
if (has_grid_reduce) {
|
|
// allocate block result
|
|
indent() << data_type << " "
|
|
<< "block_result_avg_" << block_reduce_name_ << " = "
|
|
<< gen(wop->initAvg()) << ";\n";
|
|
indent() << data_type << " "
|
|
<< "block_result_var_" << block_reduce_name_ << " = "
|
|
<< gen(wop->initVar()) << ";\n";
|
|
indent() << DataType::Int << " "
|
|
<< "block_result_n_" << block_reduce_name_ << " = "
|
|
<< gen(wop->initN()) << ";\n";
|
|
}
|
|
indent() << "blockWelford<" << (tidx ? "true" : "false") << ", "
|
|
<< (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
|
|
<< ">(\n";
|
|
if (has_grid_reduce) {
|
|
indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n"
|
|
<< kTab << "block_result_var_" << block_reduce_name_ << ",\n"
|
|
<< kTab << "block_result_n_" << block_reduce_name_ << ",\n";
|
|
} else {
|
|
indent() << kTab << gen(wop->outAvg()) << ",\n";
|
|
indent() << kTab << gen(wop->outVar()) << ",\n";
|
|
indent() << kTab << gen(wop->outN()) << ",\n";
|
|
}
|
|
indent() << " " << gen(in_avg) << ",\n";
|
|
if (in_var) {
|
|
indent() << " " << gen(in_var) << ",\n";
|
|
} else {
|
|
indent() << " (" << in_avg->dtype() << ") 0"
|
|
<< ",\n";
|
|
}
|
|
indent() << out_N->dtype() << "(" << gen(in_N) << "),\n";
|
|
indent() << kTab << "threadIdx,\n";
|
|
indent() << kTab << "blockDim,\n";
|
|
indent() << kTab << "reinterpret_cast<" << data_type
|
|
<< "*>(shared_mem_avg),\n";
|
|
indent() << kTab << "reinterpret_cast<" << data_type
|
|
<< "*>(shared_mem_var),\n";
|
|
indent() << kTab << "reinterpret_cast<" << DataType::Int
|
|
<< "*>(shared_mem_n),\n";
|
|
TORCH_INTERNAL_ASSERT(wop->predicate() != nullptr);
|
|
TORCH_INTERNAL_ASSERT(
|
|
wop->predicate() != nullptr && wop->predicate()->hasValue());
|
|
auto read_pred = genInline(wop->predicate());
|
|
indent() << kTab << read_pred << ",\n";
|
|
if (wop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(wop->writePredicate()->hasValue());
|
|
auto write_pred = genInline(wop->writePredicate());
|
|
indent() << kTab << write_pred << ",\n";
|
|
}
|
|
indent() << kTab << data_type << "(0));\n";
|
|
}
|
|
}
|
|
|
|
// Support ReductionOp and WelfordOp
|
|
template <typename REDUCTION_OP>
|
|
std::string generateGridReduceTemplateFlags(
|
|
const REDUCTION_OP* rop,
|
|
const ParallelTypeBitmap& thread_pred) {
|
|
const auto par_domains = ir_utils::getParallelDomains(rop->outputs()[0]);
|
|
std::stringstream flags;
|
|
for (const ParallelType pt : kParallelTypeThreads) {
|
|
const bool parallel_reduction =
|
|
par_domains.find(pt) != par_domains.end() &&
|
|
par_domains.at(pt)->isReduction();
|
|
const bool pred = thread_pred.get(pt);
|
|
TORCH_INTERNAL_ASSERT(
|
|
!(parallel_reduction && pred), "Cannot reduce predicated axis: ", pt);
|
|
bool flag = false;
|
|
// Currently assumed that no dimensions parallelized with blocks
|
|
// are predicated. This assumption may be lifted, but
|
|
// gridReduction would need some changes.
|
|
if (isParallelTypeBlockDim(pt)) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!pred, "Predication on block dimensions not allowed: ", pt);
|
|
flag = parallel_reduction;
|
|
} else {
|
|
flag = !pred && !parallel_reduction;
|
|
}
|
|
if (pt != kParallelTypeThreads[0]) {
|
|
flags << ", ";
|
|
}
|
|
flags << (flag ? "true" : "false");
|
|
}
|
|
return flags.str();
|
|
}
|
|
|
|
void handle(const kir::GridReduction* grop) final {
|
|
const auto rop = grop->reduction_op();
|
|
TORCH_INTERNAL_ASSERT(rop->out()->isA<kir::TensorIndex>());
|
|
|
|
const auto out = rop->out()->as<kir::TensorIndex>();
|
|
const auto domain = out->view()->domain();
|
|
TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
|
|
|
|
const auto data_type = rop->out()->dtype();
|
|
const auto op_type = rop->getReductionOpType();
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->reduction_buffer()->buffer()->isA<TensorView>());
|
|
TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA<TensorView>());
|
|
const auto work_buffer =
|
|
grop->reduction_buffer()->buffer()->as<TensorView>();
|
|
const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
const std::string flags_str =
|
|
generateGridReduceTemplateFlags(rop, grop->threadPredicate());
|
|
|
|
const bool persistent_sync =
|
|
kernel_->summary().has_cooperative_grid_reduction;
|
|
|
|
// Since block-level reduction is already done, those dimensions
|
|
// with tidx/y/z being true do not participate in the grid reduction.
|
|
indent() << "reduction::gridReduce<" << flags_str << ", "
|
|
<< (persistent_sync ? "true" : "false") << ">(\n";
|
|
indent() << kTab << gen(rop->out()) << ",\n";
|
|
if (domain->hasBlockReduction()) {
|
|
indent() << kTab << "block_result_" << block_reduce_name_ << ",\n";
|
|
block_reduce_name_++;
|
|
} else {
|
|
indent() << kTab << gen(rop->in()) << ",\n";
|
|
}
|
|
indent() << kTab << genReductionOp(op_type, out) << ",\n";
|
|
indent() << kTab << "&" << varName(work_buffer) << "[0],\n";
|
|
indent() << kTab << varName(sync_buffer) << ",\n";
|
|
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->predicate() != nullptr && grop->predicate()->hasValue());
|
|
auto read_pred = genInline(grop->predicate());
|
|
indent() << kTab << read_pred << ",\n";
|
|
if (grop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue());
|
|
auto write_pred = genInline(grop->writePredicate());
|
|
indent() << kTab << write_pred << ",\n";
|
|
} else {
|
|
indent() << kTab << read_pred << ",\n";
|
|
}
|
|
indent() << kTab << data_type << "("
|
|
<< genInline(grop->reduction_op()->init()) << "));\n";
|
|
}
|
|
|
|
void handle(const kir::GridBroadcast* grop) final {
|
|
const auto bop = grop->broadcast_op();
|
|
TORCH_INTERNAL_ASSERT(bop->out()->isA<kir::TensorIndex>());
|
|
|
|
const ParallelTypeBitmap parallel_types =
|
|
kernel_->summary().broadcast_parallel_types.at(bop);
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
parallel_types.hasBID(),
|
|
"GridBroadcast needs to be used with a broadcast op that is parallelized with the BID parallel types");
|
|
|
|
const auto out = bop->out()->as<kir::TensorIndex>();
|
|
const auto domain = out->view()->domain();
|
|
|
|
const auto data_type = bop->out()->dtype();
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->broadcast_buffer()->buffer()->isA<TensorView>());
|
|
TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA<TensorView>());
|
|
const auto work_buffer =
|
|
grop->broadcast_buffer()->buffer()->as<TensorView>();
|
|
const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
std::stringstream flags_str;
|
|
for (const ParallelType pt : kParallelTypeThreads) {
|
|
const bool parallel_bcast = parallel_types.get(pt);
|
|
if (pt != kParallelTypeThreads[0]) {
|
|
flags_str << ", ";
|
|
}
|
|
flags_str << (parallel_bcast ? "true" : "false");
|
|
}
|
|
|
|
// Since block-level broadcast has not necessarily been performed before
|
|
// this function call, so grid broadcast may be broadcasting across both
|
|
// the grid and the block level.
|
|
indent() << "grid_broadcast::broadcast<" << flags_str.str() << ">(\n";
|
|
indent() << kTab << gen(bop->out()) << ",\n";
|
|
indent() << kTab << gen(bop->in()) << ",\n";
|
|
indent() << kTab << "&" << varName(work_buffer) << "[0],\n";
|
|
indent() << kTab << varName(sync_buffer) << ",\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->predicate() != nullptr && grop->predicate()->hasValue());
|
|
indent() << kTab << genInline(grop->predicate()) << ");\n";
|
|
}
|
|
|
|
void handle(const kir::GridWelford* gwop) final {
|
|
const auto wop = gwop->welford_op();
|
|
TORCH_INTERNAL_ASSERT(wop->outAvg()->isA<kir::TensorIndex>());
|
|
|
|
const auto out = wop->out()->as<kir::TensorIndex>();
|
|
const auto domain = out->view()->domain();
|
|
TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
|
|
|
|
const auto data_type = out->dtype();
|
|
|
|
TORCH_INTERNAL_ASSERT(gwop->var_buffer()->buffer()->isA<TensorView>());
|
|
TORCH_INTERNAL_ASSERT(gwop->sync_buffer()->buffer()->isA<TensorView>());
|
|
|
|
const auto avg_buffer = gwop->avg_buffer()->buffer()->as<TensorView>();
|
|
const auto var_buffer = gwop->var_buffer()->buffer()->as<TensorView>();
|
|
const auto n_buffer = gwop->N_buffer()->buffer()->as<TensorView>();
|
|
const auto sync_buffer = gwop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
const bool persistent_sync =
|
|
kernel_->summary().has_cooperative_grid_reduction;
|
|
|
|
const std::string flags_str =
|
|
generateGridReduceTemplateFlags(wop, gwop->threadPredicate());
|
|
|
|
// Since block-level reduction is already done, those dimensions
|
|
// with tidx/y/z being true do not participate in the grid reduction.
|
|
indent() << "welford::gridWelford<" << flags_str << ", "
|
|
<< (persistent_sync ? "true" : "false") << ">(\n";
|
|
indent() << kTab << gen(wop->outAvg()) << ",\n"
|
|
<< kTab << gen(wop->outVar()) << ",\n"
|
|
<< kTab << gen(wop->outN()) << ",\n";
|
|
if (domain->hasBlockReduction()) {
|
|
indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n"
|
|
<< kTab << "block_result_var_" << block_reduce_name_ << ",\n"
|
|
<< kTab << "block_result_n_" << block_reduce_name_ << ",\n";
|
|
block_reduce_name_++;
|
|
} else {
|
|
indent() << kTab << gen(wop->inAvg()) << ",\n";
|
|
if (wop->inVar() == nullptr) {
|
|
indent() << kTab << "(" << data_type << ") 0,\n";
|
|
} else {
|
|
indent() << kTab << gen(wop->inVar()) << ",\n";
|
|
}
|
|
indent() << kTab << "(" << wop->outN()->dtype() << ")" << gen(wop->inN())
|
|
<< ",\n";
|
|
}
|
|
indent() << kTab << "&" << varName(avg_buffer) << "[0],\n";
|
|
indent() << kTab << "&" << varName(var_buffer) << "[0],\n";
|
|
indent() << kTab << "&" << varName(n_buffer) << "[0],\n";
|
|
indent() << kTab << varName(sync_buffer) << ",\n";
|
|
indent() << kTab << "reinterpret_cast<" << data_type
|
|
<< "*>(shared_mem_avg),\n";
|
|
indent() << kTab << "reinterpret_cast<" << data_type
|
|
<< "*>(shared_mem_var),\n";
|
|
indent() << kTab << "reinterpret_cast<" << wop->outN()->dtype()
|
|
<< "*>(shared_mem_n),\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
gwop->predicate() != nullptr && gwop->predicate()->hasValue());
|
|
auto read_pred = genInline(gwop->predicate());
|
|
indent() << kTab << read_pred << ",\n";
|
|
if (gwop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue());
|
|
auto write_pred = genInline(gwop->writePredicate());
|
|
indent() << kTab << write_pred << ",\n";
|
|
} else {
|
|
indent() << kTab << read_pred << ",\n";
|
|
}
|
|
// TODO : init value support or remove.
|
|
indent() << kTab << data_type << "(0));\n";
|
|
}
|
|
|
|
void handleScope(const kir::Scope& scope) {
|
|
for (auto expr : scope.exprs()) {
|
|
OptOutConstDispatch::handle(expr);
|
|
}
|
|
}
|
|
|
|
void handle(const kir::ForLoop* loop) final {
|
|
if (loop->iter_domain()->isBroadcast()) {
|
|
handleScope(loop->body());
|
|
return;
|
|
} else if (loop->vectorize()) {
|
|
vectorize_scope_ = loop->vectorize();
|
|
handleScope(loop->body());
|
|
vectorize_scope_ = false;
|
|
return;
|
|
} else if (loop->iter_domain()->isStride()) {
|
|
// A stride domain only executes the loop body with the loop
|
|
// index being zero.
|
|
indent() << "constexpr "
|
|
<< "nvfuser_index_t"
|
|
<< " " << gen(loop->index()) << " = 0;\n";
|
|
handleScope(loop->body());
|
|
return;
|
|
}
|
|
|
|
// By default, a parallelized loop would look like:
|
|
//
|
|
// for (int x = threadIdx.x; x < stop; x += blockDim.x) {
|
|
// do_some_comp(x);
|
|
// }
|
|
//
|
|
// When stop is guaranteed to be smaller or equal to the number of
|
|
// threads, the for-loop is not necessary. In the above case, we
|
|
// would just generate the loop body without the for clause but
|
|
// references to the loop index replaced by the loop start value.
|
|
//
|
|
// When the loop end is the same as the IterDomain extent, the
|
|
// assumption can be safely made. This is more conservative than
|
|
// necessary since the loop stop value just needs to be <= the
|
|
// IterDomain extent. However, at this point, this conservative
|
|
// analysis seems sufficient.
|
|
if (loop->stop() == loop->iter_domain()->extent() &&
|
|
loop->iter_domain()->isThread()) {
|
|
// Register a replacement of references to the loop index with
|
|
// the loop start value.
|
|
replacement_map_.insert({loop->index(), loop->start()});
|
|
handleScope(loop->body());
|
|
replacement_map_.erase(loop->index());
|
|
return;
|
|
}
|
|
|
|
if (loop->start()->isZeroInt() && loop->stop()->isOneInt()) {
|
|
indent() << "constexpr "
|
|
<< "nvfuser_index_t"
|
|
<< " " << gen(loop->index()) << " = 0;\n";
|
|
handleScope(loop->body());
|
|
return;
|
|
} else if (
|
|
// Special case handling for a pattern where start == end - 1.
|
|
loop->start()->definition() != nullptr &&
|
|
loop->start()->definition()->isA<BinaryOp>() &&
|
|
loop->start()->definition()->as<BinaryOp>()->getBinaryOpType() ==
|
|
BinaryOpType::Sub &&
|
|
loop->start()->definition()->as<BinaryOp>()->lhs() == loop->stop() &&
|
|
loop->start()->definition()->as<BinaryOp>()->rhs()->isOneInt()) {
|
|
indent() << "const "
|
|
<< "nvfuser_index_t"
|
|
<< " " << gen(loop->index()) << " = " << genInline(loop->start())
|
|
<< ";\n";
|
|
handleScope(loop->body());
|
|
return;
|
|
}
|
|
|
|
const auto gen_index = gen(loop->index());
|
|
const auto gen_start = genInline(loop->start());
|
|
const auto gen_stop = genInline(loop->stop());
|
|
const auto gen_step = genInline(loop->step());
|
|
|
|
std::stringstream step_code;
|
|
if (loop->step()->isOneInt()) {
|
|
step_code << "++" << gen_index;
|
|
} else {
|
|
step_code << gen_index << " += " << gen_step;
|
|
}
|
|
if (loop->isUnrolled()) {
|
|
indent() << "#pragma unroll\n";
|
|
} else {
|
|
indent() << "#pragma unroll 1\n";
|
|
}
|
|
|
|
indent() << "for(nvfuser_index_t " << gen_index;
|
|
if (loop->iter_domain()->isParallelized()) {
|
|
code_ << " = " << gen_start << "; ";
|
|
} else {
|
|
// Do not start at the start of the ID when not parallelized. Instead,
|
|
// start at 0. Predicates will protect buffers between 0 and ID->start(),
|
|
// however if we started at ID->start and extent == ID->start, we could
|
|
// have a "degenerate" loop (loop with no iterations). It may not be an
|
|
// issue to have a 0-sized loop, but all potential consequences haven't
|
|
// been covered. One example is WAR analysis which could incorrectly think
|
|
// a barrier inside a 0-sized loop actually provides protection.
|
|
code_ << " = 0; ";
|
|
}
|
|
code_ << gen_index << " < " << gen_stop << "; " << step_code.str() << ") ";
|
|
startBlock(true);
|
|
handleScope(loop->body());
|
|
endBlock();
|
|
}
|
|
|
|
void handle(const kir::IfThenElse* ite) final {
|
|
auto conditional = ite->predicate()->value();
|
|
if (conditional->isConst()) {
|
|
// If the conditional is a constant, then the IfThenElse is not required
|
|
if (conditional->value().value()) {
|
|
handleScope(ite->thenBody());
|
|
} else {
|
|
handleScope(ite->elseBody());
|
|
}
|
|
return;
|
|
}
|
|
|
|
indent() << "if (" << genInline(conditional) << ") ";
|
|
|
|
// "then" block
|
|
startBlock(true);
|
|
handleScope(ite->thenBody());
|
|
|
|
// "else" block (optional)
|
|
if (ite->hasElse()) {
|
|
endBlock(" else ");
|
|
startBlock(true);
|
|
handleScope(ite->elseBody());
|
|
}
|
|
|
|
endBlock();
|
|
}
|
|
|
|
void handle(const kir::Allocate* alloc) final {
|
|
const auto buffer_dtype = alloc->buffer()->dtype();
|
|
|
|
if (!alloc->buffer()->isA<TensorView>()) {
|
|
indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n";
|
|
return;
|
|
}
|
|
|
|
const auto tv = alloc->buffer()->as<TensorView>();
|
|
|
|
const auto size = alloc->size();
|
|
TORCH_INTERNAL_ASSERT(size != nullptr);
|
|
|
|
if (alloc->alias() != nullptr) {
|
|
// Allocate alias another Allocate stmt
|
|
const auto alias_tv = alloc->alias()->buffer()->as<TensorView>();
|
|
indent() << "// Alias Allocation - " << alloc->memoryType() << "\n";
|
|
indent() << buffer_dtype << "* " << varName(tv) << " = "
|
|
<< varName(alias_tv) << ";\n";
|
|
} else {
|
|
// Standard Memory Allocation
|
|
switch (tv->getMemoryType()) {
|
|
case MemoryType::Global:
|
|
indent() << "// Allocate global tensor " << varName(tv) << "\n";
|
|
break;
|
|
case MemoryType::Shared:
|
|
if (kir::ExpressionEvaluator::isConst(size)) {
|
|
// Static shared memory
|
|
indent() << "__shared__ " << buffer_dtype << " " << varName(tv)
|
|
<< "[" << genInline(size) << "];\n";
|
|
} else {
|
|
// Align Offset Position
|
|
indent() << "offset = alignBufferSize(offset,"
|
|
<< dataTypeSize(buffer_dtype) << ");\n";
|
|
// Shared Memory Pointer
|
|
indent() << buffer_dtype << "* " << varName(tv)
|
|
<< " = reinterpret_cast<" << buffer_dtype << "*>"
|
|
<< "(array + offset);\n";
|
|
// Increment Offset Position
|
|
indent() << "offset += (" << genInline(size) << " * sizeof("
|
|
<< buffer_dtype << "));\n";
|
|
}
|
|
break;
|
|
case MemoryType::Local:
|
|
indent() << buffer_dtype << " " << varName(tv) << "["
|
|
<< genInline(size) << "];\n";
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unexpected memory type");
|
|
}
|
|
}
|
|
}
|
|
|
|
void handle(const kir::Sync*) final {
|
|
// Use a custom synchronization method if enabled
|
|
if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
|
|
indent() << "block_sync::sync();\n";
|
|
} else {
|
|
indent() << "__barrier_sync(0);\n";
|
|
}
|
|
}
|
|
|
|
void handle(const kir::InitMagicZero*) final {
|
|
indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
|
|
}
|
|
|
|
void handle(const kir::UpdateMagicZero*) final {
|
|
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
|
|
}
|
|
|
|
private:
|
|
std::stringstream code_;
|
|
const kir::Kernel* kernel_;
|
|
int block_nest_level_ = 0;
|
|
int block_reduce_name_ = 0;
|
|
bool print_inline_ = false;
|
|
|
|
// Mark when we are inside of a vectorized for-loop
|
|
bool vectorize_scope_ = false;
|
|
|
|
//! Holds active replacement mappings during codegen
|
|
std::unordered_map<const Statement*, const Statement*> replacement_map_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::string generateCudaKernel(
|
|
const kir::Kernel* kernel,
|
|
const std::string& kernel_name) {
|
|
FUSER_PERF_SCOPE("generateCudaKernel");
|
|
return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
|
|
}
|
|
|
|
} // namespace codegen
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|