pytorch/torch/csrc/jit/codegen/cuda/lower2device.h
jiej 76d282d447 Nvfuser code bump 12 5 (#69964)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69964

Things added in this PR that requires review:
1. cuLaunchCooperativeKernel driver API added
aten/src/ATen/cuda/detail/LazyNVRTC.cpp
aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h

nvfuser code update:
1. perf turning on codegen scheduler that improves performance.
2. permutation support has been extended beyond contiguous/channels-last. (The improvements could be observed on PW benchmark)

Things reverted from local changes:
1. aten::gelu with approximation
2. local changes that is upstreamed in PR https://github.com/pytorch/pytorch/issues/68804

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69428

Reviewed By: ngimel

Differential Revision: D33073817

Pulled By: wconstab

fbshipit-source-id: e77d32e81d037d7370822b040456fd4c3bd68edb
2021-12-16 08:28:54 -08:00

168 lines
4.8 KiB
C++

#pragma once
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_warp_reduce.h>
#include <torch/csrc/jit/codegen/cuda/non_divisible_split.h>
#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
#include <torch/csrc/jit/codegen/cuda/partial_split_map.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <memory>
#include <ostream>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
// TODO: we frequently use pairwise root mapping from consumers to producers.
// This information is implicitly in the computeAtMaps, but there's no isolated
// container for this information that we can reuse. Would be nice to generate
// such a structure and propagate it through lowering.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API GpuLower {
class KernelIrMapper;
public:
GpuLower() = default;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
explicit GpuLower(Fusion* fusion) : fusion_(fusion) {
lower();
}
kir::Kernel* kernel() const;
//! Converts a Fusion IR value into the Kernel IR equivalent
kir::Val* lowerValue(const Val* val);
//! Converts a Fusion IR expression into the Kernel IR equivalent
kir::Expr* lowerExpr(const Expr* expr);
//! Returns the currently active lowering object
//! (or nullptr if no lowering is in progress)
static GpuLower* current();
const ThreadPredicateMap& threadPredMap() const {
return thread_pred_map_;
}
const ComputeAtMap& caLoopMap() const {
return ca_loop_map_;
}
const ComputeAtMap& caIndexMap() const {
return ca_index_map_;
}
const ComputeAtMap& caParallelMap() const {
return ca_parallel_map_;
}
const auto& trivialReductionInfo() const {
return trivial_reduction_info_;
}
const HaloInfo& haloInfo() const {
return halo_info_;
}
HaloInfo& haloInfo() {
return halo_info_;
}
const ParallelDimensionMap& parallelDimensionMap() const {
return parallel_dimension_map_;
}
ParallelDimensionMap& parallelDimensionMap() {
return parallel_dimension_map_;
}
PredicateElimination& predicateElimination() {
return pred_elimination_;
}
const PredicateElimination& predicateElimination() const {
return pred_elimination_;
}
LocalAllocationInfoMap& localAllocationInfoMap() {
return local_allocation_info_map_;
}
const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const {
return warp_pad_info_;
}
PartialSplitMap& partialSplitMap() {
return partial_split_map_;
}
const PartialSplitMap& partialSplitMap() const {
return partial_split_map_;
}
auto& nonDivisibleSplitInfo() {
return non_divisible_split_info_;
}
const auto& nonDivisibleSplitInfo() const {
return non_divisible_split_info_;
}
private:
void lower();
// TensorViews are all based on symbolic sizes. When we first initialize them
// we don't know if they're inputs or outputs which would mean that they have
// runtime shapes. Intermediate tensors (those not going to global memory) do
// not have this information. Since we need to have the correct information in
// the kernel being fetched for shapes, we want to replace input and output
// tensors to reference the runtime structure containing sizes.
void replaceSymbolicSizes();
// Goes through the parallelized iterdomains of the used TVs and find
// the parallel dimensions that need to be padded to a multiples of
// warp size.
void collectPaddedParallelDims();
private:
// Lowered Kernel IR
std::unique_ptr<kir::Kernel> kernel_;
// Fusion IR node to Kernel IR node mapping
std::unordered_map<const Val*, kir::Val*> kir_val_map_;
std::unordered_map<const Expr*, kir::Expr*> kir_expr_map_;
// Some stateful information during lowering
ThreadPredicateMap thread_pred_map_;
PredicateElimination pred_elimination_;
ComputeAtMap ca_loop_map_;
ComputeAtMap ca_index_map_;
ComputeAtMap ca_parallel_map_;
TrivialReductionInfo trivial_reduction_info_;
HaloInfo halo_info_;
LocalAllocationInfoMap local_allocation_info_map_;
WarpPaddedParallelInfo warp_pad_info_;
ParallelDimensionMap parallel_dimension_map_;
PartialSplitMap partial_split_map_;
NonDivisibleSplitInfo non_divisible_split_info_;
Fusion* fusion_ = nullptr;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch