pytorch/torch/csrc/jit/codegen/cuda/codegen.cpp
Thomas Viehmann d3d8da7a8e Enable CUDA Fuser for ROCm (#45965)
Summary:
This enables the cuda fuser on ROCm and enables tests for them.

Part of this patch is based on work of Rohith Nallamaddi, thank you.
Errors are my own, of course.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45965

Reviewed By: seemethere

Differential Revision: D24170457

Pulled By: walterddr

fbshipit-source-id: 3dd25b3501a41d2f00acba3ce8642ce51c49c9a6
2020-10-08 10:41:56 -07:00

645 lines
20 KiB
C++

#include <torch/csrc/jit/codegen/cuda/codegen.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <sstream>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
namespace codegen {
namespace {
class CudaKernelGenerator : private OptInConstDispatch {
static constexpr char* kTab = " ";
public:
static std::string generateKernelDefinition(
const Kernel* kernel,
const std::string& kernel_name) {
CudaKernelGenerator codegen(kernel);
codegen.genDeclaration(kernel_name);
codegen.startBlock();
codegen.genPrologue();
codegen.genBody();
codegen.endBlock();
TORCH_CHECK(codegen.block_nest_level_ == 0);
return codegen.code_.str();
}
private:
explicit CudaKernelGenerator(const Kernel* kernel) : kernel_(kernel) {}
// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
const auto& kernel_summary = kernel_->summary();
code_ << "__global__ void " << kernel_name << "(";
std::vector<Val*> params;
// Inputs
for (auto val : kernel_->inputs()) {
params.push_back(val);
}
// Outputs
for (auto val : kernel_->outputs()) {
params.push_back(val);
}
// Global buffers
for (auto allocate : kernel_summary.global_allocations) {
params.push_back(allocate->buffer());
}
// Generate parameter declarations
for (Val* val : params) {
switch (val->getValType().value()) {
case ValType::KirTensorView: {
// TODO(kir): review this
const auto tv = val->as<kir::TensorView>();
code_ << "Tensor<" << val->getDataType().value() << ", "
<< TensorDomain::noReductions(
tv->fuserTv()->getMaybeRFactorDomain())
.size()
<< "> " << gen(tv);
break;
}
case ValType::KirScalar:
code_ << val->getDataType().value() << " " << gen(val);
break;
default:
TORCH_CHECK(!"Unexpected parameter type");
}
if (val != params.back()) {
code_ << ", ";
}
}
// Kernels generating random numbers take extra (seed, offset) arguments
if (kernel_summary.is_stochastic) {
code_ << ", unsigned long long seed, unsigned long long offset";
}
code_ << ") ";
}
// Generates setup code which is executed before the kernel body
void genPrologue() {
const auto& kernel_summary = kernel_->summary();
// Random number generator (optional)
if (kernel_summary.is_stochastic) {
indent() << "const int idx = blockIdx.x*blockDim.x + threadIdx.x;\n";
indent() << "Philox rnd(seed, idx, offset);\n";
}
// Do we have any dynamic shared memory buffers?
const bool has_dynamic_smem =
!kernel_summary.dynamic_smem_allocations.empty();
// Do we have any reductions?
const bool has_reductions = kernel_summary.has_block_reductions ||
kernel_summary.has_grid_reductions;
// Shared memory
if (has_dynamic_smem || has_reductions) {
indent() << "alignas("
#ifndef __HIP_PLATFORM_HCC__
<< dataTypeSize(kernel_summary.largest_smem_data_type)
#else
<< 8 // for HIP, we want 8-aligned even for smaller datatypes
#endif
<< ") extern __shared__ char array[];\n";
if (has_dynamic_smem) {
indent() << "unsigned offset = 0;\n";
}
if (has_reductions) {
indent() << "void* shared_mem = array;\n";
if (has_dynamic_smem) {
indent() << "offset += "
<< "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
<< kernel_summary.largest_smem_data_type << "));\n";
}
}
}
}
void genBody() {
for (auto expr : kernel_->topLevelExprs()) {
OptInConstDispatch::handle(expr);
}
}
void startBlock(bool continuation = false) {
if (continuation) {
code_ << "{\n";
} else {
indent() << "{\n";
}
++block_nest_level_;
}
void endBlock(const char* sep = "\n") {
--block_nest_level_;
TORCH_CHECK(block_nest_level_ >= 0);
indent() << "}" << sep;
}
std::ostream& indent() {
for (int i = 0; i < block_nest_level_; ++i) {
code_ << kTab;
}
return code_;
}
std::string gen(const Statement* stmt) {
std::stringstream tmp_code;
std::swap(tmp_code, code_);
handle(stmt);
std::swap(tmp_code, code_);
return tmp_code.str();
}
std::string gen(const kir::TensorView* tv) {
std::stringstream tv_name;
tv_name << "T" << tv->name();
return tv_name.str();
}
std::string genInline(const Statement* stmt) {
const bool saved_inline = print_inline_;
print_inline_ = true;
const auto result = gen(stmt);
print_inline_ = saved_inline;
return result;
}
void handle(const Statement* node) final {
OptInConstDispatch::handle(node);
}
void handle(const Expr* node) final {
OptInConstDispatch::handle(node);
}
void handle(const Val* node) final {
OptInConstDispatch::handle(node);
}
void handle(const kir::Bool* node) final {
const auto def = node->getOrigin();
if (print_inline_ && def != nullptr) {
code_ << "(" << gen(def) << ")";
} else if (node->isSymbolic()) {
code_ << "b" << node->name();
} else {
code_ << *node->value();
}
}
void handle(const kir::Float* node) final {
const auto def = node->getOrigin();
if (print_inline_ && def != nullptr) {
code_ << "(" << gen(def) << ")";
} else if (node->isSymbolic()) {
code_ << "f" << node->name();
} else {
const int digits = std::numeric_limits<Float::ScalarType>::max_digits10;
code_ << "float(" << std::setprecision(digits) << *node->value() << ")";
}
}
void handle(const kir::Half* node) final {
const auto def = node->getOrigin();
if (print_inline_ && def != nullptr) {
code_ << "(" << gen(def) << ")";
} else if (node->isSymbolic()) {
code_ << "h" << node->name();
} else {
code_ << "__float2half(" << *node->value() << ")";
}
}
void handle(const kir::Int* node) final {
const auto def = node->getOrigin();
if (print_inline_ && def != nullptr) {
code_ << "(" << gen(def) << ")";
} else if (node->isSymbolic()) {
code_ << "i" << node->name();
} else {
code_ << *node->value();
}
}
void handle(const kir::NamedScalar* node) final {
code_ << node->name();
}
void handle(const kir::TensorIndex* node) final {
code_ << gen(node->view()) << "[";
bool first = true;
for (auto* ind : node->indices()) {
if (!ind->isZeroInt()) {
if (!first) {
code_ << " + ";
}
code_ << genInline(ind);
first = false;
}
}
if (first) {
code_ << "0";
}
code_ << "]";
}
void handle(const kir::IterDomain* node) final {
TORCH_INTERNAL_ASSERT(!"Unreachable");
}
void handle(const kir::TensorDomain* node) final {
TORCH_INTERNAL_ASSERT(!"Unreachable");
}
void handle(const kir::TensorView* node) final {
TORCH_INTERNAL_ASSERT(!"Unreachable");
}
void handle(const kir::UnaryOp* node) final {
if (!print_inline_) {
indent() << gen(node->out());
if (!node->out()->isScalar() && !node->in()->isScalar()) {
code_ << "\n";
indent() << kTab;
}
code_ << " = ";
}
if (auto op = inline_op_str(node->getUnaryOpType())) {
code_ << *op << gen(node->in());
} else {
if (node->getUnaryOpType() == UnaryOpType::Cast) {
const auto cast_str =
cast_func_str({node->in()->getDataType().value(),
node->out()->getDataType().value()});
code_ << cast_str.value();
} else {
code_ << node->getUnaryOpType();
}
code_ << "(";
if (node->getUnaryOpType() == UnaryOpType::RandLike) {
code_ << "rnd";
} else {
code_ << gen(node->in());
}
code_ << ")";
}
if (!print_inline_) {
code_ << ";\n";
}
}
std::string genBinaryOp(
BinaryOpType op_type,
const std::string& lhs,
const std::string& rhs) {
std::stringstream expr;
if (auto op = inline_op_str(op_type)) {
expr << lhs << " " << *op << " " << rhs;
} else {
expr << op_type << "(" << lhs << ", " << rhs << ")";
}
return expr.str();
}
void handle(const kir::BinaryOp* node) final {
const auto op_type = node->getBinaryOpType();
if (print_inline_) {
// Inline expression: `lhs op rhs`
code_ << genBinaryOp(op_type, gen(node->lhs()), gen(node->rhs()));
} else {
indent() << gen(node->out());
if (node->out()->isScalar()) {
// Single line: `out = lhs op rhs;`
code_ << " = "
<< genBinaryOp(op_type, gen(node->lhs()), gen(node->rhs()));
} else {
// Split TensorView expressions across multiple lines:
//
// out
// = lhs
// op rhs;
//
if (auto op = inline_op_str(op_type)) {
code_ << "\n";
indent() << kTab << "= " << gen(node->lhs()) << "\n";
indent() << kTab << *op << " " << gen(node->rhs());
} else {
code_ << " = " << op_type << "(\n";
indent() << kTab << gen(node->lhs()) << ",\n";
indent() << kTab << gen(node->rhs()) << ")";
}
}
code_ << ";\n";
}
}
void handle(const kir::TernaryOp* node) final {
if (!print_inline_) {
indent() << gen(node->out());
if (!node->out()->isScalar()) {
code_ << "\n";
indent() << kTab;
}
code_ << " = ";
}
code_ << node->getTernaryOpType() << "(" << gen(node->in1()) << ", "
<< gen(node->in2()) << ", " << gen(node->in3()) << ")";
if (!print_inline_) {
code_ << ";\n";
}
}
std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
std::stringstream lambda;
lambda << "[](" << data_type << " &a, " << data_type << " b) "
<< "{ a = " << genBinaryOp(op_type, "a", "b") << "; }";
return lambda.str();
}
void handle(const kir::BroadcastOp* node) final {
const ir_utils::ParallelTypeBitmap domains =
ir_utils::getParallelBroadcastDomains(
node->out(), kernel_->predicateMap());
const bool thread_x = domains.get(ParallelType::TIDx);
const bool thread_y = domains.get(ParallelType::TIDy);
const bool thread_z = domains.get(ParallelType::TIDz);
const bool block_x = domains.get(ParallelType::BIDx);
const bool block_y = domains.get(ParallelType::BIDy);
const bool block_z = domains.get(ParallelType::BIDz);
const bool grid_broadcast_needed = block_x || block_y || block_z;
const bool block_broadcast_needed = thread_x || thread_y || thread_z;
TORCH_INTERNAL_ASSERT(
!grid_broadcast_needed,
"Parallel broadcast across blocks not supported");
if (block_broadcast_needed) {
const auto data_type = node->out()->getDataType().value();
indent() << "broadcast::blockBroadcast<" << (thread_x ? "true" : "false")
<< ", " << (thread_y ? "true" : "false") << ", "
<< (thread_z ? "true" : "false") << ">(\n";
indent() << kTab << gen(node->out()) << ",\n";
indent() << kTab << gen(node->in()) << ",\n";
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem));\n";
} else {
indent() << gen(node->out()) << "\n";
indent() << kTab << " = " << gen(node->in()) << ";\n";
}
}
void handle(const kir::ReductionOp* node) final {
TORCH_CHECK(node->out()->getValType() == ValType::TensorIndex);
const auto out = node->out()->as<kir::TensorIndex>();
const auto domain = out->view()->domain();
const bool has_block_reduce = domain->hasBlockReduction();
const bool has_grid_reduce = domain->hasGridReduction();
if (!has_block_reduce && !has_grid_reduce) {
const auto gen_out = gen(out);
const auto op_type = node->getReductionOpType();
indent() << gen_out << " = "
<< genBinaryOp(op_type, gen_out, gen(node->in())) << ";\n";
return;
}
const auto par_domains = node->getParallelReductionDomains();
const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end();
const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
const auto data_type = node->out()->getDataType().value();
const auto op_type = node->getReductionOpType();
if (has_block_reduce) {
if (has_grid_reduce) {
indent() << data_type << " "
<< "block_result"
<< ";\n";
}
indent() << "blockReduce<" << (tidx ? "true" : "false") << ", "
<< (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
<< ">(\n";
if (has_grid_reduce) {
indent() << kTab << "block_result"
<< ",\n";
} else {
indent() << kTab << gen(node->out()) << ",\n";
}
indent() << kTab << gen(node->in()) << ",\n";
indent() << kTab << genReductionOp(op_type, data_type) << ",\n";
indent() << kTab << "threadIdx,\n";
indent() << kTab << "blockDim,\n";
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
if (node->pred() == nullptr) {
indent() << kTab << "true,\n";
} else {
indent() << kTab << genInline(node->pred()) << ",\n";
}
indent() << kTab << genInline(node->init()) << ");\n";
}
}
void handle(const kir::GridReduction* node) final {
const auto rop = node->reduction_op();
TORCH_INTERNAL_ASSERT(rop->out()->getValType() == ValType::TensorIndex);
const auto out = rop->out()->as<kir::TensorIndex>();
const auto domain = out->view()->domain();
TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
const auto par_domains = rop->getParallelReductionDomains();
const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end();
const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
const bool bidx = par_domains.find(ParallelType::BIDx) != par_domains.end();
const bool bidy = par_domains.find(ParallelType::BIDy) != par_domains.end();
const bool bidz = par_domains.find(ParallelType::BIDz) != par_domains.end();
const auto data_type = rop->out()->getDataType().value();
const auto op_type = rop->getReductionOpType();
TORCH_INTERNAL_ASSERT(
node->reduction_buffer()->buffer()->getValType().value() ==
ValType::KirTensorView);
TORCH_INTERNAL_ASSERT(
node->sync_buffer()->buffer()->getValType().value() ==
ValType::KirTensorView);
const auto work_buffer =
node->reduction_buffer()->buffer()->as<kir::TensorView>();
const auto sync_buffer =
node->sync_buffer()->buffer()->as<kir::TensorView>();
// Since block-level reduction is already done, those dimensions
// with tidx/y/z being true do not participate in the grid reduction.
indent() << kir::GridReduction::getPredicateFlagName(out->view()) << " = "
<< "reduction::gridReduce<" << (bidx ? "true" : "false") << ", "
<< (bidy ? "true" : "false") << ", " << (bidz ? "true" : "false")
<< ", " << (!tidx ? "true" : "false") << ", "
<< (!tidy ? "true" : "false") << ", " << (!tidz ? "true" : "false")
<< ">(\n";
indent() << kTab << gen(rop->out()) << ",\n";
if (domain->hasBlockReduction()) {
indent() << kTab << "block_result"
<< ",\n";
} else {
indent() << kTab << gen(rop->in()) << ",\n";
}
indent() << kTab << genReductionOp(op_type, data_type) << ",\n";
indent() << kTab << "&" << gen(work_buffer) << "[0],\n";
indent() << kTab << gen(sync_buffer) << ",\n";
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
if (node->pred() == nullptr) {
indent() << kTab << "true,\n";
} else {
indent() << kTab << genInline(node->pred()) << ",\n";
}
indent() << kTab << genInline(node->reduction_op()->init()) << ");\n";
}
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Woverloaded-virtual"
// TODO(Kir): fix me
void handle(const kir::Scope& scope) {
for (auto expr : scope.exprs()) {
handle(expr);
}
}
#pragma clang diagnostic pop
void handle(const kir::ForLoop* node) final {
// TODO(kir): handle this during lowering
if (node->iter_domain()->isThread() || node->iter_domain()->isBroadcast()) {
handle(node->body());
return;
}
const auto gen_index = gen(node->index());
const auto gen_start = genInline(node->iter_domain()->start());
const auto gen_extent = genInline(node->iter_domain()->extent());
indent() << "for(size_t " << gen_index << " = " << gen_start << "; "
<< gen_index << " < " << gen_extent << "; ++" << gen_index << ") ";
startBlock(true);
handle(node->body());
endBlock();
}
void handle(const kir::IfThenElse* node) final {
indent() << "if (" << genInline(node->cond()) << ") ";
// "then" block
startBlock(true);
handle(node->thenBody());
// "else" block (optional)
if (node->hasElse()) {
endBlock(" else ");
startBlock(true);
handle(node->elseBody());
}
endBlock();
}
// TODO(kir): fold initialization into Allocate
void handle(const kir::Allocate* node) final {
if (node->buffer()->getValType().value() != ValType::KirTensorView) {
indent() << node->buffer_type() << " " << gen(node->buffer()) << ";\n";
return;
}
const auto tv = node->buffer()->as<kir::TensorView>();
TORCH_INTERNAL_ASSERT(tv->domain()->nDims() > 0);
TORCH_INTERNAL_ASSERT(node->size() != nullptr);
switch (tv->memoryType()) {
case MemoryType::Global:
indent() << "// Allocate global tensor " << gen(tv) << "\n";
break;
case MemoryType::Shared:
if (node->size()->isConstScalar()) {
// Static shared memory
indent() << "__shared__ " << node->buffer_type() << " " << gen(tv)
<< "[" << genInline(node->size()) << "];\n";
} else {
// Align Offset Position
indent() << "offset = alignBufferSize(offset,"
<< dataTypeSize(node->buffer_type()) << ");\n";
// Shared Memory Pointer
indent() << node->buffer_type() << "* " << gen(tv)
<< " = reinterpret_cast<" << node->buffer_type() << "*>"
<< "(array + offset);\n";
// Increment Offset Position
indent() << "offset += (" << genInline(node->size()) << " * sizeof("
<< node->buffer_type() << "));\n";
}
break;
case MemoryType::Local:
indent() << node->buffer_type() << " " << gen(tv) << "["
<< genInline(node->size()) << "];\n";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Unexpected memory type");
}
}
void handle(const kir::Sync* node) final {
indent() << "__syncthreads();\n";
}
private:
std::stringstream code_;
const Kernel* kernel_;
int block_nest_level_ = 0;
// TODO(kir): replace with explicit assignment statements
bool print_inline_ = false;
};
} // namespace
std::string generateCudaKernel(
const Kernel* kernel,
const std::string& kernel_name) {
FUSER_PERF_SCOPE("generateCudaKernel");
return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
}
} // namespace codegen
} // namespace fuser
} // namespace jit
} // namespace torch