pytorch/torch/csrc/jit/codegen/cuda/lower2device.cpp
jjsjann123 df741c589f [NVFuser] Upstream push 0809 (#83067)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. removes un-necessary sync from redundant thread compute analysis
  2. symmetric API for BestEffortReplay
  3. support merge on trivial reductions
  4. Ampere async copy improvements
- bug fixes:
  1. vectorization bug fixes
  2. type inference patch : fixes upstream #81725
  3. segmenter bug fix with deterministic iteration ordering
- parser update
  1. added leaky_relu
- scheduler
  1. normalization scheduler clean up.
  2. simplifies matmul scheduling with new transform propagator
  3. merge all dimensions in PW scheduler
  4. various gemm related improvements
- debuggability
  1. nsight compute support
  2. debug dump for InlinePropagator
  3. Add `UnaryOpType::Print`

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD
16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884)
7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803
3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD
01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878)
0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881)
7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875)
501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827)
e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824)
9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823)
a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872)
172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868)
a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871)
440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870)
d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867)
1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864)
51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866)
a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering.  (#1865)
1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861)
1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator.  (#1817)
bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852)
e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858)
9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857)
20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855)
405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822)
01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853)
5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846)
92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845)
db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848)
102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847)
b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687)
942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842)
0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840)
63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839)
2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067
Approved by: https://github.com/davidberard98
2022-08-10 21:02:56 +00:00

396 lines
13 KiB
C++

