pytorch/torch/csrc/jit/codegen/cuda/lower2device.h
jjsjann123 9e52ad28c9 [nvfuser_upstream_push] nvfuser code base bump 052422 (#78244)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

A few bigger updates:
1. Initial support of cp.async and cp.async.wait: https://github.com/csarofeen/pytorch/pull/1619
2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: https://github.com/csarofeen/pytorch/pull/1643
3. Extending the infrastructure to support mma operators on turing and ampere arch: https://github.com/csarofeen/pytorch/pull/1440

Commits that's actually in this PR from the csarofeen branch
```
* dd2325294e236c5082c642819a1103bcfe4561a3 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710)
* b3d1c3f446355a2d276bac8272e7aa8b5bb6b1f0 Fix missing cooperative launch (#1726)
* dc670a226cbe52be46cecef47001f38bf9a09433 Async gmem copy support on sm80+ (#1619)
* 5e6a8dab5a71aefe0548bbfa15d1a93c556d23fe Add turing mma support and test (#1643)
* d6d6b7d3f10dd91dafa4cdbd5e460bbb38173af4 Fix rFactor when there are indirect root domain(s), and refactor (#1723)
* 7093e39150c6d80e0f9f767d56654714a2e8a927 Mma op integration on ampere (#1440)
* fade8da55e60a118c5595378896d34b862b2fcc3 patch python test for bfloat16 (#1724)
* 8fbd0b18743a72ac10478857c3d2351204375685 Fine-grained kernel profiling (#1720)
* 77c1b4fa633f9e631d267923f4537336fa328939 Adding dry run mode to skip arch dependent checks (#1702)
* 151d95b97bebefc94199bb4a53423ede32b55451 More precise concretization analysis (#1719)
* f4d3630ed54d7069dd377a64be1f91013b285b66 Enable complex python tests (#1667)
* 4ceeee509774cc2ce6c834a4dc1e313f71d94503 Minor bugfix in transform_rfactor.cpp (#1715)
* 3675c70faf218e86d2c78dbd3874b175a3b0a203 Separate root domain and rfactor domain in TransformPrinter (#1716)
* f68b830d5def65dadfe29d4edf52fc703369c84a Fix scheduling with polymorphic broadcast (#1714)
* 4ab5ef7ae2cfd8fffad1e1d882ae7c50631211dc updating_ci_machine (#1718)
* 56585c58b1ff338704cafb0cd6be2b3d536bed5a Merge pull request #1711 from csarofeen/upstream_master_bump_0517
* 174d453d3be0c11a5acb0fff3b3f36e19cfdaf81 Allow using nvFuser on CUDA extension (#1701)
* 18bee67495454b9a79625799776e746bd5e81c4c Validate LOOP concrete IDs have complete IterDomains (#1676)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78244
Approved by: https://github.com/csarofeen, https://github.com/malfet
2022-06-07 17:30:51 -07:00

232 lines
7.0 KiB
C++

#pragma once
#include <c10/macros/Export.h>
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
#include <torch/csrc/jit/codegen/cuda/lower_fused_reduction.h>
#include <torch/csrc/jit/codegen/cuda/lower_index_hoist.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate_elimination.h>
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
#include <torch/csrc/jit/codegen/cuda/lower_sync_information.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
#include <torch/csrc/jit/codegen/cuda/non_divisible_split.h>
#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
#include <torch/csrc/jit/codegen/cuda/partial_split_map.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/vectorization_info.h>
#include <memory>
#include <ostream>
#include <unordered_map>
#include <unordered_set>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
// TODO: we frequently use pairwise root mapping from consumers to producers.
// This information is implicitly in the computeAtMaps, but there's no isolated
// container for this information that we can reuse. Would be nice to generate
// such a structure and propagate it through lowering.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
class KernelIrMapper;
public:
GpuLower() = delete;
// GpuLower lowers the provided fusion into a kernel which can be translated
// into cuda code. index_type allows to compile the kernel based on int32
// indexing instead of int64 for additional performance.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
explicit GpuLower(Fusion* fusion, DataType index_type = DataType::Int) {
lower(fusion, index_type);
}
kir::Kernel* kernel() const;
//! Returns the currently active lowering object.
//! It's an error if no lowering is in progress.
static GpuLower* current();
//! Query if lowering is in progress
static bool hasCurrent();
ConcretizedBroadcastDomains& concretizedBroadcastDomains() {
return concretized_broadcast_domains_;
}
const ThreadPredicateMap& threadPredMap() const {
return thread_pred_map_;
}
// Returns non-const reference. Necessary to reset a predicate flag
// when a broadcast expression is fused into a reduction.
ThreadPredicateMap& threadPredMap() {
return thread_pred_map_;
}
const std::unique_ptr<ComputeAtMap>& caMap() const {
return compute_at_map_;
}
const TrivialReductionInfo& trivialReductionInfo() const {
return trivial_reduction_info_;
}
const HaloInfo& haloInfo() const {
return halo_info_;
}
HaloInfo& haloInfo() {
return halo_info_;
}
const ParallelDimensionMap& parallelDimensionMap() const {
return parallel_dimension_map_;
}
ParallelDimensionMap& parallelDimensionMap() {
return parallel_dimension_map_;
}
PredicateElimination& predicateElimination() {
return pred_elimination_;
}
const PredicateElimination& predicateElimination() const {
return pred_elimination_;
}
LocalAllocationInfoMap& localAllocationInfoMap() {
return local_allocation_info_map_;
}
const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const {
return warp_pad_info_;
}
PartialSplitMap& partialSplitMap() {
return partial_split_map_;
}
const PartialSplitMap& partialSplitMap() const {
return partial_split_map_;
}
auto& nonDivisibleSplitInfo() {
return non_divisible_split_info_;
}
const auto& nonDivisibleSplitInfo() const {
return non_divisible_split_info_;
}
DoubleBufferInfo& doubleBufferInfo() {
return double_buffer_info_;
}
CommonIndexMap& commonIndexMap() {
return common_index_map_;
}
const auto& vectorizedAccesses() const {
return vectorized_accesses_;
}
auto& vectorizedAccesses() {
return vectorized_accesses_;
}
const auto& vectorizedSetInfo() const {
return vectorized_set_info_;
}
auto& vectorizedSetInfo() {
return vectorized_set_info_;
}
FusedReductionInfo& fusedReductionInfo() {
return fused_reduction_info_;
}
const SyncMap& syncMap() const {
return sync_map_;
}
kir::KernelPerformanceProfile& profile() {
return profile_;
}
// This is an interface to propagate information after expression
// replacement on the kernel IR. E.g.:
// for ...
// c = a + b (expr 0)
// after any pass that does replacement:
// for ...
// c1 = a1 + b1 (expr1)
// The previous analysis that was performed on expr0 might still
// be valid on expr1 but that info would be lost after replacement.
// This function provides an interface to manually update the info
// in any pass that performs replacement.
void propagateExprInfo(const Expr* old_expr, const Expr* new_expr);
private:
void lower(Fusion* fusion, DataType index_type);
// Goes through the parallelized iterdomains of the used TVs and find
// the parallel dimensions that need to be padded to a multiples of
// warp size.
void collectPaddedParallelDims();
private:
// Lowered Kernel IR
std::unique_ptr<kir::Kernel> kernel_;
// Some stateful information during lowering
// TODO: A lot of this information uses a define class then call build. It
// would be safer to wrap all of these in unique pointers and remove the build
// interface and default constructor. That way they couldn't be accessed
// without being initialized.
ConcretizedBroadcastDomains concretized_broadcast_domains_;
ThreadPredicateMap thread_pred_map_;
PredicateElimination pred_elimination_;
std::unique_ptr<ComputeAtMap> compute_at_map_;
TrivialReductionInfo trivial_reduction_info_;
HaloInfo halo_info_;
LocalAllocationInfoMap local_allocation_info_map_;
WarpPaddedParallelInfo warp_pad_info_;
ParallelDimensionMap parallel_dimension_map_;
PartialSplitMap partial_split_map_;
NonDivisibleSplitInfo non_divisible_split_info_;
DoubleBufferInfo double_buffer_info_;
CommonIndexMap common_index_map_;
FusedReductionInfo fused_reduction_info_;
SyncMap sync_map_;
kir::KernelPerformanceProfile profile_;
// Track which tensor views are inputs or outputs of a vectorized operation
// and their maximum vectorized access size
// std::unordered_map<TensorView*, VectorizationInfo> vectorized_accesses_;
std::unordered_map<TensorView*, int> vectorized_accesses_;
// Info on each vectorized set op
std::vector<VectorizedSetInfo> vectorized_set_info_;
Fusion* fusion_ = nullptr;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch