mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD 16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875) 501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823) a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872) 172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871) 440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870) d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864) 51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866) a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861) 1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852) e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858) 9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857) 20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855) 405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853) 5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845) db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848) 102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847) b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842) 0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840) 63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839) 2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067 Approved by: https://github.com/davidberard98
407 lines
13 KiB
C++
407 lines
13 KiB
C++
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
|
#include <torch/csrc/jit/codegen/cuda/executor.h>
|
|
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
|
|
|
#include <benchmark/benchmark.h>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <benchmarks/cpp/nvfuser/utils.h>
|
|
|
|
using namespace torch::jit::fuser::cuda;
|
|
|
|
static void setupSBR(Fusion* fusion, DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
|
|
|
|
FusionGuard fg(fusion);
|
|
|
|
const size_t kNumberOfDims = 4;
|
|
|
|
std::vector<int64_t> bcast_shape(kNumberOfDims, 1);
|
|
bcast_shape[bcast_shape.size() - 1] = -1;
|
|
|
|
std::vector<bool> bcast_contig(kNumberOfDims, false);
|
|
bcast_contig[bcast_contig.size() - 1] = true;
|
|
|
|
auto x = makeContigTensor(kNumberOfDims, dtype);
|
|
|
|
auto scale = TensorViewBuilder()
|
|
.contiguity(bcast_contig)
|
|
.shape(bcast_shape)
|
|
.dtype(dtype)
|
|
.build();
|
|
|
|
auto bias = TensorViewBuilder()
|
|
.contiguity(bcast_contig)
|
|
.shape(bcast_shape)
|
|
.dtype(dtype)
|
|
.build();
|
|
|
|
fusion->addInput(x);
|
|
fusion->addInput(scale);
|
|
fusion->addInput(bias);
|
|
|
|
if (dtype == DataType::Half) {
|
|
x = castOp(DataType::Float, x);
|
|
scale = castOp(DataType::Float, scale);
|
|
bias = castOp(DataType::Float, bias);
|
|
}
|
|
|
|
auto scale_bias = add(mul(x, scale), bias);
|
|
auto scale_bias_relu = unaryOp(UnaryOpType::Relu, scale_bias);
|
|
|
|
if (dtype == DataType::Half) {
|
|
scale_bias_relu = castOp(DataType::Half, scale_bias_relu);
|
|
}
|
|
fusion->addOutput(scale_bias_relu);
|
|
}
|
|
|
|
static void setupSBRNorm(Fusion* fusion, DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
|
|
FusionGuard fg(fusion);
|
|
|
|
const size_t kNumberOfDims = 4;
|
|
|
|
auto x = makeContigTensor(kNumberOfDims, dtype);
|
|
auto weight = makeContigTensor(1, dtype);
|
|
auto bias = makeContigTensor(1, dtype);
|
|
auto mean = makeContigTensor(1, dtype);
|
|
auto var = makeContigTensor(1, dtype);
|
|
|
|
fusion->addInput(x);
|
|
fusion->addInput(weight);
|
|
fusion->addInput(bias);
|
|
fusion->addInput(mean);
|
|
fusion->addInput(var);
|
|
|
|
std::vector<bool> broadcast_mask(kNumberOfDims, true);
|
|
broadcast_mask[broadcast_mask.size() - 1] = false;
|
|
|
|
if (dtype == DataType::Half) {
|
|
x = castOp(DataType::Float, x);
|
|
weight = castOp(DataType::Float, weight);
|
|
bias = castOp(DataType::Float, bias);
|
|
mean = castOp(DataType::Float, mean);
|
|
var = castOp(DataType::Float, var);
|
|
}
|
|
|
|
auto rsqrt = unaryOp(UnaryOpType::Rsqrt, var);
|
|
auto this_scale = mul(weight, rsqrt);
|
|
auto this_bias = mul(sub(bias, mean), this_scale);
|
|
|
|
auto bcast_scale = broadcast(this_scale, broadcast_mask);
|
|
auto bcast_bias = broadcast(this_bias, broadcast_mask);
|
|
|
|
auto scale_bias = add(mul(x, bcast_scale), bcast_bias);
|
|
auto scale_bias_relu = unaryOp(UnaryOpType::Relu, scale_bias);
|
|
|
|
if (dtype == DataType::Half) {
|
|
scale_bias_relu = castOp(DataType::Half, scale_bias_relu);
|
|
}
|
|
|
|
fusion->addOutput(scale_bias_relu);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void NvFuserScheduler_SBR(
|
|
benchmark::State& benchmark_state,
|
|
FusionExecutorCache* fusion_executor_cache,
|
|
DataType dtype) {
|
|
// N, H, W, C format
|
|
std::vector<int64_t> input_shape{
|
|
benchmark_state.range(0),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(2)};
|
|
std::vector<int64_t> bcast_shape{1, 1, 1, -1};
|
|
|
|
// inputs
|
|
at::manual_seed(0);
|
|
std::vector<int64_t> static_bcast_shape{1, 1, 1, benchmark_state.range(2)};
|
|
auto options =
|
|
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
|
|
at::Tensor at_x = at::randn(input_shape, options);
|
|
at::Tensor at_scale = at::ones(static_bcast_shape, options);
|
|
at::Tensor at_bias = at::zeros(static_bcast_shape, options);
|
|
|
|
// inputs
|
|
std::vector<c10::IValue> aten_inputs = {at_x, at_scale, at_bias};
|
|
|
|
fusion_executor_cache->profile(true);
|
|
fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
|
|
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
|
|
auto executor_instance = compile_log.fusion_executor;
|
|
auto params = toString(compile_log.params);
|
|
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
|
|
|
benchmark_state.SetLabel(params + lparams);
|
|
benchmark_state.SetLabel(lparams);
|
|
|
|
fusion_executor_cache->profile(false);
|
|
executor_instance->setMeasureKernelTimeFlag(true);
|
|
// Sync everything up before we start
|
|
cudaDeviceSynchronize();
|
|
for (auto _ : benchmark_state) {
|
|
clearL2Cache();
|
|
auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
benchmark_state.SetIterationTime(
|
|
executor_instance->kernelTimeMs() / 1000.0);
|
|
}
|
|
// Sync everything up before we're finished, don't want to run ahead on the
|
|
// cpu while benchmarking.
|
|
cudaDeviceSynchronize();
|
|
|
|
const size_t size =
|
|
input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
|
|
const size_t channels = input_shape[3];
|
|
benchmark_state.SetBytesProcessed(
|
|
int64_t(benchmark_state.iterations()) * (channels * 2 + size * 2) *
|
|
int64_t(dataTypeSize(dtype)));
|
|
}
|
|
|
|
static void Baseline_SBR(benchmark::State& benchmark_state, DataType dtype) {
|
|
// N, H, W, C format
|
|
std::vector<int64_t> input_shape{
|
|
benchmark_state.range(0),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(2)};
|
|
std::vector<int64_t> bcast_shape{benchmark_state.range(2)};
|
|
|
|
// inputs
|
|
at::manual_seed(0);
|
|
auto options =
|
|
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
|
|
at::Tensor at_x = at::randn(input_shape, options);
|
|
at::Tensor at_y = at::randn(input_shape, options);
|
|
at::Tensor at_scale = at::ones(bcast_shape, options);
|
|
at::Tensor at_bias = at::zeros(bcast_shape, options);
|
|
|
|
clearL2Cache();
|
|
cudaDeviceSynchronize();
|
|
for (auto _ : benchmark_state) {
|
|
CudaKernelTimer timer;
|
|
|
|
auto scale = at::mul(at_x, at_scale);
|
|
auto bias = at::add(scale, at_bias);
|
|
auto output = at::relu(bias);
|
|
|
|
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
|
|
cudaDeviceSynchronize();
|
|
clearL2Cache();
|
|
cudaDeviceSynchronize();
|
|
}
|
|
|
|
const size_t size =
|
|
input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
|
|
const size_t channels = input_shape[3];
|
|
benchmark_state.SetBytesProcessed(
|
|
int64_t(benchmark_state.iterations()) * (channels * 2 + size * 2) *
|
|
int64_t(dataTypeSize(dtype)));
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void NvFuserScheduler_SBR_Norm(
|
|
benchmark::State& benchmark_state,
|
|
FusionExecutorCache* fusion_executor_cache,
|
|
DataType dtype) {
|
|
// N, H, W, C format
|
|
std::vector<int64_t> input_shape{
|
|
benchmark_state.range(0),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(2)};
|
|
std::vector<int64_t> bcast_shape{benchmark_state.range(2)};
|
|
|
|
// inputs
|
|
at::manual_seed(0);
|
|
auto options =
|
|
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
|
|
at::Tensor at_x = at::randn(input_shape, options);
|
|
at::Tensor at_weight = at::ones(bcast_shape, options);
|
|
at::Tensor at_bias = at::zeros(bcast_shape, options);
|
|
at::Tensor at_mean = at::zeros(bcast_shape, options);
|
|
at::Tensor at_var = at::ones(bcast_shape, options);
|
|
|
|
// inputs
|
|
std::vector<c10::IValue> aten_inputs = {
|
|
at_x, at_weight, at_bias, at_mean, at_var};
|
|
|
|
fusion_executor_cache->profile(true);
|
|
fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
|
|
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
|
|
auto executor_instance = compile_log.fusion_executor;
|
|
auto params = toString(compile_log.params);
|
|
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
|
|
|
benchmark_state.SetLabel(params + lparams);
|
|
|
|
fusion_executor_cache->profile(false);
|
|
executor_instance->setMeasureKernelTimeFlag(true);
|
|
// Sync everything up before we start
|
|
cudaDeviceSynchronize();
|
|
for (auto _ : benchmark_state) {
|
|
clearL2Cache();
|
|
auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
|
|
benchmark_state.SetIterationTime(
|
|
executor_instance->kernelTimeMs() / 1000.0);
|
|
}
|
|
|
|
// Sync everything up before we're finished, don't want to run ahead on the
|
|
// cpu while benchmarking.
|
|
cudaDeviceSynchronize();
|
|
|
|
const size_t size =
|
|
input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
|
|
const size_t channels = input_shape[3];
|
|
benchmark_state.SetBytesProcessed(
|
|
int64_t(benchmark_state.iterations()) * (channels * 4 + size * 2) *
|
|
int64_t(dataTypeSize(dtype)));
|
|
}
|
|
|
|
static void Baseline_SBR_Norm(
|
|
benchmark::State& benchmark_state,
|
|
DataType dtype) {
|
|
// N, H, W, C format
|
|
std::vector<int64_t> input_shape{
|
|
benchmark_state.range(0),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(2)};
|
|
std::vector<int64_t> bcast_shape{1, 1, 1, benchmark_state.range(2)};
|
|
|
|
// inputs
|
|
at::manual_seed(0);
|
|
auto options =
|
|
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
|
|
at::Tensor at_x = at::randn(input_shape, options);
|
|
at::Tensor at_weight = at::ones(bcast_shape, options);
|
|
at::Tensor at_bias = at::zeros(bcast_shape, options);
|
|
at::Tensor at_mean = at::zeros(bcast_shape, options);
|
|
at::Tensor at_var = at::ones(bcast_shape, options);
|
|
|
|
cudaDeviceSynchronize();
|
|
for (auto _ : benchmark_state) {
|
|
CudaKernelTimer timer;
|
|
|
|
auto this_scale = at::mul(at_weight, at::rsqrt(at_var));
|
|
auto this_bias = at::mul(at::sub(at_bias, at_mean), this_scale);
|
|
|
|
auto scale = at::mul(at_x, this_scale);
|
|
auto bias = at::add(scale, this_bias);
|
|
auto output = at::relu(bias);
|
|
|
|
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
|
|
cudaDeviceSynchronize();
|
|
}
|
|
|
|
const size_t size =
|
|
input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
|
|
const size_t channels = input_shape[3];
|
|
benchmark_state.SetBytesProcessed(
|
|
int64_t(benchmark_state.iterations()) * (channels * 4 + size * 2) *
|
|
int64_t(dataTypeSize(dtype)));
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_SBR_fp32,
|
|
setupSBR,
|
|
NvFuserScheduler_SBR,
|
|
DataType::Float);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_SBR_fp16,
|
|
setupSBR,
|
|
NvFuserScheduler_SBR,
|
|
DataType::Half);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_SBR_Norm_fp32,
|
|
setupSBRNorm,
|
|
NvFuserScheduler_SBR_Norm,
|
|
DataType::Float);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_Norm_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_SBR_Norm_fp16,
|
|
setupSBRNorm,
|
|
NvFuserScheduler_SBR_Norm,
|
|
DataType::Half);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_Norm_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void Baseline_SBR_fp32(benchmark::State& benchmark_state) {
|
|
Baseline_SBR(benchmark_state, DataType::Float);
|
|
}
|
|
|
|
BENCHMARK(Baseline_SBR_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
static void Baseline_SBR_fp16(benchmark::State& benchmark_state) {
|
|
Baseline_SBR(benchmark_state, DataType::Half);
|
|
}
|
|
|
|
BENCHMARK(Baseline_SBR_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void Baseline_SBR_Norm_fp32(benchmark::State& benchmark_state) {
|
|
Baseline_SBR_Norm(benchmark_state, DataType::Float);
|
|
}
|
|
|
|
BENCHMARK(Baseline_SBR_Norm_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
static void Baseline_SBR_Norm_fp16(benchmark::State& benchmark_state) {
|
|
Baseline_SBR_Norm(benchmark_state, DataType::Half);
|
|
}
|
|
|
|
BENCHMARK(Baseline_SBR_Norm_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|