pytorch/torch/csrc/jit/codegen/cuda/utils.h
jjsjann123 df741c589f [NVFuser] Upstream push 0809 (#83067)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. removes un-necessary sync from redundant thread compute analysis
  2. symmetric API for BestEffortReplay
  3. support merge on trivial reductions
  4. Ampere async copy improvements
- bug fixes:
  1. vectorization bug fixes
  2. type inference patch : fixes upstream #81725
  3. segmenter bug fix with deterministic iteration ordering
- parser update
  1. added leaky_relu
- scheduler
  1. normalization scheduler clean up.
  2. simplifies matmul scheduling with new transform propagator
  3. merge all dimensions in PW scheduler
  4. various gemm related improvements
- debuggability
  1. nsight compute support
  2. debug dump for InlinePropagator
  3. Add `UnaryOpType::Print`

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD
16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884)
7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803
3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD
01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878)
0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881)
7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875)
501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827)
e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824)
9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823)
a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872)
172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868)
a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871)
440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870)
d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867)
1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864)
51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866)
a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering.  (#1865)
1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861)
1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator.  (#1817)
bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852)
e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858)
9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857)
20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855)
405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822)
01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853)
5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846)
92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845)
db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848)
102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847)
b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687)
942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842)
0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840)
63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839)
2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067
Approved by: https://github.com/davidberard98
2022-08-10 21:02:56 +00:00

169 lines
5.4 KiB
C++

#pragma once
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
void debugPrint(const c10::TensorTypePtr& type);
bool is_zero_dim_tensor(const std::shared_ptr<c10::TensorType>& tensor_type);
bool is_zero_sized_tensor(const std::shared_ptr<c10::TensorType>& tensor_type);
bool is_cpu_scalar(const at::Tensor& tensor);
bool is_cpu_scalar(const c10::TensorType& tensor_type);
//! Types of debug print-outs
//!
//! These can be set through the `PYTORCH_NVFUSER_DUMP` environment variable
//!
enum class DebugDumpOption {
FusionIr, //!< Dump the Fusion IR before lowering
FusionIrMath, //!< Dump just the compute (math) part of the Fusion IR
KernelIr, //!< Dump the compiler Kernel IR
ComputeAtMap, //!< Dump the computeAt map
CudaKernel, //!< Dump the generated CUDA C++ kernel code
CudaFull, //!< Dump the complete CUDA C++ code
CudaToFile, //!< Dump CUDA Strings to File
DebugInfo, //!< Embed line info and debug info to compiled kernel, and dump
//!< the full CUDA C++ code
LaunchParam, //!< Dump the Launch parameters of kernel
FusionSegments, //!< Dump Segmented Fusion Graph
FusionSegmenterLog, //!< Dump Detailed Segmenter Logging
FusionArgs, //!< Print the runtime fusion arguments
KernelArgs, //!< Print the runtime kernel arguments when launching kernels
EffectiveBandwidth, //! Measure kernel performance and print effective
//! bandwidth
FusionSegmentsDrawing, //!< Dump Segmented Fusion Graph
PrintPtxasLog, //!< Print the ptxas verbose log including register usage
BufferReuseInfo, //!< Dump the analysis details of local/shared buffer re-use
SchedulerDebug, //! Dump scheduler heuristic parameters
ParallelDimensions, //!< Dump known parallel dimensions
Halo, //! Halo information of tensors
PerfDebugVerbose, //! When running kernels, print verbose information
//! associated with what's running
TransformPropagator, //! When running TransformPropagator, print propagation
//! path and replay result
InlinePropagator, //! When running InlinePropagator, print propagation
//! path and inlining result
};
TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);
//! Types of features to disable
//!
//! These can be set through the `PYTORCH_NVFUSER_DISABLE` environment variable
//!
enum class DisableOption {
ArchCheck, //! Disable hardware-specific checks to enable cross arch debug
Fallback, //! Disable fallback
Fma, //! Disable FMA instructions
IndexHoist, //! Disable index hoisting
Nvtx, //! Disable NVTX instrumentation
PredicateElimination, //! Disable predicate elimination
UnrollWithRng //! Disable unrolling for kernels with RNG in them
};
TORCH_CUDA_CU_API bool isDisabled(DisableOption option);
//! Types of features to enable
//!
//! These can be set through the `PYTORCH_NVFUSER_ENABLE` environment variable
//!
enum class EnableOption {
Complex, //! Enable complex support on python
KernelProfile, //! Enable intra-kernel performance profiling
LinearDecomposition, //! Enable linear-bias decomposition
ConvDecomposition //! Enable conv-bias decomposition
};
TORCH_CUDA_CU_API bool isEnabled(EnableOption option);
// Check if fallback path should be used which will dispatch to eagermode if any
// errors are encountered. Helpful for debugging.
bool useFallback();
//! Ceil integer division
constexpr int64_t ceilDiv(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
//! Simple mixin for suppressing copy & move operations, ex:
//!
//! class Foo : public NonCopyable {
//! ...
//! };
//!
class NonCopyable {
public:
NonCopyable() = default;
// No copy/move semantics
NonCopyable(const NonCopyable&) = delete;
NonCopyable& operator=(const NonCopyable&) = delete;
};
//! A generic root for a hierarchy of polymorphic classes:
//! - It ensures virtual destructors
//! - Provides the base->as<Derived>() and node->isA<T>() notation
class PolymorphicBase {
public:
virtual ~PolymorphicBase() = default;
// Replacement for static_cast<T*>(ptr): ptr->as<T>()
// (checked in DEBUG builds)
template <class T>
T* as() {
#ifdef NDEBUG
auto downcast_ptr = static_cast<T*>(this);
#else
auto downcast_ptr = dynamic_cast<T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}
template <class T>
const T* as() const {
#ifdef NDEBUG
auto downcast_ptr = static_cast<const T*>(this);
#else
auto downcast_ptr = dynamic_cast<const T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}
//! Check if the runtime time is T (or derived from T)
//!
//! \note Don't use this for conditional casts. Instead, use:
//!
//! if (auto t = dynamic_cast<T>(p)) { ... }
//!
//! instead of:
//!
//! if (p->isA<T>()) { auto t = p->as<T>(); ... }
//!
template <class T>
bool isA() const {
return dynamic_cast<const T*>(this) != nullptr;
}
};
template <class T, std::enable_if_t<std::is_enum<T>::value, bool> = true>
constexpr unsigned int switch_pair(T t1, T t2) {
constexpr unsigned int _WORD_SHIFT = 16;
return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2;
}
std::vector<int64_t> getTensorSizes(TensorTypePtr const& tensor_type);
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch