mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported
nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor
Things reverted from local changes:
aten::gelu with approximation (tracked in PR: https://github.com/pytorch/pytorch/pull/61439)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72127
Reviewed By: HamidShojanazeri
Differential Revision: D34113233
Pulled By: jbschlosser
fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e)
278 lines
8.7 KiB
C++
278 lines
8.7 KiB
C++
#include <torch/csrc/jit/codegen/cuda/executor.h>
|
|
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
|
|
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
|
|
|
#include <benchmark/benchmark.h>
|
|
|
|
#include <ATen/Operators.h>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include "utils.h"
|
|
|
|
using namespace torch::jit::fuser::cuda;
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void setupBatchNorm_BWD(Fusion* fusion, DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
|
|
|
|
FusionGuard fg(fusion);
|
|
|
|
const bool kTraining = true;
|
|
const float kMomentum = 0.1;
|
|
const float kEps = 1e-5;
|
|
|
|
// setup fusion
|
|
auto input = makeContigTensor(4, dtype);
|
|
auto grad_output = makeContigTensor(4, dtype);
|
|
auto weight = makeContigTensor(1, DataType::Float);
|
|
auto running_mean = makeContigTensor(1, DataType::Float);
|
|
auto running_var = makeContigTensor(1, DataType::Float);
|
|
auto save_mean = makeContigTensor(1, DataType::Float);
|
|
auto save_var = makeContigTensor(1, DataType::Float);
|
|
|
|
fusion->addInput(input);
|
|
fusion->addInput(grad_output);
|
|
fusion->addInput(weight);
|
|
fusion->addInput(running_mean);
|
|
fusion->addInput(running_var);
|
|
fusion->addInput(save_mean);
|
|
fusion->addInput(save_var);
|
|
|
|
if (dtype == DataType::Half) {
|
|
input = castOp(DataType::Float, input);
|
|
grad_output = castOp(DataType::Float, grad_output);
|
|
}
|
|
|
|
auto eps_ptr = IrBuilder::create<Double>(kEps);
|
|
|
|
auto result = batch_norm_backward(
|
|
input,
|
|
grad_output,
|
|
weight,
|
|
running_mean,
|
|
running_var,
|
|
save_mean,
|
|
save_var,
|
|
kTraining,
|
|
eps_ptr,
|
|
std::vector<bool>(3, true));
|
|
|
|
auto grad_input = result.grad_input;
|
|
auto grad_weight = result.grad_weight;
|
|
auto grad_bias = result.grad_bias;
|
|
|
|
if (dtype == DataType::Half) {
|
|
grad_input = castOp(DataType::Half, grad_input);
|
|
grad_weight = castOp(DataType::Half, grad_weight);
|
|
grad_bias = castOp(DataType::Half, grad_bias);
|
|
}
|
|
|
|
fusion->addOutput(grad_input);
|
|
fusion->addOutput(grad_weight);
|
|
fusion->addOutput(grad_bias);
|
|
}
|
|
|
|
static void NvFuserScheduler_BatchNorm_BWD(
|
|
benchmark::State& benchmark_state,
|
|
FusionExecutorCache* fusion_executor_cache,
|
|
DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
|
|
|
|
const bool kTraining = true;
|
|
const float kEps = 1e-5;
|
|
|
|
std::vector<int64_t> input_shape{
|
|
benchmark_state.range(0),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(2),
|
|
benchmark_state.range(2)};
|
|
|
|
at::manual_seed(0);
|
|
auto options =
|
|
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
|
|
auto fp32_options =
|
|
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
|
|
at::Tensor input = at::randn(input_shape, options);
|
|
at::Tensor grad_out = at::randn(input_shape, options);
|
|
at::Tensor weight = at::ones({input_shape[1]}, fp32_options);
|
|
at::Tensor run_mean = at::zeros({input_shape[1]}, fp32_options);
|
|
at::Tensor run_var = at::ones({input_shape[1]}, fp32_options);
|
|
at::Tensor save_mean = at::zeros({input_shape[1]}, fp32_options);
|
|
at::Tensor save_var = at::ones({input_shape[1]}, fp32_options);
|
|
|
|
std::vector<c10::IValue> aten_inputs(
|
|
{input, grad_out, weight, run_mean, run_var, save_mean, save_var});
|
|
|
|
runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
|
|
|
|
benchmark_state.SetBytesProcessed(
|
|
int64_t(benchmark_state.iterations()) *
|
|
(((3 * input.numel()) * int64_t(dataTypeSize(dtype))) +
|
|
(run_mean.numel() + run_var.numel() + save_mean.numel() +
|
|
save_var.numel() + weight.numel()) *
|
|
int64_t(dataTypeSize(DataType::Float))));
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void Baseline_BatchNorm_BWD(
|
|
benchmark::State& benchmark_state,
|
|
DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
|
|
|
|
const float kMomentum = 0.1;
|
|
const float kEps = 1e-5;
|
|
std::vector<int64_t> input_shape{
|
|
benchmark_state.range(0),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(2),
|
|
benchmark_state.range(2)};
|
|
|
|
at::manual_seed(0);
|
|
auto options =
|
|
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
|
|
auto fp32_options =
|
|
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
|
|
at::Tensor input = at::randn(input_shape, options);
|
|
at::Tensor grad_out = at::randn(input_shape, options);
|
|
at::Tensor weight = at::ones({input_shape[1]}, fp32_options);
|
|
at::Tensor bias = at::zeros({input_shape[1]}, fp32_options);
|
|
at::Tensor run_mean = at::zeros({input_shape[1]}, fp32_options);
|
|
at::Tensor run_var = at::ones({input_shape[1]}, fp32_options);
|
|
at::Tensor save_mean = at::zeros({input_shape[1]}, fp32_options);
|
|
at::Tensor save_var = at::ones({input_shape[1]}, fp32_options);
|
|
|
|
auto ato_weight = c10::optional<at::Tensor>(weight);
|
|
auto ato_bias = c10::optional<at::Tensor>(bias);
|
|
auto ato_run_mean = c10::optional<at::Tensor>(run_mean);
|
|
auto ato_run_var = c10::optional<at::Tensor>(run_var);
|
|
auto ato_save_mean = c10::optional<at::Tensor>(save_mean);
|
|
auto ato_save_var = c10::optional<at::Tensor>(save_var);
|
|
|
|
auto fwd_result = at::_ops::_batch_norm_impl_index::call(
|
|
input,
|
|
ato_weight,
|
|
ato_bias,
|
|
ato_run_mean,
|
|
ato_run_var,
|
|
true,
|
|
kMomentum,
|
|
kEps,
|
|
true);
|
|
cudaDeviceSynchronize();
|
|
|
|
// Sync everything up before we start
|
|
clearL2Cache();
|
|
cudaDeviceSynchronize();
|
|
for (auto _ : benchmark_state) {
|
|
CudaKernelTimer timer;
|
|
|
|
at::_ops::cudnn_batch_norm_backward::call(
|
|
input,
|
|
grad_out,
|
|
weight,
|
|
ato_run_mean,
|
|
ato_run_var,
|
|
save_mean,
|
|
save_var,
|
|
kEps,
|
|
std::get<3>(fwd_result));
|
|
|
|
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
|
|
cudaDeviceSynchronize();
|
|
clearL2Cache();
|
|
cudaDeviceSynchronize();
|
|
}
|
|
|
|
benchmark_state.SetBytesProcessed(
|
|
int64_t(benchmark_state.iterations()) *
|
|
(((3 * input.numel()) * int64_t(dataTypeSize(dtype))) +
|
|
(run_mean.numel() + run_var.numel() + save_mean.numel() +
|
|
save_var.numel() + weight.numel()) *
|
|
int64_t(dataTypeSize(DataType::Float))));
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void Baseline_BatchNorm_BWD_cuDNN_fp32(
|
|
benchmark::State& benchmark_state) {
|
|
Baseline_BatchNorm_BWD(benchmark_state, DataType::Float);
|
|
}
|
|
|
|
static void Baseline_BatchNorm_BWD_cuDNN_fp16(
|
|
benchmark::State& benchmark_state) {
|
|
Baseline_BatchNorm_BWD(benchmark_state, DataType::Half);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_BatchNorm_BWD_fp32,
|
|
setupBatchNorm_BWD,
|
|
NvFuserScheduler_BatchNorm_BWD,
|
|
DataType::Float);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_BWD_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{64, 512}, {32, 128}, {2, 64}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_BWD_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{2, 64}, {2, 32}, {2, 256}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_BatchNorm_BWD_fp16,
|
|
setupBatchNorm_BWD,
|
|
NvFuserScheduler_BatchNorm_BWD,
|
|
DataType::Half);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_BWD_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{64, 512}, {32, 128}, {2, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_BWD_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{2, 64}, {2, 32}, {2, 256}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
BENCHMARK(Baseline_BatchNorm_BWD_cuDNN_fp32)
|
|
// ->RangeMultiplier(2)
|
|
// cuDNN didn't make it to 1024
|
|
->Ranges({{64, 512}, {32, 128}, {2, 64}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
BENCHMARK(Baseline_BatchNorm_BWD_cuDNN_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{2, 64}, {2, 32}, {2, 256}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
BENCHMARK(Baseline_BatchNorm_BWD_cuDNN_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{64, 512}, {32, 128}, {2, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
BENCHMARK(Baseline_BatchNorm_BWD_cuDNN_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{2, 64}, {2, 32}, {2, 256}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|