pytorch/benchmarks/cpp/nvfuser/rms_norm.cpp
jjsjann123 873ced7cd0 Nvfuser code bump 030122 (#73627)
Summary:
Things changed in this PR that requires review:

test/forward_backward_compatibility/check_forward_backward_compatibility.py

Our previous function overload extension names were wrong and has been updated in this PR, hence the compatibility list updated.

nvfuser code updates with bug fixes towards failures we encountered in OpInfoTests as well as failures reported by AOTAutograd team.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73627

Reviewed By: Chillee

Differential Revision: D34765458

Pulled By: davidberard98

fbshipit-source-id: c81f3d6a1b723fb3a8ba419b7f82227f70440ca7
(cherry picked from commit b6a2c362c37051e44fac31687b2fe272f776551e)
2022-03-31 08:18:22 +00:00

172 lines
4.7 KiB
C++

#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <benchmark/benchmark.h>
#include <cuda_runtime.h>
#include "utils.h"
using namespace torch::jit::fuser::cuda;
//------------------------------------------------------------------------------
static void setupRMSNorm(Fusion* fusion, DataType dtype) {
TORCH_INTERNAL_ASSERT(
dtype == DataType::Float || dtype == DataType::Half ||
dtype == DataType::BFloat16);
FusionGuard fg(fusion);
const int kReductionAxis = 2;
const float kEps = 1e-6;
Double* eps_ptr = IrBuilder::create<Double>(kEps);
// setup fusion
auto input = makeContigTensor(3, dtype);
auto weight = makeContigTensor(1, dtype);
fusion->addInput(input);
fusion->addInput(weight);
if (dtype == DataType::Half) {
input = castOp(DataType::Float, input);
weight = castOp(DataType::Float, weight);
}
auto rms_norm_results = rms_norm(input, 1, weight, eps_ptr);
auto output = rms_norm_results.output;
if (dtype != DataType::Float) {
output = castOp(dtype, output);
}
fusion->addOutput(output);
}
static void NvFuserScheduler_RMSNorm(
benchmark::State& benchmark_state,
FusionExecutorCache* fusion_executor_cache,
DataType dtype) {
TORCH_INTERNAL_ASSERT(
dtype == DataType::Float || dtype == DataType::Half ||
dtype == DataType::BFloat16);
std::vector<int64_t> input_shape{8, benchmark_state.range(0), 1024};
const float kEps = 1e-6;
// inputs
at::manual_seed(0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
at::Tensor input = at::randn(input_shape, options);
at::Tensor weight = at::randn({input_shape[2]}, options);
std::vector<c10::IValue> aten_inputs({input, weight});
runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
benchmark_state.SetBytesProcessed(
int64_t(benchmark_state.iterations()) *
(2 * input.numel() + weight.numel()) * int64_t(dataTypeSize(dtype)));
}
//------------------------------------------------------------------------------
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_RMSNorm_fp32,
setupRMSNorm,
NvFuserScheduler_RMSNorm,
DataType::Float);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp32)
->RangeMultiplier(2)
->Ranges({{16, 64}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp32)
->RangeMultiplier(2)
->Ranges({{18, 56}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp32)
->RangeMultiplier(2)
->Ranges({{22, 44}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp32)
->RangeMultiplier(2)
->Ranges({{24, 48}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_RMSNorm_fp16,
setupRMSNorm,
NvFuserScheduler_RMSNorm,
DataType::Half);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp16)
->RangeMultiplier(2)
->Ranges({{16, 64}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp16)
->RangeMultiplier(2)
->Ranges({{18, 56}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp16)
->RangeMultiplier(2)
->Ranges({{22, 44}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp16)
->RangeMultiplier(2)
->Ranges({{24, 48}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_DEFINE(
NvFuserScheduler_RMSNorm_bf16,
setupRMSNorm,
NvFuserScheduler_RMSNorm,
DataType::BFloat16);
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_bf16)
->RangeMultiplier(2)
->Ranges({{16, 64}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_bf16)
->RangeMultiplier(2)
->Ranges({{18, 56}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_bf16)
->RangeMultiplier(2)
->Ranges({{22, 44}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_bf16)
->RangeMultiplier(2)
->Ranges({{24, 48}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();