pytorch/benchmarks/cpp/nvfuser/utils.h
jjsjann123 df741c589f [NVFuser] Upstream push 0809 (#83067)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. removes un-necessary sync from redundant thread compute analysis
  2. symmetric API for BestEffortReplay
  3. support merge on trivial reductions
  4. Ampere async copy improvements
- bug fixes:
  1. vectorization bug fixes
  2. type inference patch : fixes upstream #81725
  3. segmenter bug fix with deterministic iteration ordering
- parser update
  1. added leaky_relu
- scheduler
  1. normalization scheduler clean up.
  2. simplifies matmul scheduling with new transform propagator
  3. merge all dimensions in PW scheduler
  4. various gemm related improvements
- debuggability
  1. nsight compute support
  2. debug dump for InlinePropagator
  3. Add `UnaryOpType::Print`

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD
16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884)
7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803
3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD
01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878)
0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881)
7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875)
501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827)
e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824)
9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823)
a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872)
172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868)
a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871)
440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870)
d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867)
1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864)
51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866)
a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering.  (#1865)
1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861)
1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator.  (#1817)
bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852)
e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858)
9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857)
20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855)
405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822)
01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853)
5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846)
92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845)
db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848)
102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847)
b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687)
942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842)
0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840)
63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839)
2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067
Approved by: https://github.com/davidberard98
2022-08-10 21:02:56 +00:00

203 lines
7.2 KiB
C++