#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
#include <torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h>
#include <torch/csrc/jit/codegen/cuda/lower_index.h>
#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
#include <torch/csrc/jit/codegen/cuda/lower_instrument.h>
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_replace_size.h>
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
#include <list>
#include <unordered_map>
#include <unordered_set>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT
namespace {
class KIRCleaner : public OptOutDispatch {
public:
//! Remove nop IR nodes
static std::vector<Expr*> cleanUp(const std::vector<Expr*>& loop_nests) {
KIRCleaner cleaner;
std::vector<Expr*> out_loop_nests;
for (auto loop_nest : loop_nests) {
cleaner.handle(loop_nest);
// No need to keep the loop nest if it's determined to be nop
if (!cleaner.is_nop_) {
out_loop_nests.push_back(loop_nest);
}
}
return out_loop_nests;
}
private:
using OptOutDispatch::handle;
void handle(Expr* expr) final {
if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
OptOutDispatch::handle(expr);
} else {
// Any non-scoping expr is not considered nop
is_nop_ = false;
}
}
void handle(kir::ForLoop* fl) final {
auto exprs = fl->body().exprs();
fl->body().clear();
for (auto expr : exprs) {
handle(expr);
// Add the expr to the loop body only when the expr is not nop
if (!is_nop_) {
fl->body().push_back(expr);
}
}
// The loop is nop when no expr exists in the body
is_nop_ = fl->body().empty();
}
void handle(kir::IfThenElse* ite) final {
const auto conditional = ite->predicate()->value();
// Visit the then block
auto then_exprs = ite->thenBody().exprs();
ite->thenBody().clear();
if (!conditional->isConst() || conditional->value().value()) {
for (auto expr : then_exprs) {
handle(expr);
if (!is_nop_) {
ite->thenBody().push_back(expr);
}
}
}
const bool then_nop = ite->thenBody().empty();
// Visit the else block
auto else_exprs = ite->elseBody().exprs();
ite->elseBody().clear();
if (!conditional->isConst() || !conditional->value().value()) {
for (auto expr : else_exprs) {
handle(expr);
if (!is_nop_) {
ite->elseBody().push_back(expr);
}
}
}
const bool else_nop = ite->elseBody().empty();
// If the then block is nop but the else is not, invert the
// conditional and move the exprs in the else block to the then
// block.
if (then_nop && !else_nop) {
Bool* pred = ite->predicate()->value();
Bool* not_pred = SimplifyingIrBuilder::notExpr(pred)->as<Bool>();
ite->predicate()->setValue(not_pred);
for (auto expr : ite->elseBody().exprs()) {
ite->thenBody().push_back(expr);
}
ite->elseBody().clear();
}
// This IfThenElse is nop if both the then and else blocks are nop
is_nop_ = then_nop && else_nop;
}
private:
//! True if the last visited expr is nop
bool is_nop_ = false;
};
} // namespace
void GpuLower::collectPaddedParallelDims() {
ExpressionEvaluator ee(fusion_);
bool can_be_single_warp = true;
auto warp_size = at::cuda::warp_size();
auto used_vals = fusion_->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
for (auto id : tv->domain()->domain()) {
if (tv->definition()) {
// TODO: Support GroupedReductionOp
if (auto reduction = dynamic_cast<ReductionOp*>(tv->definition())) {
if (ir_utils::getMaybeWarpReductionDim(
reduction->out(), reduction->in())
.has_value()) {
warp_pad_info_.has_warp_reduction = true;
}
}
}
// Check ifi TIDx is padded in this kernel
if (id->hasPaddingToMultipleOfWarp()) {
TORCH_INTERNAL_ASSERT(
id->getParallelType() == ParallelType::TIDx,
"Padded types supported only on TIDx");
warp_pad_info_.is_tidx_padded = true;
}
// Check all possible bindings of TIDx to see
// if TIDx will eventually be bound to a single warp.
if (id->getParallelType() == ParallelType::TIDx) {
auto eval_dim = ee.evaluate(id->extent());
auto size_after_padding = id->getMaybeSizeAfterPadding();
bool padding_to_single_warp = size_after_padding.has_value() &&
size_after_padding.value() == warp_size;
if ((!eval_dim.has_value() || eval_dim.value() > warp_size) &&
!padding_to_single_warp) {
// If we see any other TIDx binding that's larger than
// a warp or unknown, we shouldn't lower warp reduce
// to a single warp type.
can_be_single_warp = false;
warp_pad_info_.is_tidx_single_warp = false;
} else if (can_be_single_warp) {
if (padding_to_single_warp ||
(eval_dim.has_value() && eval_dim.value() == warp_size)) {
warp_pad_info_.is_tidx_single_warp = true;
}
}
}
}
}
}
void GpuLower::lower(Fusion* fusion, DataType index_type) {
FUSER_PERF_SCOPE("GpuLower::lower");
TORCH_INTERNAL_ASSERT(fusion != nullptr);
TORCH_INTERNAL_ASSERT(
active_gpu_lower == nullptr, "Nested lowering passes are not supported");
struct LowerGuard {
LowerGuard(GpuLower* gpu_lower) {
active_gpu_lower = gpu_lower;
}
~LowerGuard() {
active_gpu_lower = nullptr;
}
} lower_guard(this);
// Copy fusion into a new kernel for processing
kernel_ = std::make_unique<kir::Kernel>(fusion, index_type);
// Alias the fusion kernel caries around as a view of itself.
fusion_ = kernel_.get();
// Convert tensor views of DataType::Index type to either Int or Int32
for (auto tv : ir_utils::allTvs(fusion_)) {
if (tv->dtype() == DataType::Index) {
tv->resolveIndexDtype();
}
}
FusionGuard fg(fusion_);
// prepare for lowering
validateIr(fusion_);
// Checks if any TIDx dim is marked as padded to a warp. Also checks if we can
// determine the padding is explicitly a single warp.
collectPaddedParallelDims();
// Replaces integers that are tensor sizes by named scalars as "T0.size[0]"
replaceSymbolicSizes(fusion_);
// Traverse through reductions and termine if any iteration domains are
// trivial reductions. Add these iteration domains to trivial_reduction_info_
// which simply holds a map of which axes are trivial and which are not.
trivial_reduction_info_.build(fusion_);
// Replaces trivial reduction expressions (all id's being reduced are trivial)
// with set unary op
trivialReductionReplacement(fusion_, trivial_reduction_info_);
// Build what's refered to as the compute at map. This map contains the
// mappings of all iteration domains across the fusion. There are three types
// of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more
// information.
compute_at_map_ = std::make_unique<ComputeAtMap>(fusion_);
if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) {
std::cout << compute_at_map_->toString() << std::endl;
}
compute_at_map_->validateAndPropagatePType();
// Used in parallel dimension map
concretized_broadcast_domains_.build(fusion_);
parallelDimensionMap().build(fusion_);
if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) {
std::cout << "Parallel dimension map:" << std::endl;
std::cout << parallel_dimension_map_.toString() << std::endl;
}
// Validate mma data format and compatibility if any on the fusion.
validateMma(fusion_);
// Validate swizzle usage on the fusion schedule.
validateSwizzle(fusion_);
// Compute thread predicates. Depends on parallel_dimension_map_
thread_pred_map_.build(fusion_);
// Fuse cetain patterns of reductions, such as a grid reduction
// followed by a grid broadcast. Only depends on parallelization and
// thread predicate map.
fuseReductionsAndBroadcasts(fusion_);
// Scan the whole fusion and build mappings about halo extensions of
// all IterDomains
haloInfo().build(fusion_);
// Want to run this after parallel map and halo info map are
// created. vectorized_accesses_ and vectorized_set_info_ are filled.
validateAndCollectVectorizeInfo(fusion_);
// Depends on ComputeAtMap and HaloInfo.
validateAndConvertIterDomainGrouping(fusion_);
// Assumes all grouped reductions are convered to
// GroupedReductionOp, which is done by
// validateAndConvertIterDomainGrouping
validateGroupedReductions(fusion_);
// Depends on thread_pred_map_, validates parallelization collects which
// tensor views need WAR or RAW syncs
sync_map_.build(fusion_);
partialSplitMap().build(fusion_);
validatePartialSplit(fusion_);
nonDivisibleSplitInfo().build(fusion_);
// Detects all exprssions that don't need predicates. Depends on
// nonDivisibleSplitInfo.
predicateElimination().build(fusion_);
doubleBufferInfo().build(fusion_);
compute_at_map_->allocateIndexVariables();
// Run our passes keeping the lowered expressions and forwarding
// them
// Reorder expressions for loop-nest generation respecting computeAt
// relationships
const auto exprs_sorted = reorderExprsForComputeAt();
// Generate loop-nests and place each expression at its
// corresponding loop
const auto exprs_lowered = LoopNestGenerator::loweredExprs(exprs_sorted);
// Replace trivial reductions, Transpose, Shift, Gather, and View ops with
// unary ops since they're not separately processed in lowering.
const auto exprs_unary_replaced = unarySetOpInserter(exprs_lowered);
// Insert allocations
const auto exprs_alloced = insertAllocations(exprs_unary_replaced);
// Insert read after write smem syncs
const auto exprs_raw_sync = insertRawThreadSynchronization(exprs_alloced);
// Reuse memory locations
const auto exprs_reuse_mem = reuseMemoryAllocations(exprs_raw_sync);
// Insert SyncThreads at end of for-loop to avoid WAR race condition
const auto exprs_war_sync = insertWarThreadSynchronization(exprs_reuse_mem);
const auto exprs_double_buffered = DoubleBufferPass::run(exprs_war_sync);
// This pass inserts predicates as well as branches in the code. Up until now
// the code is explicitly single shot for loop based. Need to be careful in
// later passes when doing any kind of insertions in loop nest structure as
// insertions could be on if then or else instead of directly on a for loop.
const auto exprs_unrolled_loops =
UnrollPass::runPass(fusion_, exprs_double_buffered);
const auto exprs_unrolled_mv_loops =
processMisalignedVectorization(exprs_unrolled_loops);
const auto exprs_indexed_loops =
IndexLowering::getIndexedExprs(exprs_unrolled_mv_loops);
// TODO: It seems this type of optimization would be far easier to implement
// on fusion ir than kernel ir. We should likely refactor this to at least run
// before allocation insertion.
const auto exprs_with_fused_broadcast = fuseWarpReduce(exprs_indexed_loops);
const auto exprs_conditional_loops =
generateConditionalFromPredicate(exprs_with_fused_broadcast);
const auto exprs_common_index_allocated =
allocateCommonIndices(exprs_conditional_loops);
// Insert fake zero updates to make sure nvrtc doesn't blow out register use
// on index and predicate reuse
const auto exprs_register_adjusted =
insertMagicZero(exprs_common_index_allocated);
const auto exprs_cleaned_up_loops =
KIRCleaner::cleanUp(exprs_register_adjusted);
const auto exprs_instrumented = instrumentKernel(exprs_cleaned_up_loops);
// We now have the lowered expressions, finalize the kernel IR. This function
// will also copy over some relevant information for code generation from
// GpuLower.
kernel_->finalize(exprs_instrumented);
}
kir::Kernel* GpuLower::kernel() const {
TORCH_CHECK(kernel_);
return kernel_.get();
}
GpuLower* GpuLower::current() {
TORCH_INTERNAL_ASSERT(
active_gpu_lower != nullptr, "No active GpuLower available");
return active_gpu_lower;
}
bool GpuLower::hasCurrent() {
return active_gpu_lower != nullptr;
}
void GpuLower::propagateExprInfo(const Expr* old_expr, const Expr* new_expr) {
pred_elimination_.propagateRemovalInfo(old_expr, new_expr);
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch