mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69964 Things added in this PR that requires review: 1. cuLaunchCooperativeKernel driver API added aten/src/ATen/cuda/detail/LazyNVRTC.cpp aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h nvfuser code update: 1. perf turning on codegen scheduler that improves performance. 2. permutation support has been extended beyond contiguous/channels-last. (The improvements could be observed on PW benchmark) Things reverted from local changes: 1. aten::gelu with approximation 2. local changes that is upstreamed in PR https://github.com/pytorch/pytorch/issues/68804 Pull Request resolved: https://github.com/pytorch/pytorch/pull/69428 Reviewed By: ngimel Differential Revision: D33073817 Pulled By: wconstab fbshipit-source-id: e77d32e81d037d7370822b040456fd4c3bd68edb
168 lines
4.8 KiB
C++
168 lines
4.8 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/Export.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
|
|
#include <torch/csrc/jit/codegen/cuda/non_divisible_split.h>
|
|
#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
|
|
#include <torch/csrc/jit/codegen/cuda/partial_split_map.h>
|
|
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
|
|
|
#include <memory>
|
|
#include <ostream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
// TODO: we frequently use pairwise root mapping from consumers to producers.
|
|
// This information is implicitly in the computeAtMaps, but there's no isolated
|
|
// container for this information that we can reuse. Would be nice to generate
|
|
// such a structure and propagate it through lowering.
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
class TORCH_CUDA_CU_API GpuLower {
|
|
class KernelIrMapper;
|
|
|
|
public:
|
|
GpuLower() = default;
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
explicit GpuLower(Fusion* fusion) : fusion_(fusion) {
|
|
lower();
|
|
}
|
|
|
|
kir::Kernel* kernel() const;
|
|
|
|
//! Converts a Fusion IR value into the Kernel IR equivalent
|
|
kir::Val* lowerValue(const Val* val);
|
|
|
|
//! Converts a Fusion IR expression into the Kernel IR equivalent
|
|
kir::Expr* lowerExpr(const Expr* expr);
|
|
|
|
//! Returns the currently active lowering object
|
|
//! (or nullptr if no lowering is in progress)
|
|
static GpuLower* current();
|
|
|
|
const ThreadPredicateMap& threadPredMap() const {
|
|
return thread_pred_map_;
|
|
}
|
|
|
|
const ComputeAtMap& caLoopMap() const {
|
|
return ca_loop_map_;
|
|
}
|
|
|
|
const ComputeAtMap& caIndexMap() const {
|
|
return ca_index_map_;
|
|
}
|
|
|
|
const ComputeAtMap& caParallelMap() const {
|
|
return ca_parallel_map_;
|
|
}
|
|
|
|
const auto& trivialReductionInfo() const {
|
|
return trivial_reduction_info_;
|
|
}
|
|
|
|
const HaloInfo& haloInfo() const {
|
|
return halo_info_;
|
|
}
|
|
|
|
HaloInfo& haloInfo() {
|
|
return halo_info_;
|
|
}
|
|
|
|
const ParallelDimensionMap& parallelDimensionMap() const {
|
|
return parallel_dimension_map_;
|
|
}
|
|
|
|
ParallelDimensionMap& parallelDimensionMap() {
|
|
return parallel_dimension_map_;
|
|
}
|
|
|
|
PredicateElimination& predicateElimination() {
|
|
return pred_elimination_;
|
|
}
|
|
|
|
const PredicateElimination& predicateElimination() const {
|
|
return pred_elimination_;
|
|
}
|
|
|
|
LocalAllocationInfoMap& localAllocationInfoMap() {
|
|
return local_allocation_info_map_;
|
|
}
|
|
|
|
const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const {
|
|
return warp_pad_info_;
|
|
}
|
|
|
|
PartialSplitMap& partialSplitMap() {
|
|
return partial_split_map_;
|
|
}
|
|
|
|
const PartialSplitMap& partialSplitMap() const {
|
|
return partial_split_map_;
|
|
}
|
|
|
|
auto& nonDivisibleSplitInfo() {
|
|
return non_divisible_split_info_;
|
|
}
|
|
|
|
const auto& nonDivisibleSplitInfo() const {
|
|
return non_divisible_split_info_;
|
|
}
|
|
|
|
private:
|
|
void lower();
|
|
|
|
// TensorViews are all based on symbolic sizes. When we first initialize them
|
|
// we don't know if they're inputs or outputs which would mean that they have
|
|
// runtime shapes. Intermediate tensors (those not going to global memory) do
|
|
// not have this information. Since we need to have the correct information in
|
|
// the kernel being fetched for shapes, we want to replace input and output
|
|
// tensors to reference the runtime structure containing sizes.
|
|
void replaceSymbolicSizes();
|
|
|
|
// Goes through the parallelized iterdomains of the used TVs and find
|
|
// the parallel dimensions that need to be padded to a multiples of
|
|
// warp size.
|
|
void collectPaddedParallelDims();
|
|
|
|
private:
|
|
// Lowered Kernel IR
|
|
std::unique_ptr<kir::Kernel> kernel_;
|
|
|
|
// Fusion IR node to Kernel IR node mapping
|
|
std::unordered_map<const Val*, kir::Val*> kir_val_map_;
|
|
std::unordered_map<const Expr*, kir::Expr*> kir_expr_map_;
|
|
|
|
// Some stateful information during lowering
|
|
ThreadPredicateMap thread_pred_map_;
|
|
PredicateElimination pred_elimination_;
|
|
ComputeAtMap ca_loop_map_;
|
|
ComputeAtMap ca_index_map_;
|
|
ComputeAtMap ca_parallel_map_;
|
|
TrivialReductionInfo trivial_reduction_info_;
|
|
HaloInfo halo_info_;
|
|
LocalAllocationInfoMap local_allocation_info_map_;
|
|
WarpPaddedParallelInfo warp_pad_info_;
|
|
ParallelDimensionMap parallel_dimension_map_;
|
|
PartialSplitMap partial_split_map_;
|
|
NonDivisibleSplitInfo non_divisible_split_info_;
|
|
|
|
Fusion* fusion_ = nullptr;
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|