mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported
nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor
Things reverted from local changes:
aten::gelu with approximation (tracked in PR: https://github.com/pytorch/pytorch/pull/61439)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72127
Reviewed By: HamidShojanazeri
Differential Revision: D34113233
Pulled By: jbschlosser
fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e)
224 lines
7.0 KiB
C++
224 lines
7.0 KiB
C++
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
|
#include <torch/csrc/jit/codegen/cuda/executor.h>
|
|
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
|
|
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
|
|
|
#include <benchmark/benchmark.h>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include "utils.h"
|
|
|
|
using namespace torch::jit::fuser::cuda;
|
|
|
|
static void setupInstanceNorm(Fusion* fusion, DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
|
|
|
|
FusionGuard fg(fusion);
|
|
|
|
auto input = makeContigTensor(4, dtype);
|
|
auto weight = makeContigTensor(1, dtype);
|
|
auto bias = makeContigTensor(1, dtype);
|
|
auto running_mean = makeContigTensor(1, DataType::Float);
|
|
auto running_var = makeContigTensor(1, DataType::Float);
|
|
|
|
fusion->addInput(input);
|
|
fusion->addInput(weight);
|
|
fusion->addInput(bias);
|
|
fusion->addInput(running_mean);
|
|
fusion->addInput(running_var);
|
|
|
|
if (dtype == DataType::Half) {
|
|
input = castOp(DataType::Float, input);
|
|
weight = castOp(DataType::Float, weight);
|
|
bias = castOp(DataType::Float, bias);
|
|
}
|
|
|
|
const bool kTraining = true;
|
|
const float kMomentum = 0.1;
|
|
const float kEps = 1e-5;
|
|
auto momentum_ptr = IrBuilder::create<Double>(kMomentum);
|
|
auto eps_ptr = IrBuilder::create<Double>(kEps);
|
|
|
|
auto norm = instance_norm(
|
|
input,
|
|
weight,
|
|
bias,
|
|
running_mean,
|
|
running_var,
|
|
kTraining,
|
|
momentum_ptr,
|
|
eps_ptr);
|
|
|
|
auto output = unaryOp(UnaryOpType::Relu, norm.output);
|
|
|
|
if (dtype == DataType::Half) {
|
|
output = castOp(DataType::Half, output);
|
|
}
|
|
|
|
fusion->addOutput(output);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void NvFuserScheduler_InstanceNorm(
|
|
benchmark::State& benchmark_state,
|
|
FusionExecutorCache* fusion_executor_cache,
|
|
DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
|
|
|
|
std::vector<int64_t> input_shape{
|
|
benchmark_state.range(0),
|
|
benchmark_state.range(2),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(1)};
|
|
|
|
// inputs
|
|
at::manual_seed(0);
|
|
auto options =
|
|
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
|
|
auto fp32_options =
|
|
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
|
|
at::Tensor at_x = at::randn(input_shape, options);
|
|
at::Tensor at_weight = at::ones({input_shape[1]}, options);
|
|
at::Tensor at_bias = at::zeros({input_shape[1]}, options);
|
|
at::Tensor at_mean = at::zeros({input_shape[1]}, fp32_options);
|
|
at::Tensor at_var = at::ones({input_shape[1]}, fp32_options);
|
|
|
|
std::vector<c10::IValue> aten_inputs = {
|
|
at_x, at_weight, at_bias, at_mean, at_var};
|
|
std::vector<at::Tensor> outputs;
|
|
|
|
runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
|
|
|
|
const size_t kSize =
|
|
input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
|
|
const size_t kChannels = input_shape[1];
|
|
|
|
// Read: x, weight, bias, running_mean, running_var
|
|
// Write: y, running_mean, running_var
|
|
benchmark_state.SetBytesProcessed(
|
|
benchmark_state.iterations() *
|
|
((kChannels * 2 + kSize * 2) * dataTypeSize(dtype) +
|
|
(kChannels * 2 * 2) * dataTypeSize(DataType::Float)));
|
|
}
|
|
|
|
static void Baseline_InstanceNorm(
|
|
benchmark::State& benchmark_state,
|
|
DataType dtype) {
|
|
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
|
|
|
|
std::vector<int64_t> input_shape{
|
|
benchmark_state.range(0),
|
|
benchmark_state.range(2),
|
|
benchmark_state.range(1),
|
|
benchmark_state.range(1)};
|
|
const float kMomentum = 0.1;
|
|
const float kEps = 1e-5;
|
|
const auto aten_dtype = data_type_to_aten(dtype);
|
|
|
|
at::manual_seed(0);
|
|
auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
|
|
auto fp32_options =
|
|
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
|
|
|
|
at::Tensor at_x = at::randn(input_shape, options);
|
|
at::Tensor at_weight = at::ones({input_shape[1]}, options);
|
|
at::Tensor at_bias = at::zeros({input_shape[1]}, options);
|
|
at::Tensor at_mean = at::zeros({input_shape[1]}, fp32_options);
|
|
at::Tensor at_var = at::ones({input_shape[1]}, fp32_options);
|
|
|
|
auto ato_weight = c10::optional<at::Tensor>(at_weight);
|
|
auto ato_bias = c10::optional<at::Tensor>(at_bias);
|
|
auto ato_running_mean = c10::optional<at::Tensor>(at_mean);
|
|
auto ato_running_var = c10::optional<at::Tensor>(at_var);
|
|
|
|
clearL2Cache();
|
|
cudaDeviceSynchronize();
|
|
for (auto _ : benchmark_state) {
|
|
CudaKernelTimer timer;
|
|
|
|
auto norm = at::instance_norm(
|
|
at_x,
|
|
ato_weight,
|
|
ato_bias,
|
|
ato_running_mean,
|
|
ato_running_var,
|
|
true,
|
|
kMomentum,
|
|
kEps,
|
|
false);
|
|
auto output = at::relu(norm);
|
|
|
|
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
|
|
cudaDeviceSynchronize();
|
|
clearL2Cache();
|
|
cudaDeviceSynchronize();
|
|
}
|
|
|
|
const size_t kSize =
|
|
input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
|
|
const size_t kChannels = input_shape[1];
|
|
|
|
// Read: x, weight, bias, running_mean, running_var
|
|
// Write: y, running_mean, running_var
|
|
benchmark_state.SetBytesProcessed(
|
|
benchmark_state.iterations() *
|
|
((kChannels * 2 + kSize * 2) * dataTypeSize(dtype) +
|
|
(kChannels * 2 * 2) * dataTypeSize(DataType::Float)));
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
static void Baseline_InstanceNorm_fp32(benchmark::State& benchmark_state) {
|
|
Baseline_InstanceNorm(benchmark_state, DataType::Float);
|
|
}
|
|
|
|
static void Baseline_InstanceNorm_fp16(benchmark::State& benchmark_state) {
|
|
Baseline_InstanceNorm(benchmark_state, DataType::Half);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_InstanceNorm_fp32,
|
|
setupInstanceNorm,
|
|
NvFuserScheduler_InstanceNorm,
|
|
DataType::Float);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
NVFUSER_BENCHMARK_DEFINE(
|
|
NvFuserScheduler_InstanceNorm_fp16,
|
|
setupInstanceNorm,
|
|
NvFuserScheduler_InstanceNorm,
|
|
DataType::Half);
|
|
|
|
NVFUSER_BENCHMARK_RUN(NvFuserScheduler_InstanceNorm_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 256}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
//------------------------------------------------------------------------------
|
|
|
|
BENCHMARK(Baseline_InstanceNorm_fp32)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 128}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
BENCHMARK(Baseline_InstanceNorm_fp16)
|
|
// ->RangeMultiplier(2)
|
|
->Ranges({{8, 8}, {640, 640}, {64, 256}})
|
|
->Unit(benchmark::kMicrosecond)
|
|
->UseManualTime();
|
|
|
|
//------------------------------------------------------------------------------
|