mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ A few bigger updates: 1. Initial support of cp.async and cp.async.wait: https://github.com/csarofeen/pytorch/pull/1619 2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: https://github.com/csarofeen/pytorch/pull/1643 3. Extending the infrastructure to support mma operators on turing and ampere arch: https://github.com/csarofeen/pytorch/pull/1440 Commits that's actually in this PR from the csarofeen branch ``` * dd2325294e236c5082c642819a1103bcfe4561a3 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710) * b3d1c3f446355a2d276bac8272e7aa8b5bb6b1f0 Fix missing cooperative launch (#1726) * dc670a226cbe52be46cecef47001f38bf9a09433 Async gmem copy support on sm80+ (#1619) * 5e6a8dab5a71aefe0548bbfa15d1a93c556d23fe Add turing mma support and test (#1643) * d6d6b7d3f10dd91dafa4cdbd5e460bbb38173af4 Fix rFactor when there are indirect root domain(s), and refactor (#1723) * 7093e39150c6d80e0f9f767d56654714a2e8a927 Mma op integration on ampere (#1440) * fade8da55e60a118c5595378896d34b862b2fcc3 patch python test for bfloat16 (#1724) * 8fbd0b18743a72ac10478857c3d2351204375685 Fine-grained kernel profiling (#1720) * 77c1b4fa633f9e631d267923f4537336fa328939 Adding dry run mode to skip arch dependent checks (#1702) * 151d95b97bebefc94199bb4a53423ede32b55451 More precise concretization analysis (#1719) * f4d3630ed54d7069dd377a64be1f91013b285b66 Enable complex python tests (#1667) * 4ceeee509774cc2ce6c834a4dc1e313f71d94503 Minor bugfix in transform_rfactor.cpp (#1715) * 3675c70faf218e86d2c78dbd3874b175a3b0a203 Separate root domain and rfactor domain in TransformPrinter (#1716) * f68b830d5def65dadfe29d4edf52fc703369c84a Fix scheduling with polymorphic broadcast (#1714) * 4ab5ef7ae2cfd8fffad1e1d882ae7c50631211dc updating_ci_machine (#1718) * 56585c58b1ff338704cafb0cd6be2b3d536bed5a Merge pull request #1711 from csarofeen/upstream_master_bump_0517 * 174d453d3be0c11a5acb0fff3b3f36e19cfdaf81 Allow using nvFuser on CUDA extension (#1701) * 18bee67495454b9a79625799776e746bd5e81c4c Validate LOOP concrete IDs have complete IterDomains (#1676) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/78244 Approved by: https://github.com/csarofeen, https://github.com/malfet
436 lines
14 KiB
C++
436 lines
14 KiB
C++
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
#include <iostream>
|
|
#include <unordered_set>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
IrBuilderPasskey::IrBuilderPasskey(IrContainer* ir_container)
|
|
: ir_container_(ir_container) {}
|
|
|
|
namespace kir {
|
|
|
|
namespace {
|
|
|
|
//! Scan all primary expressions in the Kernel IR and build
|
|
//! lists of specialized nodes and other interesting information
|
|
class KernelIrScanner : private IrVisitor {
|
|
public:
|
|
explicit KernelIrScanner(const Kernel* kernel) {
|
|
IrVisitor::handle(kernel->topLevelExprs());
|
|
const auto gpu_lower = GpuLower::current();
|
|
for (auto split : gpu_lower->nonDivisibleSplitInfo().splitsToValidate()) {
|
|
auto extent = split->in()->extent();
|
|
auto factor = split->factor();
|
|
summary_.splits_to_validate.emplace_back(extent, factor);
|
|
}
|
|
}
|
|
|
|
const auto& summary() const {
|
|
return summary_;
|
|
}
|
|
|
|
private:
|
|
using IrVisitor::handle;
|
|
void handle(Expr* expr) final {
|
|
IrVisitor::handle(expr);
|
|
for (auto inp : expr->inputs()) {
|
|
handle(inp);
|
|
}
|
|
for (auto out : expr->outputs()) {
|
|
handle(out);
|
|
}
|
|
}
|
|
void handle(BlockSync* sync) final {
|
|
// TODO: Move to a dedicated validation pass
|
|
// which is not on the common execution/compilation path
|
|
if (sync->isWarHazardSync()) {
|
|
++summary_.war_hazard_syncs_count;
|
|
}
|
|
}
|
|
|
|
void handle(GridSync* sync) final {
|
|
summary_.has_cooperative_grid_reduction = true;
|
|
}
|
|
|
|
void handle(Allocate* allocate) final {
|
|
switch (allocate->memoryType()) {
|
|
case MemoryType::Global:
|
|
summary_.global_allocations.push_back(allocate);
|
|
break;
|
|
case MemoryType::Shared:
|
|
summary_.dynamic_smem_allocations.push_back(allocate);
|
|
break;
|
|
case MemoryType::Local:
|
|
if (!ExpressionEvaluator::isConst(allocate->size())) {
|
|
summary_.has_dynamic_local_memory_allocations = true;
|
|
summary_.dynamic_lmem_allocations.emplace_back(allocate);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
|
|
void handle(UnaryOp* unary_op) final {
|
|
if (unary_op->getUnaryOpType() == UnaryOpType::RandLike) {
|
|
// This kernel is using random numbers
|
|
summary_.is_stochastic = true;
|
|
}
|
|
}
|
|
|
|
void handle(TensorIndex* tensor_index) final {
|
|
const auto tv = tensor_index->view();
|
|
const auto domain = tv->domain();
|
|
// Do we have any reductions?
|
|
summary_.has_block_reductions =
|
|
summary_.has_block_reductions || domain->hasBlockReduction();
|
|
|
|
// Update the largest smem data type
|
|
if (domain->hasBlockReduction() || domain->hasGridReduction() ||
|
|
tv->getMemoryType() == MemoryType::Shared) {
|
|
const auto data_type = tv->dtype();
|
|
const size_t type_size = dataTypeSize(data_type);
|
|
if (type_size > max_smem_type_size_) {
|
|
max_smem_type_size_ = type_size;
|
|
summary_.largest_smem_data_type = data_type;
|
|
}
|
|
}
|
|
}
|
|
|
|
void handle(WelfordOp* welford_op) final {
|
|
summary_.has_welford = true;
|
|
TORCH_INTERNAL_ASSERT(welford_op->outAvg()->isA<TensorIndex>());
|
|
auto out_dom = welford_op->outAvg()->as<TensorIndex>()->view()->domain();
|
|
summary_.has_block_welford =
|
|
summary_.has_block_welford || out_dom->hasBlockReduction();
|
|
}
|
|
|
|
void handle(GridWelford* grid_welford) final {
|
|
summary_.has_welford = true;
|
|
summary_.has_grid_welford = true;
|
|
summary_.has_grid_reductions = true;
|
|
if (grid_welford->welford_op()->isAllreduce()) {
|
|
summary_.has_cooperative_grid_reduction = true;
|
|
}
|
|
}
|
|
|
|
void handle(GridReduction* grid_reduction) final {
|
|
summary_.has_grid_reductions = true;
|
|
if (grid_reduction->isAllreduce()) {
|
|
summary_.has_cooperative_grid_reduction = true;
|
|
}
|
|
}
|
|
|
|
void handle(GroupedGridReduction* grid_reduction) final {
|
|
summary_.has_grid_reductions = true;
|
|
const auto dom = ir_utils::getTvOutput(grid_reduction)->domain();
|
|
updateGridReductionInLoop(dom);
|
|
if (grid_reduction->isAllreduce()) {
|
|
summary_.has_cooperative_grid_reduction = true;
|
|
}
|
|
}
|
|
|
|
void handle(GridBroadcast* grid_broadcast) final {
|
|
summary_.has_cooperative_grid_reduction = true;
|
|
handle(grid_broadcast->broadcast_op());
|
|
}
|
|
|
|
void handle(BroadcastOp* bop) final {
|
|
const ParallelTypeBitmap parallel_types =
|
|
GpuLower::current()->threadPredMap().getParallelBroadcastDomains(
|
|
bop->out()->as<TensorIndex>()->view());
|
|
summary_.broadcast_parallel_types.emplace(bop, parallel_types);
|
|
// Do we have block broadcasts?
|
|
summary_.has_block_broadcasts =
|
|
summary_.has_block_broadcasts || parallel_types.hasTID();
|
|
// Do we have grid broadcasts?
|
|
summary_.has_grid_broadcasts =
|
|
summary_.has_grid_broadcasts || parallel_types.hasBID();
|
|
}
|
|
|
|
private:
|
|
size_t max_smem_type_size_ = 0;
|
|
KernelSummary summary_;
|
|
|
|
private:
|
|
void updateGridReductionInLoop(TensorDomain* dom) {
|
|
for (const auto i : c10::irange(dom->nDims())) {
|
|
const auto id = GpuLower::current()->caMap()->getConcreteMappedID(
|
|
dom->domain()[i], IdMappingMode::LOOP);
|
|
|
|
summary_.has_cooperative_grid_reduction =
|
|
summary_.has_cooperative_grid_reduction ||
|
|
!(id->isThread() || id->extent()->isOneInt());
|
|
}
|
|
}
|
|
};
|
|
|
|
//! Make sure tensors have valid allocations even when parallelized
|
|
//! loops potentially have larger iteration counts than the number of
|
|
//! threads.
|
|
//!
|
|
//! When an IterDomain of a tensor is parallelized, the IterDomain
|
|
//! may not contribute to the allocation of the tensor. For example,
|
|
//! it is assumed that an allocation of a local-memory tensor does not
|
|
//! need to be accounted for an parallelied IterDomain. This is true
|
|
//! when it is guaranteed that each thread only needs to execute the
|
|
//! loop body once. However, if not, the allocation is invalid as it
|
|
//! only has a space for one value per thread.
|
|
//!
|
|
//! ValidateAllocation checks all tensor allocations and sees if any
|
|
//! tensor may have a parallelized loop whose iteration count may
|
|
//! be larger than the number of threads. If so, an error is thrown if
|
|
//! the tensor is not allocated on thread-shared memories. Note that
|
|
//! when allocated on a shared memory (i.e., MemoryType::Shared or
|
|
//! MemoryType::Global for tensors parallelized with threadIdx, or
|
|
//! MemoryType::Global for tensors parallelized with blockIdx), it is
|
|
//! assumed that allocation is properly extended for the iteration
|
|
//! count.
|
|
class ValidateAllocation : private OptOutConstDispatch {
|
|
public:
|
|
static void validate(const Kernel* kernel) {
|
|
ValidateAllocation validate_allocation(kernel);
|
|
}
|
|
|
|
private:
|
|
explicit ValidateAllocation(const Kernel* kernel) {
|
|
live_allocations_.emplace_back(std::vector<const Allocate*>());
|
|
for (const auto& expr : kernel->topLevelExprs()) {
|
|
OptOutConstDispatch::handle(expr);
|
|
}
|
|
live_allocations_.pop_back();
|
|
TORCH_INTERNAL_ASSERT(live_allocations_.empty());
|
|
}
|
|
|
|
void handle(const Allocate* allocate) final {
|
|
TORCH_INTERNAL_ASSERT(!live_allocations_.empty());
|
|
live_allocations_.back().push_back(allocate);
|
|
}
|
|
|
|
// for_loop is parallelized and its stop value is not guaranteed to
|
|
// be <= the number of threads, which breaks an assumption made
|
|
// during in the allocation lowering if it's thread-parallel and not
|
|
// allocated on shared or global memories, or if it's block-parallel
|
|
// ando not allocated on global memory.
|
|
void validate(const ForLoop* for_loop) {
|
|
const auto loop_id = for_loop->iter_domain();
|
|
for (const auto& allocations : live_allocations_) {
|
|
for (const auto& allocate : allocations) {
|
|
const auto tv = dynamic_cast<TensorView*>(allocate->buffer());
|
|
if (tv == nullptr) {
|
|
continue;
|
|
}
|
|
for (const auto& axis : tv->domain()->domain()) {
|
|
if (!GpuLower::current()->caMap()->areMapped(
|
|
loop_id, axis, IdMappingMode::LOOP)) {
|
|
continue;
|
|
}
|
|
if (isParallelTypeThreadDim(loop_id->getParallelType())) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
tv->getMemoryType() == MemoryType::Shared ||
|
|
tv->getMemoryType() == MemoryType::Global,
|
|
"Tensor t",
|
|
tv->name(),
|
|
" must be allocated on SMEM or GMEM.");
|
|
} else if (isParallelTypeBlockDim(loop_id->getParallelType())) {
|
|
TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void handle(const ForLoop* for_loop) final {
|
|
if (for_loop->stop() != for_loop->iter_domain()->extent() &&
|
|
isParallelTypeThread(for_loop->iter_domain()->getParallelType())) {
|
|
validate(for_loop);
|
|
}
|
|
|
|
live_allocations_.emplace_back(std::vector<const Allocate*>());
|
|
for (const auto& expr : for_loop->body().exprs()) {
|
|
OptOutConstDispatch::handle(expr);
|
|
}
|
|
live_allocations_.pop_back();
|
|
}
|
|
|
|
void handle(const IfThenElse* ite) final {
|
|
for (const auto& expr : ite->thenBody().exprs()) {
|
|
OptOutConstDispatch::handle(expr);
|
|
}
|
|
for (const auto& expr : ite->elseBody().exprs()) {
|
|
OptOutConstDispatch::handle(expr);
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::vector<std::vector<const Allocate*>> live_allocations_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
// TODO(kir): Kernel IR validation
|
|
void Kernel::finalize(std::vector<Expr*> top_level_exprs) {
|
|
TORCH_INTERNAL_ASSERT(top_level_exprs_.empty());
|
|
top_level_exprs_ = std::move(top_level_exprs);
|
|
warp_padded_parallel_info_ = GpuLower::current()->getWarpPaddedParallelInfo();
|
|
profile_ = GpuLower::current()->profile();
|
|
ValidateAllocation::validate(this);
|
|
analyze();
|
|
// Make sure this is after analyze as it sets summary_
|
|
summary_.vectorized_accesses = GpuLower::current()->vectorizedAccesses();
|
|
summary_.vectorized_set_info = GpuLower::current()->vectorizedSetInfo();
|
|
summary_.sync_map = GpuLower::current()->syncMap();
|
|
summary_.parallel_dimension_map_ =
|
|
GpuLower::current()->parallelDimensionMap();
|
|
}
|
|
|
|
void Kernel::analyze() {
|
|
FUSER_PERF_SCOPE("Kernel::analyze");
|
|
|
|
const KernelIrScanner ir_scanner(this);
|
|
summary_ = ir_scanner.summary();
|
|
}
|
|
|
|
void Kernel::print() const {
|
|
IrPrinter ir_printer(std::cout);
|
|
ir_printer.handle(this);
|
|
}
|
|
|
|
//! Register the Val with this fusion
|
|
void Kernel::registerVal(Val* val) {
|
|
if (inContainer(val)) {
|
|
return;
|
|
}
|
|
if (val->kernel()) {
|
|
TORCH_CHECK(
|
|
val->kernel() == this,
|
|
val->toString(),
|
|
" was not found in the active kernel.");
|
|
}
|
|
|
|
Fusion::registerVal(val);
|
|
}
|
|
|
|
//! Register expr with this fusion.
|
|
//! When we register an expression, we want to update the dependency tracking
|
|
//! of Vals. We add expr to our general expr_set_,
|
|
void Kernel::registerExpr(Expr* expr) {
|
|
if (inContainer(expr)) {
|
|
return;
|
|
}
|
|
|
|
if (expr->kernel()) {
|
|
TORCH_CHECK(
|
|
expr->kernel() == this,
|
|
expr->toString(),
|
|
" was not found in the active kernel.");
|
|
}
|
|
|
|
for (Val* input : expr->inputs()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
inContainer(input),
|
|
"Input\n",
|
|
input->toString(),
|
|
" to expr,\n",
|
|
expr->toString(),
|
|
",\n is invalid because it is not in the same kernel.");
|
|
}
|
|
|
|
for (Val* output : expr->outputs()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
inContainer(output),
|
|
"Output\n",
|
|
output->toString(),
|
|
" to expr,\n",
|
|
expr->toString(),
|
|
",\n is invalid because it is not in the same kernel.");
|
|
}
|
|
|
|
// Register expr is explicitly non-SSA when coming from a kernel. This is
|
|
// detected inside Fusion::registerExpr
|
|
Fusion::registerExpr(expr);
|
|
}
|
|
|
|
std::vector<Expr*>& KernelInternalProxy::topLevelExprs() {
|
|
return kernel_->top_level_exprs_;
|
|
}
|
|
|
|
void KernelPerformanceProfile::registerExpr(const Expr* expr) {
|
|
if (expr_entry_map_.find(expr) != expr_entry_map_.end()) {
|
|
return;
|
|
}
|
|
|
|
auto slot = getNewIndex();
|
|
expr_entry_map_.emplace(expr, slot);
|
|
}
|
|
|
|
int KernelPerformanceProfile::getNewIndex() {
|
|
return num_profile_entries_++;
|
|
}
|
|
|
|
bool KernelPerformanceProfile::isProfiled(const Expr* expr) const {
|
|
return expr_entry_map_.find(expr) != expr_entry_map_.end();
|
|
}
|
|
|
|
c10::optional<int> KernelPerformanceProfile::getIndex(const Expr* expr) const {
|
|
auto it = expr_entry_map_.find(expr);
|
|
if (it == expr_entry_map_.end()) {
|
|
return c10::optional<int>();
|
|
} else {
|
|
return it->second;
|
|
}
|
|
}
|
|
|
|
std::array<int, 2> KernelPerformanceProfile::getIndicesInProfileBuffer(
|
|
const Expr* expr) const {
|
|
TORCH_INTERNAL_ASSERT(
|
|
isProfiled(expr), "Not a profiled expression: ", expr->toString());
|
|
|
|
int cycle_index = getIndex(expr).value() * 2;
|
|
int count_index = cycle_index + 1;
|
|
|
|
return {cycle_index, count_index};
|
|
}
|
|
|
|
std::string KernelPerformanceProfile::toString(const at::Tensor& buffer) const {
|
|
std::stringstream ss;
|
|
ss << "Kernel performance profile:\n";
|
|
if (!buffer.defined()) {
|
|
ss << "No profile found\n";
|
|
return ss.str();
|
|
}
|
|
|
|
double kilo_freq = at::cuda::getCurrentDeviceProperties()->clockRate;
|
|
|
|
ss << std::setprecision(3) << std::fixed;
|
|
|
|
for (const auto& kv : expr_entry_map_) {
|
|
auto expr = kv.first;
|
|
auto index = kv.second;
|
|
auto out_tv = ir_utils::getTvOutput(expr);
|
|
double cycles = static_cast<double>(buffer[index][0].item<int64_t>());
|
|
auto count = buffer[index][1].item<int64_t>();
|
|
auto cycles_per_call = count == 0 ? 0.0 : cycles / count;
|
|
auto us_per_call = cycles_per_call / kilo_freq * 1000.0;
|
|
ss << expr->getExprType().value() << ", T" << out_tv->name() << ", "
|
|
<< us_per_call << " us, " << count << "\n";
|
|
}
|
|
|
|
return ss.str();
|
|
}
|
|
|
|
} // namespace kir
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|