mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD 16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875) 501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823) a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872) 172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871) 440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870) d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864) 51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866) a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861) 1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852) e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858) 9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857) 20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855) 405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853) 5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845) db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848) 102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847) b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842) 0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840) 63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839) 2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067 Approved by: https://github.com/davidberard98
396 lines
13 KiB
C++
396 lines
13 KiB
C++
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
|
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_index.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_instrument.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_replace_size.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
|
|
|
|
#include <list>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT
|
|
namespace {
|
|
|
|
class KIRCleaner : public OptOutDispatch {
|
|
public:
|
|
//! Remove nop IR nodes
|
|
static std::vector<Expr*> cleanUp(const std::vector<Expr*>& loop_nests) {
|
|
KIRCleaner cleaner;
|
|
std::vector<Expr*> out_loop_nests;
|
|
for (auto loop_nest : loop_nests) {
|
|
cleaner.handle(loop_nest);
|
|
// No need to keep the loop nest if it's determined to be nop
|
|
if (!cleaner.is_nop_) {
|
|
out_loop_nests.push_back(loop_nest);
|
|
}
|
|
}
|
|
return out_loop_nests;
|
|
}
|
|
|
|
private:
|
|
using OptOutDispatch::handle;
|
|
void handle(Expr* expr) final {
|
|
if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
|
|
OptOutDispatch::handle(expr);
|
|
} else {
|
|
// Any non-scoping expr is not considered nop
|
|
is_nop_ = false;
|
|
}
|
|
}
|
|
|
|
void handle(kir::ForLoop* fl) final {
|
|
auto exprs = fl->body().exprs();
|
|
fl->body().clear();
|
|
for (auto expr : exprs) {
|
|
handle(expr);
|
|
// Add the expr to the loop body only when the expr is not nop
|
|
if (!is_nop_) {
|
|
fl->body().push_back(expr);
|
|
}
|
|
}
|
|
// The loop is nop when no expr exists in the body
|
|
is_nop_ = fl->body().empty();
|
|
}
|
|
|
|
void handle(kir::IfThenElse* ite) final {
|
|
const auto conditional = ite->predicate()->value();
|
|
|
|
// Visit the then block
|
|
auto then_exprs = ite->thenBody().exprs();
|
|
ite->thenBody().clear();
|
|
if (!conditional->isConst() || conditional->value().value()) {
|
|
for (auto expr : then_exprs) {
|
|
handle(expr);
|
|
if (!is_nop_) {
|
|
ite->thenBody().push_back(expr);
|
|
}
|
|
}
|
|
}
|
|
|
|
const bool then_nop = ite->thenBody().empty();
|
|
|
|
// Visit the else block
|
|
auto else_exprs = ite->elseBody().exprs();
|
|
ite->elseBody().clear();
|
|
if (!conditional->isConst() || !conditional->value().value()) {
|
|
for (auto expr : else_exprs) {
|
|
handle(expr);
|
|
if (!is_nop_) {
|
|
ite->elseBody().push_back(expr);
|
|
}
|
|
}
|
|
}
|
|
|
|
const bool else_nop = ite->elseBody().empty();
|
|
|
|
// If the then block is nop but the else is not, invert the
|
|
// conditional and move the exprs in the else block to the then
|
|
// block.
|
|
if (then_nop && !else_nop) {
|
|
Bool* pred = ite->predicate()->value();
|
|
Bool* not_pred = SimplifyingIrBuilder::notExpr(pred)->as<Bool>();
|
|
ite->predicate()->setValue(not_pred);
|
|
for (auto expr : ite->elseBody().exprs()) {
|
|
ite->thenBody().push_back(expr);
|
|
}
|
|
ite->elseBody().clear();
|
|
}
|
|
|
|
// This IfThenElse is nop if both the then and else blocks are nop
|
|
is_nop_ = then_nop && else_nop;
|
|
}
|
|
|
|
private:
|
|
//! True if the last visited expr is nop
|
|
bool is_nop_ = false;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void GpuLower::collectPaddedParallelDims() {
|
|
ExpressionEvaluator ee(fusion_);
|
|
bool can_be_single_warp = true;
|
|
|
|
auto warp_size = at::cuda::warp_size();
|
|
|
|
auto used_vals = fusion_->usedMathVals();
|
|
for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
|
|
for (auto id : tv->domain()->domain()) {
|
|
if (tv->definition()) {
|
|
// TODO: Support GroupedReductionOp
|
|
if (auto reduction = dynamic_cast<ReductionOp*>(tv->definition())) {
|
|
if (ir_utils::getMaybeWarpReductionDim(
|
|
reduction->out(), reduction->in())
|
|
.has_value()) {
|
|
warp_pad_info_.has_warp_reduction = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check ifi TIDx is padded in this kernel
|
|
if (id->hasPaddingToMultipleOfWarp()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
id->getParallelType() == ParallelType::TIDx,
|
|
"Padded types supported only on TIDx");
|
|
warp_pad_info_.is_tidx_padded = true;
|
|
}
|
|
|
|
// Check all possible bindings of TIDx to see
|
|
// if TIDx will eventually be bound to a single warp.
|
|
if (id->getParallelType() == ParallelType::TIDx) {
|
|
auto eval_dim = ee.evaluate(id->extent());
|
|
auto size_after_padding = id->getMaybeSizeAfterPadding();
|
|
bool padding_to_single_warp = size_after_padding.has_value() &&
|
|
size_after_padding.value() == warp_size;
|
|
|
|
if ((!eval_dim.has_value() || eval_dim.value() > warp_size) &&
|
|
!padding_to_single_warp) {
|
|
// If we see any other TIDx binding that's larger than
|
|
// a warp or unknown, we shouldn't lower warp reduce
|
|
// to a single warp type.
|
|
can_be_single_warp = false;
|
|
warp_pad_info_.is_tidx_single_warp = false;
|
|
} else if (can_be_single_warp) {
|
|
if (padding_to_single_warp ||
|
|
(eval_dim.has_value() && eval_dim.value() == warp_size)) {
|
|
warp_pad_info_.is_tidx_single_warp = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void GpuLower::lower(Fusion* fusion, DataType index_type) {
|
|
FUSER_PERF_SCOPE("GpuLower::lower");
|
|
TORCH_INTERNAL_ASSERT(fusion != nullptr);
|
|
TORCH_INTERNAL_ASSERT(
|
|
active_gpu_lower == nullptr, "Nested lowering passes are not supported");
|
|
|
|
struct LowerGuard {
|
|
LowerGuard(GpuLower* gpu_lower) {
|
|
active_gpu_lower = gpu_lower;
|
|
}
|
|
~LowerGuard() {
|
|
active_gpu_lower = nullptr;
|
|
}
|
|
} lower_guard(this);
|
|
// Copy fusion into a new kernel for processing
|
|
kernel_ = std::make_unique<kir::Kernel>(fusion, index_type);
|
|
// Alias the fusion kernel caries around as a view of itself.
|
|
fusion_ = kernel_.get();
|
|
|
|
// Convert tensor views of DataType::Index type to either Int or Int32
|
|
for (auto tv : ir_utils::allTvs(fusion_)) {
|
|
if (tv->dtype() == DataType::Index) {
|
|
tv->resolveIndexDtype();
|
|
}
|
|
}
|
|
|
|
FusionGuard fg(fusion_);
|
|
// prepare for lowering
|
|
validateIr(fusion_);
|
|
|
|
// Checks if any TIDx dim is marked as padded to a warp. Also checks if we can
|
|
// determine the padding is explicitly a single warp.
|
|
collectPaddedParallelDims();
|
|
|
|
// Replaces integers that are tensor sizes by named scalars as "T0.size[0]"
|
|
replaceSymbolicSizes(fusion_);
|
|
|
|
// Traverse through reductions and termine if any iteration domains are
|
|
// trivial reductions. Add these iteration domains to trivial_reduction_info_
|
|
// which simply holds a map of which axes are trivial and which are not.
|
|
trivial_reduction_info_.build(fusion_);
|
|
// Replaces trivial reduction expressions (all id's being reduced are trivial)
|
|
// with set unary op
|
|
trivialReductionReplacement(fusion_, trivial_reduction_info_);
|
|
|
|
// Build what's refered to as the compute at map. This map contains the
|
|
// mappings of all iteration domains across the fusion. There are three types
|
|
// of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more
|
|
// information.
|
|
compute_at_map_ = std::make_unique<ComputeAtMap>(fusion_);
|
|
|
|
if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) {
|
|
std::cout << compute_at_map_->toString() << std::endl;
|
|
}
|
|
|
|
compute_at_map_->validateAndPropagatePType();
|
|
|
|
// Used in parallel dimension map
|
|
concretized_broadcast_domains_.build(fusion_);
|
|
|
|
parallelDimensionMap().build(fusion_);
|
|
if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) {
|
|
std::cout << "Parallel dimension map:" << std::endl;
|
|
std::cout << parallel_dimension_map_.toString() << std::endl;
|
|
}
|
|
|
|
// Validate mma data format and compatibility if any on the fusion.
|
|
validateMma(fusion_);
|
|
|
|
// Validate swizzle usage on the fusion schedule.
|
|
validateSwizzle(fusion_);
|
|
|
|
// Compute thread predicates. Depends on parallel_dimension_map_
|
|
thread_pred_map_.build(fusion_);
|
|
|
|
// Fuse cetain patterns of reductions, such as a grid reduction
|
|
// followed by a grid broadcast. Only depends on parallelization and
|
|
// thread predicate map.
|
|
fuseReductionsAndBroadcasts(fusion_);
|
|
|
|
// Scan the whole fusion and build mappings about halo extensions of
|
|
// all IterDomains
|
|
haloInfo().build(fusion_);
|
|
|
|
// Want to run this after parallel map and halo info map are
|
|
// created. vectorized_accesses_ and vectorized_set_info_ are filled.
|
|
validateAndCollectVectorizeInfo(fusion_);
|
|
|
|
// Depends on ComputeAtMap and HaloInfo.
|
|
validateAndConvertIterDomainGrouping(fusion_);
|
|
|
|
// Assumes all grouped reductions are convered to
|
|
// GroupedReductionOp, which is done by
|
|
// validateAndConvertIterDomainGrouping
|
|
validateGroupedReductions(fusion_);
|
|
|
|
// Depends on thread_pred_map_, validates parallelization collects which
|
|
// tensor views need WAR or RAW syncs
|
|
sync_map_.build(fusion_);
|
|
|
|
partialSplitMap().build(fusion_);
|
|
|
|
validatePartialSplit(fusion_);
|
|
|
|
nonDivisibleSplitInfo().build(fusion_);
|
|
|
|
// Detects all exprssions that don't need predicates. Depends on
|
|
// nonDivisibleSplitInfo.
|
|
predicateElimination().build(fusion_);
|
|
|
|
doubleBufferInfo().build(fusion_);
|
|
|
|
compute_at_map_->allocateIndexVariables();
|
|
// Run our passes keeping the lowered expressions and forwarding
|
|
// them
|
|
|
|
// Reorder expressions for loop-nest generation respecting computeAt
|
|
// relationships
|
|
const auto exprs_sorted = reorderExprsForComputeAt();
|
|
|
|
// Generate loop-nests and place each expression at its
|
|
// corresponding loop
|
|
const auto exprs_lowered = LoopNestGenerator::loweredExprs(exprs_sorted);
|
|
|
|
// Replace trivial reductions, Transpose, Shift, Gather, and View ops with
|
|
// unary ops since they're not separately processed in lowering.
|
|
const auto exprs_unary_replaced = unarySetOpInserter(exprs_lowered);
|
|
|
|
// Insert allocations
|
|
const auto exprs_alloced = insertAllocations(exprs_unary_replaced);
|
|
|
|
// Insert read after write smem syncs
|
|
const auto exprs_raw_sync = insertRawThreadSynchronization(exprs_alloced);
|
|
|
|
// Reuse memory locations
|
|
const auto exprs_reuse_mem = reuseMemoryAllocations(exprs_raw_sync);
|
|
|
|
// Insert SyncThreads at end of for-loop to avoid WAR race condition
|
|
const auto exprs_war_sync = insertWarThreadSynchronization(exprs_reuse_mem);
|
|
|
|
const auto exprs_double_buffered = DoubleBufferPass::run(exprs_war_sync);
|
|
|
|
// This pass inserts predicates as well as branches in the code. Up until now
|
|
// the code is explicitly single shot for loop based. Need to be careful in
|
|
// later passes when doing any kind of insertions in loop nest structure as
|
|
// insertions could be on if then or else instead of directly on a for loop.
|
|
const auto exprs_unrolled_loops =
|
|
UnrollPass::runPass(fusion_, exprs_double_buffered);
|
|
|
|
const auto exprs_unrolled_mv_loops =
|
|
processMisalignedVectorization(exprs_unrolled_loops);
|
|
|
|
const auto exprs_indexed_loops =
|
|
IndexLowering::getIndexedExprs(exprs_unrolled_mv_loops);
|
|
|
|
// TODO: It seems this type of optimization would be far easier to implement
|
|
// on fusion ir than kernel ir. We should likely refactor this to at least run
|
|
// before allocation insertion.
|
|
const auto exprs_with_fused_broadcast = fuseWarpReduce(exprs_indexed_loops);
|
|
|
|
const auto exprs_conditional_loops =
|
|
generateConditionalFromPredicate(exprs_with_fused_broadcast);
|
|
|
|
const auto exprs_common_index_allocated =
|
|
allocateCommonIndices(exprs_conditional_loops);
|
|
|
|
// Insert fake zero updates to make sure nvrtc doesn't blow out register use
|
|
// on index and predicate reuse
|
|
const auto exprs_register_adjusted =
|
|
insertMagicZero(exprs_common_index_allocated);
|
|
|
|
const auto exprs_cleaned_up_loops =
|
|
KIRCleaner::cleanUp(exprs_register_adjusted);
|
|
|
|
const auto exprs_instrumented = instrumentKernel(exprs_cleaned_up_loops);
|
|
|
|
// We now have the lowered expressions, finalize the kernel IR. This function
|
|
// will also copy over some relevant information for code generation from
|
|
// GpuLower.
|
|
kernel_->finalize(exprs_instrumented);
|
|
}
|
|
|
|
kir::Kernel* GpuLower::kernel() const {
|
|
TORCH_CHECK(kernel_);
|
|
return kernel_.get();
|
|
}
|
|
|
|
GpuLower* GpuLower::current() {
|
|
TORCH_INTERNAL_ASSERT(
|
|
active_gpu_lower != nullptr, "No active GpuLower available");
|
|
return active_gpu_lower;
|
|
}
|
|
|
|
bool GpuLower::hasCurrent() {
|
|
return active_gpu_lower != nullptr;
|
|
}
|
|
|
|
void GpuLower::propagateExprInfo(const Expr* old_expr, const Expr* new_expr) {
|
|
pred_elimination_.propagateRemovalInfo(old_expr, new_expr);
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|