mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Things changed in this PR that requires review: test/forward_backward_compatibility/check_forward_backward_compatibility.py Our previous function overload extension names were wrong and has been updated in this PR, hence the compatibility list updated. nvfuser code updates with bug fixes towards failures we encountered in OpInfoTests as well as failures reported by AOTAutograd team. Pull Request resolved: https://github.com/pytorch/pytorch/pull/73627 Reviewed By: Chillee Differential Revision: D34765458 Pulled By: davidberard98 fbshipit-source-id: c81f3d6a1b723fb3a8ba419b7f82227f70440ca7 (cherry picked from commit b6a2c362c37051e44fac31687b2fe272f776551e)
172 lines
4.7 KiB
C++
172 lines
4.7 KiB
C++
#include <torch/csrc/jit/codegen/cuda/executor.h>
|
|
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
|
|
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
|
|
|
#include <benchmark/benchmark.h>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include "utils.h"
|
|
|
|
using namespace torch::jit::fuser::cuda;
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void setupRMSNorm(Fusion* fusion, DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
dtype == DataType::Float || dtype == DataType::Half ||
|
|
dtype == DataType::BFloat16);
|
|
|
|
FusionGuard fg(fusion);
|
|
|
|
const int kReductionAxis = 2;
|
|
const float kEps = 1e-6;
|
|
|
|
Double* eps_ptr = IrBuilder::create<Double>(kEps);
|
|
|
|
// setup fusion
|
|
auto input = makeContigTensor(3, dtype);
|
|
auto weight = makeContigTensor(1, dtype);
|
|
|
|
fusion->addInput(input);
|
|
fusion->addInput(weight);
|
|
|
|
if (dtype == DataType::Half) {
|
|
input = castOp(DataType::Float, input);
|
|
weight = castOp(DataType::Float, weight);
|
|
}
|
|
|
|
auto rms_norm_results = rms_norm(input, 1, weight, eps_ptr);
|
|
|
|
auto output = rms_norm_results.output;
|
|
|
|
if (dtype != DataType::Float) {
|
|
output = castOp(dtype, output);
|
|
}
|
|
|
|
fusion->addOutput(output);
|
|
}
|
|
|
|
static void NvFuserScheduler_RMSNorm(
|
|
benchmark::State& benchmark_state,
|
|
FusionExecutorCache* fusion_executor_cache,
|
|
DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
dtype == DataType::Float || dtype == DataType::Half ||
|
|
dtype == DataType::BFloat16);
|
|
|
|
std::vector<int64_t> input_shape{8, benchmark_state.range(0), 1024};
|
|
const float kEps = 1e-6;
|
|
|
|
// inputs
|
|
at::manual_seed(0);
|
|
auto options =
|
|
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
|
|
at::Tensor input = at::randn(input_shape, options);
|
|
at::Tensor weight = at::randn({input_shape[2]}, options);
|
|
|
|
std::vector<c10::IValue> aten_inputs({input, weight});
|
|
|
|
runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
|
|
|
|
benchmark_state.SetBytesProcessed(
|
|
int64_t(benchmark_state.iterations()) *
|
|
(2 * input.numel() + weight.numel()) * int64_t(dataTypeSize(dtype)));
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_RMSNorm_fp32,
|
|
setupRMSNorm,
|
|
NvFuserScheduler_RMSNorm,
|
|
DataType::Float);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp32)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{16, 64}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp32)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{18, 56}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp32)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{22, 44}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp32)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{24, 48}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_RMSNorm_fp16,
|
|
setupRMSNorm,
|
|
NvFuserScheduler_RMSNorm,
|
|
DataType::Half);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp16)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{16, 64}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp16)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{18, 56}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp16)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{22, 44}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_fp16)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{24, 48}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_RMSNorm_bf16,
|
|
setupRMSNorm,
|
|
NvFuserScheduler_RMSNorm,
|
|
DataType::BFloat16);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_bf16)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{16, 64}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_bf16)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{18, 56}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_bf16)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{22, 44}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_RMSNorm_bf16)
|
|
->RangeMultiplier(2)
|
|
->Ranges({{24, 48}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|