mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ A few bigger updates: 1. Initial support of cp.async and cp.async.wait: https://github.com/csarofeen/pytorch/pull/1619 2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: https://github.com/csarofeen/pytorch/pull/1643 3. Extending the infrastructure to support mma operators on turing and ampere arch: https://github.com/csarofeen/pytorch/pull/1440 Commits that's actually in this PR from the csarofeen branch ``` * dd2325294e236c5082c642819a1103bcfe4561a3 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710) * b3d1c3f446355a2d276bac8272e7aa8b5bb6b1f0 Fix missing cooperative launch (#1726) * dc670a226cbe52be46cecef47001f38bf9a09433 Async gmem copy support on sm80+ (#1619) * 5e6a8dab5a71aefe0548bbfa15d1a93c556d23fe Add turing mma support and test (#1643) * d6d6b7d3f10dd91dafa4cdbd5e460bbb38173af4 Fix rFactor when there are indirect root domain(s), and refactor (#1723) * 7093e39150c6d80e0f9f767d56654714a2e8a927 Mma op integration on ampere (#1440) * fade8da55e60a118c5595378896d34b862b2fcc3 patch python test for bfloat16 (#1724) * 8fbd0b18743a72ac10478857c3d2351204375685 Fine-grained kernel profiling (#1720) * 77c1b4fa633f9e631d267923f4537336fa328939 Adding dry run mode to skip arch dependent checks (#1702) * 151d95b97bebefc94199bb4a53423ede32b55451 More precise concretization analysis (#1719) * f4d3630ed54d7069dd377a64be1f91013b285b66 Enable complex python tests (#1667) * 4ceeee509774cc2ce6c834a4dc1e313f71d94503 Minor bugfix in transform_rfactor.cpp (#1715) * 3675c70faf218e86d2c78dbd3874b175a3b0a203 Separate root domain and rfactor domain in TransformPrinter (#1716) * f68b830d5def65dadfe29d4edf52fc703369c84a Fix scheduling with polymorphic broadcast (#1714) * 4ab5ef7ae2cfd8fffad1e1d882ae7c50631211dc updating_ci_machine (#1718) * 56585c58b1ff338704cafb0cd6be2b3d536bed5a Merge pull request #1711 from csarofeen/upstream_master_bump_0517 * 174d453d3be0c11a5acb0fff3b3f36e19cfdaf81 Allow using nvFuser on CUDA extension (#1701) * 18bee67495454b9a79625799776e746bd5e81c4c Validate LOOP concrete IDs have complete IterDomains (#1676) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/78244 Approved by: https://github.com/csarofeen, https://github.com/malfet
232 lines
7.0 KiB
C++
232 lines
7.0 KiB
C++
#pragma once
|
|
|
|
#include <c10/macros/Export.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_fused_reduction.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_index_hoist.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_predicate_elimination.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_sync_information.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
|
|
#include <torch/csrc/jit/codegen/cuda/non_divisible_split.h>
|
|
#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
|
|
#include <torch/csrc/jit/codegen/cuda/partial_split_map.h>
|
|
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
|
#include <torch/csrc/jit/codegen/cuda/vectorization_info.h>
|
|
|
|
#include <memory>
|
|
#include <ostream>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
// TODO: we frequently use pairwise root mapping from consumers to producers.
|
|
// This information is implicitly in the computeAtMaps, but there's no isolated
|
|
// container for this information that we can reuse. Would be nice to generate
|
|
// such a structure and propagate it through lowering.
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
|
|
class KernelIrMapper;
|
|
|
|
public:
|
|
GpuLower() = delete;
|
|
|
|
// GpuLower lowers the provided fusion into a kernel which can be translated
|
|
// into cuda code. index_type allows to compile the kernel based on int32
|
|
// indexing instead of int64 for additional performance.
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
explicit GpuLower(Fusion* fusion, DataType index_type = DataType::Int) {
|
|
lower(fusion, index_type);
|
|
}
|
|
|
|
kir::Kernel* kernel() const;
|
|
|
|
//! Returns the currently active lowering object.
|
|
//! It's an error if no lowering is in progress.
|
|
static GpuLower* current();
|
|
|
|
//! Query if lowering is in progress
|
|
static bool hasCurrent();
|
|
|
|
ConcretizedBroadcastDomains& concretizedBroadcastDomains() {
|
|
return concretized_broadcast_domains_;
|
|
}
|
|
|
|
const ThreadPredicateMap& threadPredMap() const {
|
|
return thread_pred_map_;
|
|
}
|
|
|
|
// Returns non-const reference. Necessary to reset a predicate flag
|
|
// when a broadcast expression is fused into a reduction.
|
|
ThreadPredicateMap& threadPredMap() {
|
|
return thread_pred_map_;
|
|
}
|
|
|
|
const std::unique_ptr<ComputeAtMap>& caMap() const {
|
|
return compute_at_map_;
|
|
}
|
|
|
|
const TrivialReductionInfo& trivialReductionInfo() const {
|
|
return trivial_reduction_info_;
|
|
}
|
|
|
|
const HaloInfo& haloInfo() const {
|
|
return halo_info_;
|
|
}
|
|
|
|
HaloInfo& haloInfo() {
|
|
return halo_info_;
|
|
}
|
|
|
|
const ParallelDimensionMap& parallelDimensionMap() const {
|
|
return parallel_dimension_map_;
|
|
}
|
|
|
|
ParallelDimensionMap& parallelDimensionMap() {
|
|
return parallel_dimension_map_;
|
|
}
|
|
|
|
PredicateElimination& predicateElimination() {
|
|
return pred_elimination_;
|
|
}
|
|
|
|
const PredicateElimination& predicateElimination() const {
|
|
return pred_elimination_;
|
|
}
|
|
|
|
LocalAllocationInfoMap& localAllocationInfoMap() {
|
|
return local_allocation_info_map_;
|
|
}
|
|
|
|
const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const {
|
|
return warp_pad_info_;
|
|
}
|
|
|
|
PartialSplitMap& partialSplitMap() {
|
|
return partial_split_map_;
|
|
}
|
|
|
|
const PartialSplitMap& partialSplitMap() const {
|
|
return partial_split_map_;
|
|
}
|
|
|
|
auto& nonDivisibleSplitInfo() {
|
|
return non_divisible_split_info_;
|
|
}
|
|
|
|
const auto& nonDivisibleSplitInfo() const {
|
|
return non_divisible_split_info_;
|
|
}
|
|
|
|
DoubleBufferInfo& doubleBufferInfo() {
|
|
return double_buffer_info_;
|
|
}
|
|
|
|
CommonIndexMap& commonIndexMap() {
|
|
return common_index_map_;
|
|
}
|
|
|
|
const auto& vectorizedAccesses() const {
|
|
return vectorized_accesses_;
|
|
}
|
|
|
|
auto& vectorizedAccesses() {
|
|
return vectorized_accesses_;
|
|
}
|
|
|
|
const auto& vectorizedSetInfo() const {
|
|
return vectorized_set_info_;
|
|
}
|
|
|
|
auto& vectorizedSetInfo() {
|
|
return vectorized_set_info_;
|
|
}
|
|
|
|
FusedReductionInfo& fusedReductionInfo() {
|
|
return fused_reduction_info_;
|
|
}
|
|
|
|
const SyncMap& syncMap() const {
|
|
return sync_map_;
|
|
}
|
|
|
|
kir::KernelPerformanceProfile& profile() {
|
|
return profile_;
|
|
}
|
|
|
|
// This is an interface to propagate information after expression
|
|
// replacement on the kernel IR. E.g.:
|
|
// for ...
|
|
// c = a + b (expr 0)
|
|
// after any pass that does replacement:
|
|
// for ...
|
|
// c1 = a1 + b1 (expr1)
|
|
// The previous analysis that was performed on expr0 might still
|
|
// be valid on expr1 but that info would be lost after replacement.
|
|
// This function provides an interface to manually update the info
|
|
// in any pass that performs replacement.
|
|
void propagateExprInfo(const Expr* old_expr, const Expr* new_expr);
|
|
|
|
private:
|
|
void lower(Fusion* fusion, DataType index_type);
|
|
|
|
// Goes through the parallelized iterdomains of the used TVs and find
|
|
// the parallel dimensions that need to be padded to a multiples of
|
|
// warp size.
|
|
void collectPaddedParallelDims();
|
|
|
|
private:
|
|
// Lowered Kernel IR
|
|
std::unique_ptr<kir::Kernel> kernel_;
|
|
|
|
// Some stateful information during lowering
|
|
// TODO: A lot of this information uses a define class then call build. It
|
|
// would be safer to wrap all of these in unique pointers and remove the build
|
|
// interface and default constructor. That way they couldn't be accessed
|
|
// without being initialized.
|
|
ConcretizedBroadcastDomains concretized_broadcast_domains_;
|
|
ThreadPredicateMap thread_pred_map_;
|
|
PredicateElimination pred_elimination_;
|
|
std::unique_ptr<ComputeAtMap> compute_at_map_;
|
|
TrivialReductionInfo trivial_reduction_info_;
|
|
HaloInfo halo_info_;
|
|
LocalAllocationInfoMap local_allocation_info_map_;
|
|
WarpPaddedParallelInfo warp_pad_info_;
|
|
ParallelDimensionMap parallel_dimension_map_;
|
|
PartialSplitMap partial_split_map_;
|
|
NonDivisibleSplitInfo non_divisible_split_info_;
|
|
DoubleBufferInfo double_buffer_info_;
|
|
CommonIndexMap common_index_map_;
|
|
FusedReductionInfo fused_reduction_info_;
|
|
SyncMap sync_map_;
|
|
kir::KernelPerformanceProfile profile_;
|
|
|
|
// Track which tensor views are inputs or outputs of a vectorized operation
|
|
// and their maximum vectorized access size
|
|
// std::unordered_map<TensorView*, VectorizationInfo> vectorized_accesses_;
|
|
std::unordered_map<TensorView*, int> vectorized_accesses_;
|
|
// Info on each vectorized set op
|
|
std::vector<VectorizedSetInfo> vectorized_set_info_;
|
|
|
|
Fusion* fusion_ = nullptr;
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|