#pragma once
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <benchmark/benchmark.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>
#include <cuda_runtime.h>
using namespace torch::jit::fuser::cuda;
// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float);
// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
// but unknown sizes. Taken from test_gpu.cpp
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float);
// Make a non-contiguous tensor of compile-time known sizes
TensorView* makeConcreteTensor(
std::vector<int64_t> shape,
DataType dtype = DataType::Float);
// Make a contiguous tensor of compile-time known sizes
TensorView* makeContigConcreteTensor(
std::vector<int64_t> shape,
DataType dtype = DataType::Float);
std::string toString(ReductionParams rparams);
std::string toString(PointwiseParams params);
std::string toString(const std::shared_ptr<HeuristicParams>& params);
std::string toString(LaunchParams lparams);
// Run benchmark iterations with provided inputs. If not segmented, report
// kernel time from the runtime, as well as heuristic parameters. If segmented
// use timers. Make sure to clear L2 between iterations.
void runBenchmarkIterations(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
std::vector<c10::IValue>& aten_inputs);
void clearL2Cache();
class CudaKernelTimer {
public:
CudaKernelTimer() {
// Setup
cudaEventCreate(&start_event);
cudaEventCreate(&finish_event);
cudaEventRecord(start_event);
}
~CudaKernelTimer() {
cudaEventDestroy(start_event);
cudaEventDestroy(finish_event);
}
void restart() {
cudaEventRecord(start_event);
}
float elapsed() {
// Record
cudaEventRecord(finish_event);
cudaEventSynchronize(start_event);
cudaEventSynchronize(finish_event);
cudaEventElapsedTime(&kernel_time_ms_, start_event, finish_event);
return kernel_time_ms_;
}
private:
// Create
float kernel_time_ms_ = 0;
cudaEvent_t start_event = {};
cudaEvent_t finish_event = {};
};
namespace executorCache {
using ExecutorPtr = std::unique_ptr<FusionExecutorCache>;
using ExecutorMap = std::unordered_map<std::string, ExecutorPtr>;
ExecutorMap& getGlobalMap();
} // namespace executorCache
//! Utility to manage FusionExecutorCache instances for
//! all defined benchmarks
class BenchmarkGraph : public benchmark::Fixture {
public:
using SetupFusionFunction = std::function<void(Fusion*)>;
using SetupFusionMap = std::unordered_map<std::string, SetupFusionFunction>;
virtual std::string graphName() = 0;
virtual SetupFusionFunction setupFusion() = 0;
FusionExecutorCache* getExecutorCache() {
auto& executor_ = getExecutorCacheMap()[graphName()];
TORCH_INTERNAL_ASSERT(executor_);
return executor_.get();
}
void SetUp(const ::benchmark::State& state) {
auto& executor_ = getExecutorCacheMap()[graphName()];
// Makes sure same graph hasn't been compiled before
if (!executor_) {
auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard(fusion_ptr.get());
setupFusion()(fusion_ptr.get());
getExecutorCacheMap()[graphName()] =
std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
}
}
void TearDown(const ::benchmark::State& state) {}
protected:
static executorCache::ExecutorMap& getExecutorCacheMap() {
return executorCache::getGlobalMap();
}
};
#define NVFUSER_TO_STRING_HELPER(n) std::string(#n)
#define NVFUSER_TO_STRING(n) NVFUSER_TO_STRING_HELPER(n)
//! NVFUSER_BENCHMARK_RUN utility usage:
//! This utility helps create and manage FusionExecutorCaches and tries to use
//! the caching
//! mechanism in NVFuser to avoid re-compilation.
//!
//! There are two macros in this utility: NVFUSER_BENCHMARK_DEFINE, and
//! NVFUSER_BENCHMARK_RUN,
//! and user needs to supply two functions SETUP_FUSION and RUN_FUSION, with
//! following signatures:
//!
//! SETUP_FUSION(Fusion* , args...);
//! RUN_FUSION(benchmark::State&, FusionExecutorCache* , args...);
//!
//! where args... are additional arguments, and they need to be the same for
//! SETUP_FUSION and RUN_FUSION.
//!
//! SETUP_FUSION is called once in each definition of benchmark to build the
//! fusionIR graph
//!
//! RUN_FUSION is just like the normal benchmark instance, except that a
//! FusionExecutorCache
//! will be provided for scheduling, running and timing the fusion runs. It is
//! called once in each benchmark instance. For example:
//! NVFUSER_BENCHMARK_RUN(my_benchmark)
//! ->RangeMultiplier(2)
//! ->Ranges({{1, 4})
//! Calls RUN_FUSION 3 times.
//!
//! To register a benchmark, the API is:
//!
//! NVFUSER_BENCHMARK_DEFINE(my_benchmark,SETUP_FUSION,RUN_FUSION,args...);
//!
//! where my_benchmark is any unique name given for this benchmark,
//! SETUP_FUSION, RUN_FUSION as described above,
//! args... is the arg list supplied to both setup_fusion and run_fusion
//!
//! each NVFUSER_BENCHMARK_DEFINE registers a benchmark with a single
//! FusionExecutorCache, i.e. a single fusion graph, and multiple benchmark
//! data points can be registered like:
//!
//! NVFUSER_BENCHMARK_RUN(my_benchmark)
//! ->Ranges({{1,2}});
//!
//! NVFUSER_BENCHMARK_RUN(my_benchmark)
//! ->Ranges({{3,4}});
//!
//! All datapoints will use the same FusionExecutorCache so recompilation is
//! avoided as much as possible.
#define NVFUSER_BENCHMARK_DEFINE( \
BENCHMARK_NAME, SETUP_FUSION, RUN_FUSION, ...) \
class BENCHMARK_NAME##___GRAPH : public BenchmarkGraph { \
public: \
std::string graphName() { \
return NVFUSER_TO_STRING(BENCHMARK_NAME##___GRAPH); \
} \
SetupFusionFunction setupFusion() { \
return [](Fusion* fusion) { SETUP_FUSION(fusion, __VA_ARGS__); }; \
} \
}; \
BENCHMARK_DEFINE_F(BENCHMARK_NAME##___GRAPH, BENCHMARK_NAME) \
(benchmark::State & benchmark_state) { \
RUN_FUSION( \
benchmark_state, \
BENCHMARK_NAME##___GRAPH::getExecutorCache(), \
__VA_ARGS__); \
}
#define NVFUSER_BENCHMARK_RUN(BENCHMARK_NAME) \
BENCHMARK_REGISTER_F(BENCHMARK_NAME##___GRAPH, BENCHMARK_NAME)