mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD 16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875) 501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823) a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872) 172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871) 440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870) d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864) 51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866) a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861) 1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852) e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858) 9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857) 20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855) 405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853) 5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845) db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848) 102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847) b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842) 0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840) 63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839) 2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067 Approved by: https://github.com/davidberard98
216 lines
7.1 KiB
C++
216 lines
7.1 KiB
C++
#include <benchmarks/cpp/nvfuser/utils.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
|
|
|
#include <sstream>
|
|
|
|
using namespace torch::jit::fuser::cuda;
|
|
|
|
std::string toString(ReductionParams rparams) {
|
|
std::stringstream ss;
|
|
ss << (rparams.fastest_dim ? "Red On Fastest Dim // " : "Red On Slow Dim // ")
|
|
<< (rparams.persistent_kernel ? "Persistent Kernel // " : "")
|
|
<< (rparams.project_persistent_buffers ? "Project Persistent Buffers // "
|
|
: "");
|
|
|
|
if (rparams.schedule_3D) {
|
|
ss << "3D Schedule // "
|
|
<< "Outer Reduction: "
|
|
<< (rparams.cross_block_outer_reduction ? "cross block / " : "")
|
|
<< (rparams.cross_grid_outer_reduction ? "cross grid / " : "")
|
|
<< (rparams.split_grid_dim_outer_reduction ? "split grid dim / " : "");
|
|
if (rparams.batches_per_block_outer_reduction > 1 ||
|
|
rparams.persistent_kernel) {
|
|
ss << "persistent batch - " << rparams.batches_per_block_outer_reduction
|
|
<< " / ";
|
|
}
|
|
}
|
|
|
|
ss << " // Iteration Domain: "
|
|
<< (rparams.multiple_reds_per_blk ? "multiple reductions per block / "
|
|
: "")
|
|
<< (rparams.split_grid_dim_iter_dom ? "split grid dimension / " : "")
|
|
<< (rparams.vectorize_iter_dom ? "vectorize / " : "")
|
|
<< (rparams.unroll_factor_iter_dom > 1 && !rparams.vectorize_iter_dom
|
|
? "unroll / "
|
|
: "");
|
|
if (rparams.unroll_factor_iter_dom > 1 || rparams.vectorize_iter_dom) {
|
|
ss << "factor " << rparams.unroll_factor_iter_dom;
|
|
}
|
|
|
|
ss << " // Inner Reduction Domain: "
|
|
<< (rparams.cross_block_inner_reduction ? "cross block reduction / " : "")
|
|
<< (rparams.pad_inner_reduction_to_warp ? "pad to warp / " : "")
|
|
<< (rparams.cross_grid_inner_reduction ? "cross grid reduction / " : "");
|
|
|
|
if (rparams.batches_per_block_inner_reduction > 1 ||
|
|
rparams.persistent_kernel) {
|
|
ss << "persistent batch - " << rparams.batches_per_block_inner_reduction
|
|
<< " / ";
|
|
}
|
|
|
|
ss << (rparams.cross_grid_inner_reduction &&
|
|
rparams.split_grid_dim_inner_reduction
|
|
? "split grid dimension / "
|
|
: "")
|
|
<< (rparams.vectorize_inner_reduction ? "vectorize / " : "")
|
|
<< (rparams.unroll_factor_inner_reduction > 1 &&
|
|
!rparams.vectorize_inner_reduction
|
|
? "unroll / "
|
|
: "");
|
|
if (rparams.unroll_factor_inner_reduction > 1 ||
|
|
rparams.vectorize_inner_reduction) {
|
|
ss << "factor " << rparams.unroll_factor_inner_reduction;
|
|
}
|
|
return ss.str();
|
|
}
|
|
|
|
std::string toString(PointwiseParams params) {
|
|
std::stringstream ss;
|
|
if (params.break_point) {
|
|
ss << "2D Schedule at " << params.break_point << "/";
|
|
if (params.split_block) {
|
|
ss << " Split block into y-dim/";
|
|
}
|
|
if (params.split_grid_y_dim) {
|
|
ss << " Split y grid dim/";
|
|
}
|
|
} else {
|
|
ss << "1D"
|
|
<< "/";
|
|
}
|
|
if (params.unroll_factor > 1) {
|
|
if (params.vectorize) {
|
|
ss << "Vectorize, Factor: " << params.unroll_factor;
|
|
} else {
|
|
ss << "Unroll, Factor: " << params.unroll_factor;
|
|
}
|
|
}
|
|
return ss.str();
|
|
}
|
|
|
|
std::string toString(const std::shared_ptr<HeuristicParams>& params) {
|
|
auto rparams = std::dynamic_pointer_cast<ReductionParams>(params);
|
|
if (rparams) {
|
|
return toString(*rparams);
|
|
}
|
|
auto pparams = std::dynamic_pointer_cast<PointwiseParams>(params);
|
|
if (pparams) {
|
|
return toString(*pparams);
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"Unknown heuristic parameter type. Did you just added a new heuristic parameter type but forget to update here?");
|
|
}
|
|
|
|
std::string toString(LaunchParams lparams) {
|
|
std::stringstream ss;
|
|
lparams.toString();
|
|
ss << "/Launch_Parameters["
|
|
<< "block(" << lparams.bdimz() << "/" << lparams.bdimy() << "/"
|
|
<< lparams.bdimx() << ")/grid(" << lparams.gdimz() << "/"
|
|
<< lparams.gdimy() << "/" << lparams.gdimx() << ")/" << lparams.smem()
|
|
<< "]";
|
|
return ss.str();
|
|
}
|
|
|
|
void clearL2Cache() {
|
|
torch::NoGradGuard no_grad;
|
|
auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize;
|
|
auto options =
|
|
torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0);
|
|
|
|
auto l2_elems = l2_cache_size / 4;
|
|
torch::Tensor t0 = torch::empty(l2_elems, options);
|
|
torch::Tensor t1 = torch::clone(t0);
|
|
};
|
|
|
|
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype) {
|
|
return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
|
|
}
|
|
|
|
TensorView* makeContigTensor(size_t ndims, DataType dtype) {
|
|
return TensorViewBuilder()
|
|
.ndims(ndims)
|
|
.dtype(dtype)
|
|
.contiguity(std::vector<bool>(ndims, true))
|
|
.build();
|
|
}
|
|
|
|
TensorView* makeConcreteTensor(std::vector<int64_t> shape, DataType dtype) {
|
|
return TensorViewBuilder().shape(shape).dtype(dtype).build();
|
|
}
|
|
|
|
TensorView* makeContigConcreteTensor(
|
|
std::vector<int64_t> shape,
|
|
DataType dtype) {
|
|
return TensorViewBuilder()
|
|
.shape(shape)
|
|
.dtype(dtype)
|
|
.contiguity(std::vector<bool>(shape.size(), true))
|
|
.build();
|
|
}
|
|
|
|
void runBenchmarkIterations(
|
|
benchmark::State& benchmark_state,
|
|
FusionExecutorCache* fusion_executor_cache,
|
|
std::vector<c10::IValue>& aten_inputs) {
|
|
fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
bool segmented =
|
|
fusion_executor_cache->getMostRecentKernelRuntime()->isSegmented() &&
|
|
fusion_executor_cache->getMostRecentKernelRuntime()
|
|
->fusionSegments()
|
|
->groups()
|
|
.size() > 1;
|
|
|
|
if (!segmented) {
|
|
fusion_executor_cache->profile(true);
|
|
fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
|
|
auto executor_instance = compile_log.fusion_executor;
|
|
|
|
auto params = toString(compile_log.params);
|
|
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
|
benchmark_state.SetLabel(params + lparams);
|
|
|
|
executor_instance->setMeasureKernelTimeFlag(true);
|
|
|
|
// Sync everything up before we start
|
|
cudaDeviceSynchronize();
|
|
for (auto _ : benchmark_state) {
|
|
clearL2Cache();
|
|
auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
benchmark_state.SetIterationTime(
|
|
executor_instance->kernelTimeMs() / 1000.0);
|
|
}
|
|
// Sync everything up before we're finished, don't want to run ahead on the
|
|
// cpu while benchmarking.
|
|
cudaDeviceSynchronize();
|
|
} else {
|
|
// Segmented
|
|
// Sync everything up before we start
|
|
{
|
|
// Compile/warmup
|
|
auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
}
|
|
cudaDeviceSynchronize();
|
|
CudaKernelTimer timer;
|
|
for (auto _ : benchmark_state) {
|
|
clearL2Cache();
|
|
timer.restart();
|
|
auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
|
|
}
|
|
// Sync everything up before we're finished, don't want to run ahead on the
|
|
// cpu while benchmarking.
|
|
cudaDeviceSynchronize();
|
|
}
|
|
}
|
|
|
|
namespace executorCache {
|
|
thread_local ExecutorMap executor_map_;
|
|
ExecutorMap& getGlobalMap() {
|
|
return executor_map_;
|
|
}
|
|
} // namespace executorCache
|