mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: A lot of changes are in this update, some highlights: - Added Doxygen config file - Split the fusion IR (higher level TE like IR) from kernel IR (lower level CUDA like IR) - Improved latency with dynamic shape handling for the fusion logic - Prevent recompilation for pointwise + reduction fusions when not needed - Improvements to inner dimension reduction performance - Added input -> kernel + kernel launch parameters cache, added eviction policy - Added reduction fusions with multiple outputs (still single reduction stage) - Fixed code generation bugs for symbolic tiled GEMM example - Added thread predicates to prevent shared memory form being loaded multiple times - Improved sync threads placements with shared memory and removed read before write race - Fixes to FP16 reduction fusions where output would come back as FP32 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45218 Reviewed By: ezyang Differential Revision: D23905183 Pulled By: soumith fbshipit-source-id: 12f5ad4cbe03e9a25043bccb89e372f8579e2a79
158 lines
4.5 KiB
C++
158 lines
4.5 KiB
C++
|
|
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
|
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
|
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
|
|
|
#include <unordered_set>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
|
|
namespace {
|
|
|
|
//! Scan all primary expressions in the Kernel IR and build
|
|
//! list of specialized nodes
|
|
//!
|
|
//! \note primary expressions are expressions which are not subexpressions
|
|
//! in a larger expression (things like ForLoop or IfThenElse are not
|
|
//! real expressions)
|
|
//!
|
|
class KernelIrScanner : private OptOutDispatch {
|
|
public:
|
|
// Use expression count to uniquely identify each expression
|
|
size_t all_expression_count = 0;
|
|
|
|
// Map expression id to war hazard sync
|
|
std::unordered_map<size_t, kir::Sync*> war_hazard_syncs;
|
|
|
|
std::vector<kir::Allocate*> global_allocations;
|
|
std::vector<kir::Allocate*> dynamic_allocations;
|
|
std::vector<kir::Allocate*> static_allocations;
|
|
std::unordered_set<Expr*> primary_expressions;
|
|
|
|
public:
|
|
explicit KernelIrScanner(const std::vector<Expr*>& exprs) {
|
|
TORCH_INTERNAL_ASSERT(!exprs.empty());
|
|
for (auto expr : exprs) {
|
|
handle(expr);
|
|
}
|
|
}
|
|
|
|
private:
|
|
void handle(Expr* expr) final {
|
|
TORCH_CHECK(primary_expressions.insert(expr).second);
|
|
++all_expression_count;
|
|
OptOutDispatch::handle(expr);
|
|
}
|
|
|
|
void handle(kir::Sync* sync) final {
|
|
// TODO: Move to a dedicated validation pass
|
|
// which is not on the common execution/compilation path
|
|
if (sync->isWarHazardSync()) {
|
|
war_hazard_syncs[all_expression_count] = sync;
|
|
}
|
|
}
|
|
|
|
void handle(kir::ForLoop* fl) final {
|
|
for (auto expr : fl->body().exprs()) {
|
|
handle(expr);
|
|
}
|
|
}
|
|
|
|
void handle(kir::IfThenElse* ite) final {
|
|
for (auto expr : ite->thenBody().exprs()) {
|
|
handle(expr);
|
|
}
|
|
for (auto expr : ite->elseBody().exprs()) {
|
|
handle(expr);
|
|
}
|
|
}
|
|
|
|
void handle(kir::Allocate* a) final {
|
|
switch (a->getMemoryType()) {
|
|
case MemoryType::Global:
|
|
global_allocations.push_back(a);
|
|
break;
|
|
case MemoryType::Shared:
|
|
if (a->size()->isConstScalar()) {
|
|
static_allocations.push_back(a);
|
|
} else {
|
|
dynamic_allocations.push_back(a);
|
|
}
|
|
break;
|
|
case MemoryType::Local:
|
|
break;
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
// TODO(kir): Kernel IR validation
|
|
void Kernel::finalize(
|
|
std::vector<Expr*> top_level_exprs,
|
|
ThreadPredicateMap predicate_map) {
|
|
TORCH_CHECK(top_level_exprs_.empty());
|
|
TORCH_CHECK(!predicate_map_);
|
|
top_level_exprs_ = std::move(top_level_exprs);
|
|
predicate_map_ =
|
|
std::make_unique<ThreadPredicateMap>(std::move(predicate_map));
|
|
analyze();
|
|
}
|
|
|
|
void Kernel::analyze() {
|
|
FUSER_PERF_SCOPE("Kernel::analyze");
|
|
|
|
const KernelIrScanner ir_scanner(top_level_exprs_);
|
|
|
|
// Cache the list of buffers used within the kernel
|
|
summary_.war_hazard_syncs = ir_scanner.war_hazard_syncs;
|
|
summary_.global_allocations = ir_scanner.global_allocations;
|
|
summary_.dynamic_smem_allocations = ir_scanner.dynamic_allocations;
|
|
summary_.static_smem_allocations = ir_scanner.static_allocations;
|
|
|
|
// Figure out if the kernel uses random numbers
|
|
for (auto expr : ir_scanner.primary_expressions) {
|
|
if (expr->getExprType() == ExprType::KirUnaryOp) {
|
|
if (expr->as<kir::UnaryOp>()->getUnaryOpType() == UnaryOpType::RandLike) {
|
|
summary_.is_stochastic = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Look for reductions and shared memory buffers
|
|
size_t max_smem_type_size = 0;
|
|
for (auto expr : ir_scanner.primary_expressions) {
|
|
for (auto out : expr->outputs()) {
|
|
if (out->getValType() == ValType::TensorIndex) {
|
|
const auto tv = out->as<kir::TensorIndex>()->view();
|
|
const auto domain = tv->domain();
|
|
|
|
// Do we have any reductions?
|
|
summary_.has_block_reductions |= domain->hasBlockReduction();
|
|
summary_.has_grid_reductions |= domain->hasGridReduction();
|
|
|
|
// Do we have block broadcasts?
|
|
summary_.has_block_broadcasts |= domain->hasBlockBroadcast();
|
|
|
|
// Update the largest smem data type
|
|
if (domain->hasBlockReduction() || domain->hasGridReduction() ||
|
|
tv->memoryType() == MemoryType::Shared) {
|
|
const auto data_type = tv->getDataType().value();
|
|
const size_t type_size = dataTypeSize(data_type);
|
|
if (type_size > max_smem_type_size) {
|
|
max_smem_type_size = type_size;
|
|
summary_.largest_smem_data_type = data_type;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|