Revert D28142447: Improve BatchNorm1d performance (CUDA)

Test Plan: revert-hammer

Differential Revision:
D28142447 (b2936ad8fa)

Original commit changeset: c70109780e20

fbshipit-source-id: e93f6d00d644697b106f5ea8ab79872f353b51c6
This commit is contained in:
Sam Estep 2021-05-06 15:00:22 -07:00 committed by Facebook GitHub Bot
parent 3948ce2fd9
commit 2992ff3fb8
11 changed files with 348 additions and 514 deletions

View File

@ -1,20 +0,0 @@
#include <ATen/AccumulateType.h>
namespace at {
c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) {
switch (type) {
#define DEFINE_CASE(scalar_t, TypeNum) \
case ScalarType::TypeNum: \
return is_cuda ? \
CppTypeToScalarType<at::acc_type<scalar_t, true>>::value : \
CppTypeToScalarType<at::acc_type<scalar_t, false>>::value;
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CASE)
#undef DEFINE_CASE
default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
}
}
}

View File

@ -2,7 +2,6 @@
#include <ATen/Config.h> #include <ATen/Config.h>
#include <c10/util/Half.h> #include <c10/util/Half.h>
#include <c10/util/BFloat16.h> #include <c10/util/BFloat16.h>
#include <c10/core/ScalarType.h>
// Defines the accumulation type for a scalar type. // Defines the accumulation type for a scalar type.
// Example: // Example:
@ -23,8 +22,8 @@ struct AccumulateType { };
#if defined(__CUDACC__) || defined(__HIPCC__) #if defined(__CUDACC__) || defined(__HIPCC__)
template <> struct AccumulateType<half, true> { using type = float; }; template <> struct AccumulateType<half, true> { using type = float; };
#endif
template <> struct AccumulateType<BFloat16, true> {using type = float; }; template <> struct AccumulateType<BFloat16, true> {using type = float; };
#endif
template <> struct AccumulateType<Half, true> { using type = float; }; template <> struct AccumulateType<Half, true> { using type = float; };
template <> struct AccumulateType<float, true> { using type = float; }; template <> struct AccumulateType<float, true> { using type = float; };
template <> struct AccumulateType<double, true> { using type = double; }; template <> struct AccumulateType<double, true> { using type = double; };
@ -35,7 +34,6 @@ template <> struct AccumulateType<int16_t, true> { using type = int64_t; };
template <> struct AccumulateType<int32_t, true> { using type = int64_t; }; template <> struct AccumulateType<int32_t, true> { using type = int64_t; };
template <> struct AccumulateType<int64_t, true> { using type = int64_t; }; template <> struct AccumulateType<int64_t, true> { using type = int64_t; };
template <> struct AccumulateType<bool, true> {using type = bool; }; template <> struct AccumulateType<bool, true> {using type = bool; };
template <> struct AccumulateType<Half, false> { using type = float; };
template <> struct AccumulateType<BFloat16, false> { using type = float; }; template <> struct AccumulateType<BFloat16, false> { using type = float; };
template <> struct AccumulateType<c10::complex<float>, false> { using type = c10::complex<double>; }; template <> struct AccumulateType<c10::complex<float>, false> { using type = c10::complex<double>; };
template <> struct AccumulateType<c10::complex<double>, false> { using type = c10::complex<double>; }; template <> struct AccumulateType<c10::complex<double>, false> { using type = c10::complex<double>; };
@ -49,11 +47,8 @@ template <> struct AccumulateType<char, false> { using type = int64_t; };
template <> struct AccumulateType<int16_t, false> { using type = int64_t; }; template <> struct AccumulateType<int16_t, false> { using type = int64_t; };
template <> struct AccumulateType<int32_t, false> { using type = int64_t; }; template <> struct AccumulateType<int32_t, false> { using type = int64_t; };
template <> struct AccumulateType<int64_t, false> { using type = int64_t; }; template <> struct AccumulateType<int64_t, false> { using type = int64_t; };
template <> struct AccumulateType<bool, false> {using type = bool; };
template<typename T, bool is_cuda> template<typename T, bool is_cuda>
using acc_type = typename AccumulateType<T, is_cuda>::type; using acc_type = typename AccumulateType<T, is_cuda>::type;
TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
} // namespace at } // namespace at

View File

@ -423,8 +423,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
check_dims_match_num_input_features("bias", num_features, bias.numel()); check_dims_match_num_input_features("bias", num_features, bias.numel());
} }
const bool use_cudnn = ( bool use_cudnn = false;
input.is_cuda() use_cudnn = (input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16 && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf && (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat) || weight.scalar_type() == at::kFloat)
@ -432,29 +432,26 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
&& ((running_mean.defined() && running_var.defined()) && ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training)) || (!running_mean.defined() && !running_var.defined() && training))
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
&& (input.dim() >= 3) && ((input.dim() == 2 && input.size(0) <= 131070 && training) // per-activation, training
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
&& ((input.size(0) <= 880801 && training) // spatial, training || (input.dim() == 2 && input.size(0) <= 262136 && !training) // per-activation, eval
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
||(input.size(0) <= 65535 && !training)) //spatial, eval || (input.dim() >= 3 && input.size(0) <= 880801 && training) // spatial, training
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|| (input.dim() >= 3 && input.size(0) <= 65535 && !training)) //spatial, eval
&& detail::getCUDAHooks().compiledWithCuDNN() && detail::getCUDAHooks().compiledWithCuDNN()
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN() // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L); && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L);
if (use_cudnn) { if (use_cudnn && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()) {
auto input_c = input.contiguous(input.suggest_memory_format()); return std::tuple_cat(
auto weight_c = weight.contiguous(); at::cudnn_batch_norm(
auto bias_c = bias.contiguous(); input.contiguous(input.suggest_memory_format()), weight.contiguous(),
auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean; bias.contiguous(),
auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var; running_mean.defined() ? running_mean.contiguous() : running_mean,
running_var.defined() ? running_var.contiguous() : running_var,
Tensor output, save_mean, save_var, reserve; training, momentum, eps),
std::tie(output, save_mean, save_var, reserve) = std::make_tuple(1));
at::cudnn_batch_norm(input_c, weight_c, bias_c, rmean_c, rvar_c,
training, momentum, eps);
return std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t>(
output, save_mean, save_var, reserve, 1);
} }
Tensor reserve = at::empty({0}, input.options().dtype(kByte)); Tensor reserve = at::empty({0}, input.options().dtype(kByte));

View File

@ -39,6 +39,4 @@ DECLARE_DISPATCH(cum_fn, cumsum_stub);
DECLARE_DISPATCH(cum_fn, cumprod_stub); DECLARE_DISPATCH(cum_fn, cumprod_stub);
DECLARE_DISPATCH(cum_fn, logcumsumexp_stub); DECLARE_DISPATCH(cum_fn, logcumsumexp_stub);
TORCH_API std::tuple<Tensor&,Tensor&> var_mean_out(Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim);
}} // namespace at::native }} // namespace at::native

View File

@ -129,7 +129,8 @@ struct WelfordOps {
auto ret = (divisor > 0) ? auto ret = (divisor > 0) ?
(take_sqrt ? device_sqrt(acc.m2 / divisor) : (acc.m2 / divisor)) (take_sqrt ? device_sqrt(acc.m2 / divisor) : (acc.m2 / divisor))
: NAN; : NAN;
return res_t(ret, mean); detail::pair<scalar_t, scalar_t> results{(scalar_t) ret, (scalar_t) mean};
return results;
} }
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {

View File

@ -1,13 +1,6 @@
#pragma once #pragma once
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <thrust/tuple.h>
#define NUM_THREADS (C10_WARP_SIZE * 2) #define NUM_THREADS (C10_WARP_SIZE * 2)
#define THREAD_WORK_SIZE 4 #define THREAD_WORK_SIZE 4
#define BLOCK_WORK_SIZE (THREAD_WORK_SIZE * num_threads) #define BLOCK_WORK_SIZE (THREAD_WORK_SIZE * num_threads)
@ -16,8 +9,14 @@ constexpr int num_threads = NUM_THREADS;
constexpr int thread_work_size = THREAD_WORK_SIZE; constexpr int thread_work_size = THREAD_WORK_SIZE;
constexpr int block_work_size = BLOCK_WORK_SIZE; constexpr int block_work_size = BLOCK_WORK_SIZE;
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh> #include <ATen/native/cuda/MemoryAccess.cuh>
#include <thrust/tuple.h>
namespace at { namespace native { namespace at { namespace native {
template<int N> template<int N>

View File

@ -1,280 +1,72 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/Resize.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Reduce.cuh>
#include <ATen/native/cuda/Normalization.cuh> #include <ATen/native/cuda/Normalization.cuh>
#include <c10/cuda/CUDAMathCompat.h>
namespace at { namespace native {
namespace {
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) { inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2; return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
} }
enum class Impl { namespace at { namespace native {
Contiguous,
ChannelsLast,
General,
};
inline Impl batch_norm_choose_impl(const Tensor& self) {
if (!at::cuda::detail::canUse32BitIndexMath(self)) {
return Impl::General;
}
if (self.is_contiguous()) {
return self.strides()[1] == 1 ? Impl::ChannelsLast : Impl::Contiguous;
}
if (self.is_contiguous(at::MemoryFormat::ChannelsLast)) {
return Impl::ChannelsLast;
}
return Impl::General;
}
void batch_norm_elementwise(
const Tensor& out, const Tensor& self, const c10::optional<Tensor>& weight_opt,
const c10::optional<Tensor>& bias_opt, const Tensor& mean_, const Tensor& invstd_) {
switch (batch_norm_choose_impl(self)) {
case Impl::Contiguous: {
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
c10::MaybeOwned<Tensor> bias = at::borrow_from_optional_tensor(bias_opt);
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(),
"batch_norm_elementwise_cuda", [&] {
batch_norm_elemt_cuda_template<scalar_t, scalar_t, int32_t>(
out, self, *weight, *bias, mean_, invstd_);
});
return;
}
case Impl::ChannelsLast: {
auto weight = at::borrow_from_optional_tensor(weight_opt);
auto bias = at::borrow_from_optional_tensor(bias_opt);
if ((!weight->defined() || weight->is_contiguous()) &&
(!bias->defined() || bias->is_contiguous()) &&
(!mean_.defined() || mean_.is_contiguous()) &&
(!invstd_.defined() || invstd_.is_contiguous())) {
batch_norm_elemt_channels_last_cuda_template(
out, self, *weight, *bias, mean_, invstd_);
return;
}
C10_FALLTHROUGH;
}
case Impl::General: {
const int64_t ndim = self.dim();
DimVector sizes(ndim, 1), strides(ndim, 0);
// Helper to convert 1d tensors to an nd tensor that broadcasts with input
// All elements go into the channel dimension
auto as_nd = [&](const Tensor& t) {
TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
sizes[1] = t.sizes()[0];
strides[1] = t.strides()[0];
return t.as_strided(sizes, strides);
};
auto weight = weight_opt.has_value() && weight_opt->defined() ?
as_nd(*weight_opt) : at::scalar_tensor(1, mean_.options());
auto bias = bias_opt.has_value() && bias_opt->defined() ?
as_nd(*bias_opt) : at::scalar_tensor(0, mean_.options());
auto mean = as_nd(mean_);
auto invstd = as_nd(invstd_);
auto iter = TensorIteratorConfig()
.add_output(out)
.add_input(self)
.add_input(weight)
.add_input(bias)
.add_input(mean)
.add_input(invstd)
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(false)
.build();
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(),
"batch_norm_elementwise_cuda", [&] {
using acc_t = at::acc_type<scalar_t, true>;
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t input, acc_t weight, acc_t bias,
acc_t mean, acc_t invstd) -> scalar_t {
return ((input - mean) * invstd) * weight + bias;
});
});
return;
}
}
}
void batch_norm_mean_var(const Tensor& self, Tensor& save_mean, Tensor& save_var) {
// NOTE: Epsilon is only used for InvStd, not Var. The value here is ignored.
const double dummy_epsilon = 1e-5;
switch (batch_norm_choose_impl(self)) {
case Impl::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
batch_norm_stats_cuda_template<scalar_t, int32_t, Var>(
save_mean, save_var, self, dummy_epsilon);
});
return;
}
case Impl::ChannelsLast: {
if ((!save_mean.defined() || save_mean.is_contiguous()) &&
(!save_var.defined() || save_var.is_contiguous())) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
batch_norm_stats_channels_last_cuda_template<scalar_t, Var>(
save_mean, save_var, self, dummy_epsilon);
});
return;
}
C10_FALLTHROUGH;
}
case Impl::General: {
const int64_t ndim = self.dim();
DimVector reduce_dims(ndim - 1);
reduce_dims[0] = 0;
for (int64_t i = 2; i < ndim; ++i) {
reduce_dims[i - 1] = i;
}
// For some reason this isn't an actual operator but it exists anyway...
at::native::var_mean_out(save_var, save_mean, self, /*dims=*/reduce_dims,
/*unbiased=*/false, /*keepdim=*/false);
return;
}
}
}
void batch_norm_update_stats(
const Tensor& save_mean, const Tensor& save_var,
const Tensor& running_mean, const Tensor& running_var,
double momentum_, int64_t N) {
auto iter = TensorIteratorConfig()
.add_output(running_mean)
.add_output(running_var)
.add_input(save_mean)
.add_input(save_var)
.add_input(running_mean)
.add_input(running_var)
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(false)
.build();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_mean.scalar_type(),
"batch_norm_update_stats_cuda", [&] {
using acc_t = at::acc_type<scalar_t, true>;
const auto bessel_correction_factor = static_cast<acc_t>(
static_cast<double>(N) / static_cast<double>(N - 1));
const auto momentum = static_cast<acc_t>(momentum_);
gpu_kernel_multiple_outputs(
iter, [=] GPU_LAMBDA (acc_t mean, acc_t var, scalar_t running_mean, scalar_t running_var)
-> thrust::tuple<scalar_t, scalar_t> {
const auto unbiased_var = var * bessel_correction_factor;
return thrust::tuple<scalar_t, scalar_t>{
mean * momentum + (1 - momentum) * running_mean,
unbiased_var * momentum + (1 - momentum) * running_var,
};
});
});
}
void batch_norm_update_stats_and_invert(
const Tensor& save_mean, const Tensor& save_var,
const Tensor& running_mean, const Tensor& running_var,
double momentum_, double epsilon, int64_t N) {
auto iter = TensorIteratorConfig()
.add_output(running_mean)
.add_output(running_var)
.add_output(save_var)
.add_input(save_mean)
.add_input(save_var)
.add_input(running_mean)
.add_input(running_var)
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(false)
.build();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_mean.scalar_type(),
"batch_norm_update_stats_cuda", [&] {
using acc_t = at::acc_type<scalar_t, true>;
const auto bessel_correction_factor = static_cast<acc_t>(
static_cast<double>(N) / static_cast<double>(N - 1));
const auto eps = static_cast<acc_t>(epsilon);
const auto momentum = static_cast<acc_t>(momentum_);
gpu_kernel_multiple_outputs(
iter, [=] GPU_LAMBDA (acc_t mean, acc_t var, scalar_t running_mean, scalar_t running_var)
-> thrust::tuple<scalar_t, scalar_t, acc_t> {
const auto unbiased_var = var * bessel_correction_factor;
return thrust::tuple<scalar_t, scalar_t, acc_t>{
mean * momentum + (1 - momentum) * running_mean,
unbiased_var * momentum + (1 - momentum) * running_var,
c10::cuda::compat::rsqrt(var + eps)
};
});
});
}
void batch_norm_calc_invstd(const Tensor& out_invstd, const Tensor& running_var, double epsilon) {
auto iter = TensorIteratorConfig()
.add_output(out_invstd)
.add_input(running_var)
.check_all_same_dtype(false)
.build();
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_var.scalar_type(),
"batch_norm_invert_std_cuda", [&] {
using acc_t = at::acc_type<scalar_t, true>;
auto eps = static_cast<acc_t>(epsilon);
gpu_kernel(iter, [eps] GPU_LAMBDA (scalar_t var) -> acc_t {
return c10::cuda::compat::rsqrt(var + eps);
});
});
}
}
std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) { std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined()); // See [Note: hacky wrapper removal for optional tensor]
const bool has_running_var = (running_mean_opt.has_value() && running_mean_opt->defined()); c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
TORCH_CHECK(has_running_mean == has_running_var); const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
if (train) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_cuda", [&] {
batch_norm_mean_var(self, save_mean, save_invstd); auto mean_st = running_mean.dtype();
if (has_running_mean) { auto var_st = running_var.dtype();
const int64_t N = self.numel() / save_mean.numel(); TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
batch_norm_update_stats_and_invert( bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
save_mean, save_invstd, *running_mean_opt, *running_var_opt, bool is_bfloat16_float = std::is_same<scalar_t, at::BFloat16>::value && mean_st == at::kFloat;
momentum, epsilon, N); if (cuda::detail::canUse32BitIndexMath(self)) {
if (is_half_float || is_bfloat16_float) {
batch_norm_cuda_template<scalar_t, float, int32_t>(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon);
} else { } else {
batch_norm_calc_invstd(save_invstd, save_invstd, epsilon); batch_norm_cuda_template<scalar_t, scalar_t, int32_t>(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon);
} }
} else { } else {
TORCH_CHECK(has_running_mean); if (is_half_float || is_bfloat16_float) {
at::native::resize_output(save_mean, running_mean_opt->sizes()); batch_norm_cuda_template<scalar_t, float, int64_t>(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon);
save_mean.copy_(*running_mean_opt, /*non_blocking=*/true); } else {
batch_norm_calc_invstd(save_invstd, running_var_opt.value(), epsilon); batch_norm_cuda_template<scalar_t, scalar_t, int64_t>(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon);
} }
}
batch_norm_elementwise(output, self, weight_opt, bias_opt, save_mean, save_invstd); });
return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_invstd); return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_invstd);
} }
std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon) { std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
auto output = at::empty_like(self, at::MemoryFormat::Contiguous); auto output = at::empty_like(self, at::MemoryFormat::Contiguous);
int64_t n_input = self.size(1); int64_t n_input = self.size(1);
auto options = self.options().dtype( auto input_options = self.options();
at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true)); // Accumulate in higher precision if input is half/bfloat16
auto save_mean = at::empty({n_input}, options); if (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16) {
auto save_invstd = at::empty({n_input}, options); input_options = input_options.dtype(ScalarType::Float);
}
Tensor save_mean, save_invstd;
if (train) {
save_mean = at::empty({n_input}, input_options);
save_invstd = at::empty({n_input}, input_options);
} else {
save_mean = at::empty({0}, input_options);
save_invstd = at::empty({0}, input_options);
}
at::native::batch_norm_cuda_out( at::native::batch_norm_cuda_out(
self, self,
weight_opt, weight,
bias_opt, bias,
running_mean_opt, running_mean,
running_var_opt, running_var,
train, train,
momentum, momentum,
epsilon, epsilon,
@ -316,45 +108,65 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_o
} }
std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) { std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) {
auto options = self.options().dtype(
at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
auto n_channels = self.size(1);
auto save_mean = at::empty({n_channels}, options);
auto save_invstd = at::empty({n_channels}, options);
bool use_channels_last_kernel = batch_norm_use_channels_last_kernels(self); bool use_channels_last_kernel = batch_norm_use_channels_last_kernels(self);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
self.scalar_type(), "batch_norm_stats_cuda", [&] { return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
if (cuda::detail::canUse32BitIndexMath(self)) { if (cuda::detail::canUse32BitIndexMath(self)) {
if (use_channels_last_kernel) { if (use_channels_last_kernel) {
batch_norm_stats_channels_last_cuda_template<scalar_t, InvStd>( return batch_norm_stats_channels_last_cuda_template<scalar_t>(self, epsilon);
save_mean, save_invstd, self, epsilon);
} else { } else {
batch_norm_stats_cuda_template<scalar_t, int32_t, InvStd>( return batch_norm_stats_cuda_template<scalar_t, int32_t>(self, epsilon);
save_mean, save_invstd, self, epsilon);
} }
} else { } else {
batch_norm_stats_cuda_template<scalar_t, int64_t, InvStd>( return batch_norm_stats_cuda_template<scalar_t, int64_t>(self, epsilon);
save_mean, save_invstd, self, epsilon);
} }
}); });
return std::tuple<Tensor, Tensor>(save_mean, save_invstd);
} }
Tensor batch_norm_elemt_cuda( Tensor batch_norm_elemt_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const Tensor& self, const c10::optional<Tensor>& weight_opt, const Tensor& mean, const Tensor& invstd, double epsilon) {
const c10::optional<Tensor>& bias_opt, const Tensor& mean, // See [Note: hacky wrapper removal for optional tensor]
const Tensor& invstd, double epsilon) { c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
auto output = at::empty_like(self, self.suggest_memory_format()); auto output = at::empty_like(self, self.suggest_memory_format());
// FIXME: Epsilon parameter isn't required, we don't take the reciprocal at::native::batch_norm_elemt_cuda_out(self, weight, bias, mean, invstd, epsilon, output);
batch_norm_elementwise(output, self, weight_opt, bias_opt, mean, invstd);
return output; return output;
} }
Tensor& batch_norm_elemt_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& batch_norm_elemt_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const Tensor& mean, const Tensor& invstd, double epsilon, Tensor& output) { const Tensor& mean, const Tensor& invstd, double epsilon, Tensor& output) {
// FIXME: Epsilon parameter isn't required, we don't take the reciprocal // See [Note: hacky wrapper removal for optional tensor]
batch_norm_elementwise(output, self, weight_opt, bias_opt, mean, invstd); c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
if (at::cuda::detail::canUse32BitIndexMath(self) && batch_norm_use_channels_last_kernels(self)){
batch_norm_elemt_channels_last_cuda_template(output, self, weight, bias, mean, invstd, epsilon);
return output;
}
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_elemt", [&] {
auto mean_st = mean.dtype();
auto invstd_st = invstd.dtype();
TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
bool is_bfloat16_float = std::is_same<scalar_t, at::BFloat16>::value && mean_st == at::kFloat;
if (cuda::detail::canUse32BitIndexMath(self)) {
if (is_half_float || is_bfloat16_float) {
batch_norm_elemt_cuda_template<scalar_t, float, int32_t>(output, self, weight, bias, mean, invstd, epsilon);
} else {
batch_norm_elemt_cuda_template<scalar_t, scalar_t, int32_t>(output, self, weight, bias, mean, invstd, epsilon);
}
} else {
if (is_half_float || is_bfloat16_float) {
batch_norm_elemt_cuda_template<scalar_t, float, int64_t>(output, self, weight, bias, mean, invstd, epsilon);
} else {
batch_norm_elemt_cuda_template<scalar_t, scalar_t, int64_t>(output, self, weight, bias, mean, invstd, epsilon);
}
}
});
return output; return output;
} }
@ -455,24 +267,35 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c
} }
std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda( std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
const Tensor& self, const c10::optional<Tensor>& running_mean_opt, const Tensor& self, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, double momentum) {
const c10::optional<Tensor>& running_var_opt, double momentum) { // See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> running_mean = at::borrow_from_optional_tensor(running_mean_opt); c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
c10::MaybeOwned<Tensor> running_var = at::borrow_from_optional_tensor(running_var_opt); const Tensor& running_mean = *running_mean_maybe_owned;
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
const int64_t n_input = self.size(1); return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward", [&] {
auto options = self.options().dtype( auto mean_st = running_mean.dtype();
at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true)); auto var_st = running_var.dtype();
auto save_mean = at::empty({n_input}, options); TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
auto save_var = at::empty({n_input}, options); // <sigh> Some workloads depend on passing in half input and float stats, which is
// usually handled by cuDNN. However, the JIT sometimes replaces cuDNN calls with this
batch_norm_mean_var(self, save_mean, save_var); // one so it needs to support the same case, or people start to complain.
TORCH_CHECK(running_mean->defined() == running_var->defined()); bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
if (running_mean->defined()) { bool is_bfloat16_float = std::is_same<scalar_t, at::BFloat16>::value && mean_st == at::kFloat;
const int64_t N = self.numel() / save_mean.numel(); if (cuda::detail::canUse32BitIndexMath(self)) {
batch_norm_update_stats(save_mean, save_var, *running_mean, *running_var, momentum, N); if (is_half_float || is_bfloat16_float) {
return batch_norm_update_stats_cuda_template<scalar_t, float, int32_t>(self, running_mean, running_var, momentum);
} else {
return batch_norm_update_stats_cuda_template<scalar_t, scalar_t, int32_t>(self, running_mean, running_var, momentum);
} }
return std::tuple<Tensor, Tensor>(save_mean, save_var); } else {
if (is_half_float || is_bfloat16_float) {
return batch_norm_update_stats_cuda_template<scalar_t, float, int64_t>(self, running_mean, running_var, momentum);
} else {
return batch_norm_update_stats_cuda_template<scalar_t, scalar_t, int64_t>(self, running_mean, running_var, momentum);
}
}
});
} }
} } // namespace at::native } } // namespace at::native

View File

@ -19,8 +19,6 @@ constexpr int MAX_BLOCK_SIZE = 256;
constexpr int MAX_BLOCK_SIZE = 512; constexpr int MAX_BLOCK_SIZE = 512;
#endif #endif
constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE // Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) { static int getNumThreads(int nElem) {
#if defined(__HIP_PLATFORM_HCC__) #if defined(__HIP_PLATFORM_HCC__)
@ -274,8 +272,8 @@ __global__ void batch_norm_transform_input_kernel(
} }
} }
template<typename T>
struct InvStd { struct InvStd {
template <typename T>
__device__ __forceinline__ T operator()(T var, double epsilon) const { __device__ __forceinline__ T operator()(T var, double epsilon) const {
T invstd = 0; T invstd = 0;
if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) { if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
@ -285,18 +283,20 @@ struct InvStd {
} }
}; };
template<typename T>
struct Var { struct Var {
template <typename T>
__device__ __forceinline__ T operator()(T var, double epsilon) const { __device__ __forceinline__ T operator()(T var, double epsilon) const {
return var; return var;
} }
}; };
template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> template <template<typename T> class VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
__global__ void batch_norm_collect_statistics_kernel( __global__ void batch_norm_collect_statistics_kernel(
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input, const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
const stat_accscalar_t epsilon, const stat_accscalar_t epsilon,
const stat_accscalar_t momentum, const stat_accscalar_t momentum,
GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> running_var,
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean, GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) { GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
@ -373,7 +373,14 @@ __global__ void batch_norm_collect_statistics_kernel(
save_mean[plane] = avg; save_mean[plane] = avg;
} }
if (save_transformed_var.data() != NULL) { if (save_transformed_var.data() != NULL) {
save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon); save_transformed_var[plane] = VarTransform<stat_accscalar_t>{}(var_n / N, epsilon);
}
if (running_mean.data() != NULL) {
running_mean[plane] = static_cast<stat_scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
}
if (running_var.data() != NULL) {
stat_accscalar_t unbiasedVar = var_n / (N - 1);
running_var[plane] = static_cast<stat_scalar_t>((1 - momentum) * running_var[plane] + momentum * unbiasedVar);
} }
} }
@ -586,6 +593,74 @@ static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_acc
return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>(); return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
} }
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
void batch_norm_cuda_template(Tensor& output_, Tensor& save_mean_, Tensor& save_invstd_, const Tensor& input_, const Tensor& weight_, const Tensor& bias_,
const Tensor& running_mean_, const Tensor& running_var_,
bool train, double momentum, double epsilon) {
TensorArg output_arg{ output_, "output", 1 },
save_mean_arg{ save_mean_, "save_mean", 2 },
save_invstd_arg{ save_invstd_, "save_invstd", 3 },
input_arg{ input_, "input", 4 },
weight_arg{ weight_, "weight", 5 },
bias_arg{ bias_, "bias", 6 },
run_mean_arg{ running_mean_, "running_mean", 7 },
run_var_arg{ running_var_, "running_var", 8 };
CheckedFrom c = "batch_norm_cuda";
checkAllSameGPU(c, {output_arg, save_mean_arg, save_invstd_arg, input_arg, weight_arg, bias_arg, run_mean_arg, run_var_arg});
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
auto bs = input_reshaped.size(0);
auto features = input_reshaped.size(2);
auto input = input_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
auto input_options = input_.options();
if (input_.scalar_type() == at::ScalarType::Half || input_.scalar_type() == at::ScalarType::BFloat16) {
input_options = input_options.dtype(ScalarType::Float);
}
auto output = output_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
auto weight = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_);
auto bias = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_);
auto running_mean = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
auto running_var = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
auto save_mean = save_mean_.generic_packed_accessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t>();
auto save_invstd = save_invstd_.generic_packed_accessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t>();
auto stream = at::cuda::getCurrentCUDAStream();
// The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
// and good occupancy. Quite likely, we could go with even more blocks than 1024.
// The various planes are independent, so we use blocks for them.
int tf = std::max<int>(getNumThreads(input.size(2)/4),
std::min<int>(getNumThreads(input.size(2)), 64));
int tb = std::max<int>(64/tf, 1);
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
(input.size(0)+tb-1)/tb)));
blocks_trans.y = std::min<int>(blocks_trans.y, 65535);
dim3 threads_trans(tf, tb);
if (!train) {
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, false, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
(input, output, running_mean, running_var, weight, bias, epsilon);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
// for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
// the feature dimension, we'll use some threads for blocks
dim3 blocks(input.size(1));
tf = getNumThreads(input.size(2));
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
batch_norm_collect_statistics_kernel<InvStd, input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
(input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd);
C10_CUDA_KERNEL_LAUNCH_CHECK();
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
(input, output, save_mean, save_invstd, weight, bias, epsilon);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
template<typename input_scalar_t, typename stat_scalar_t, typename index_t> template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
@ -634,39 +709,49 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tenso
return std::make_tuple(grad_input_, grad_weight_, grad_bias_); return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
} }
template<typename scalar_t, typename index_t, typename VarTransform> template<typename scalar_t, typename index_t>
void batch_norm_stats_cuda_template( std::tuple<Tensor, Tensor> batch_norm_stats_cuda_template(const Tensor& input_, double epsilon) {
const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
int64_t n_input = input_.size(1); int64_t n_input = input_.size(1);
Tensor dummy_mean_; Tensor dummy_mean_;
Tensor dummy_var_; Tensor dummy_var_;
Tensor mean_;
Tensor invstd_;
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
auto bs = input_reshaped.size(0); auto bs = input_reshaped.size(0);
auto features = input_reshaped.size(2); auto features = input_reshaped.size(2);
auto input = input_reshaped.generic_packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>(); auto input = input_reshaped.generic_packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && auto input_options = input_.options();
out_invstd.sizes()[0]); dummy_mean_ = at::empty({0}, input_options);
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && dummy_var_ = at::empty({0}, input_options);
out_mean.sizes()[0]); // promote only mean_/invstd_ precision
if (input_.scalar_type() == at::ScalarType::Half || input_.scalar_type() == at::ScalarType::BFloat16) {
auto mean = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean); input_options = input_options.dtype(ScalarType::Float);
auto invstd = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd); }
mean_ = at::empty({n_input}, input_options);
invstd_ = at::empty({n_input}, input_options);
auto mean = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(mean_);
auto invstd = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
auto dummy_mean = dummy_mean_.generic_packed_accessor<scalar_t, 1, RestrictPtrTraits, index_t>();
auto dummy_invstd = dummy_var_.generic_packed_accessor<scalar_t, 1, RestrictPtrTraits, index_t>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input.size(1)); dim3 blocks(input.size(1));
int tf = getNumThreads(input.size(2)); int tf = getNumThreads(input.size(2));
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf)); dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>> batch_norm_collect_statistics_kernel<InvStd, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
(input, epsilon, 0.0, mean, invstd); (input, epsilon, 0.0, dummy_mean, dummy_invstd, mean, invstd);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(mean_, invstd_);
} }
template<typename input_scalar_t, typename stat_scalar_t, typename index_t> template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_, void batch_norm_elemt_cuda_template(Tensor& output_, const Tensor& input_, const Tensor& weight_, const Tensor& bias_,
const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) { const Tensor& mean_, const Tensor& invstd_,
double epsilon) {
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>; using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
int64_t n_input = input_.size(1); int64_t n_input = input_.size(1);
@ -676,6 +761,10 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
auto bs = input_reshaped.size(0); auto bs = input_reshaped.size(0);
auto features = input_reshaped.size(2); auto features = input_reshaped.size(2);
auto input = input_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>(); auto input = input_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
auto input_options = input_.options();
if (input_.scalar_type() == at::ScalarType::Half || input_.scalar_type() == at::ScalarType::BFloat16) {
input_options = input_options.dtype(ScalarType::Float);
}
auto output = output_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>(); auto output = output_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
auto weight = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_); auto weight = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_);
auto bias = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_); auto bias = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_);
@ -683,9 +772,6 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
auto invstd = packed_accessor_or_dummy<stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_); auto invstd = packed_accessor_or_dummy<stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
// NOTE: We use transform_input_kernel in training mode, which ignores epsilon
const double dummy_epsilon = 1e-5;
// The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
// and good occupancy. Quiet likely, we could go with even more blocks than 1024. // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
@ -695,10 +781,9 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
int tb = std::max<int>(64/tf, 1); int tb = std::max<int>(64/tf, 1);
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1), dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
(input.size(0)+tb-1)/tb))); (input.size(0)+tb-1)/tb)));
blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
dim3 threads_trans(tf, tb); dim3 threads_trans(tf, tb);
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>> batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
(input, output, mean, invstd, weight, bias, dummy_epsilon); (input, output, mean, invstd, weight, bias, epsilon);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
@ -830,10 +915,45 @@ Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Te
return grad_input_reshaped.view(input_.sizes()); return grad_input_reshaped.view(input_.sizes());
} }
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
const Tensor& input_, const Tensor& running_mean_, const Tensor& running_var_, double momentum) {
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
int64_t n_channels = input_.size(1);
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
auto input_options = input_.options();
if (input_.scalar_type() == at::ScalarType::Half || input_.scalar_type() == at::ScalarType::BFloat16) {
input_options = input_options.dtype(ScalarType::Float);
}
Tensor save_mean_ = at::empty({n_channels}, input_options);
Tensor save_var_ = at::empty({n_channels}, input_options);
auto input = input_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
auto running_mean = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
auto running_var = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
auto save_mean = save_mean_.generic_packed_accessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t>();
auto save_var = save_var_.generic_packed_accessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t>();
auto stream = at::cuda::getCurrentCUDAStream();
// for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
// the feature dimension, we'll use some threads for blocks
dim3 blocks(input.size(1));
int tf = getNumThreads(input.size(2));
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
// NB: epsilon is unused by the Var transform, so we set it to 0
batch_norm_collect_statistics_kernel<Var, input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
(input, 0., momentum, running_mean, running_var, save_mean, save_var);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(save_mean_, save_var_);
}
// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
// original apex name: welford_kernel_c_last // original apex name: welford_kernel_c_last
template template
<typename VarTransform, <template<typename T> class VarTransform,
typename scalar_t, typename scalar_t,
typename accscalar_t, typename accscalar_t,
int PARALLEL_LOADS> int PARALLEL_LOADS>
@ -964,13 +1084,13 @@ batch_norm_collect_statistics_channels_last_kernel(
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
if (threadIdx.y == 0 && c_offset < stride) { if (threadIdx.y == 0 && c_offset < stride) {
out_mean[c_offset] = static_cast<accscalar_t>(mean_th); out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); out_invstd[c_offset] = VarTransform<accscalar_t>{}(m2_th/count_th, epsilon);
} }
} }
} else { } else {
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
out_mean[c_offset] = static_cast<accscalar_t>(mean_th); out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); out_invstd[c_offset] = VarTransform<accscalar_t>{}(m2_th/count_th, epsilon);
} }
} }
} }
@ -1260,18 +1380,18 @@ __global__ void batch_norm_backward_elemt_channels_last_kernel(
} }
} }
template<typename scalar_t, typename VarTransform> template<typename scalar_t>
void batch_norm_stats_channels_last_cuda_template( std::tuple<Tensor, Tensor> batch_norm_stats_channels_last_cuda_template(const Tensor& input, double epsilon) {
const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
const auto stride = input.sizes()[1]; const auto stride = input.sizes()[1];
const auto reduction_size = input.numel() / stride; const auto reduction_size = input.numel() / stride;
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && auto scalar_type = input.scalar_type() == at::kHalf ? at::kFloat : input.scalar_type();
out_invstd.sizes()[0]); auto option = input.options().dtype(scalar_type);
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
out_mean.sizes()[0]); at::Tensor out_invstd = at::empty({stride}, option);
at::Tensor out_mean = at::empty({stride}, option);
dim3 block; dim3 block;
dim3 grid; dim3 grid;
@ -1280,13 +1400,13 @@ void batch_norm_stats_channels_last_cuda_template(
at::Tensor staging_data; at::Tensor staging_data;
at::Tensor semaphores; at::Tensor semaphores;
if (grid.y > 1) { if (grid.y > 1) {
staging_data = at::empty({4*stride*grid.y}, out_mean.options()); staging_data = at::empty({4*stride*grid.y}, option);
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
} }
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER> batch_norm_collect_statistics_channels_last_kernel<InvStd, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
out_mean.data_ptr<accscalar_t>(), out_mean.data_ptr<accscalar_t>(),
@ -1297,15 +1417,18 @@ void batch_norm_stats_channels_last_cuda_template(
stride, stride,
epsilon); epsilon);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(out_mean, out_invstd);
} }
void batch_norm_elemt_channels_last_cuda_template( void batch_norm_elemt_channels_last_cuda_template(
const at::Tensor& output, at::Tensor& output,
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& shift, // bias of BN const at::Tensor& shift, // bias of BN
const at::Tensor& mean, const at::Tensor& mean,
const at::Tensor& inv_std, const at::Tensor& inv_std,
double epsilon,
const at::optional<at::Tensor>& z = c10::nullopt, // bias after BN const at::optional<at::Tensor>& z = c10::nullopt, // bias after BN
const bool fuse_relu = false) { const bool fuse_relu = false) {
const auto stride = input.sizes()[1]; const auto stride = input.sizes()[1];

View File

@ -1,4 +1,3 @@
#include <ATen/AccumulateType.h>
#include <ATen/native/TensorIterator.h> #include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Reduce.cuh> #include <ATen/native/cuda/Reduce.cuh>
#include <ATen/native/DispatchStub.h> #include <ATen/native/DispatchStub.h>
@ -8,30 +7,31 @@
namespace at { namespace native { namespace at { namespace native {
template <typename scalar_t, typename out_t=scalar_t> template <typename scalar_t>
void std_var_kernel_impl(TensorIterator& iter, bool unbiased, bool take_sqrt) { void std_var_kernel_impl(TensorIterator& iter, bool unbiased, bool take_sqrt) {
// reducing unrolling factor to 2 for welford kernel // reducing unrolling factor to 2 for welford kernel
// This is necessary to lower register usage that leads to register spills. // This is necessary to lower register usage that leads to register spills.
using acc_t = at::acc_type<scalar_t, true>; gpu_reduce_kernel<scalar_t, scalar_t, 2>(iter, WelfordOps<scalar_t, scalar_t, int32_t, float, thrust::pair<scalar_t, scalar_t>> { unbiased, take_sqrt }, WelfordData<scalar_t, int32_t, float> {});
using ops_t = WelfordOps<scalar_t, acc_t, int32_t, float, thrust::pair<out_t, out_t>>; }
gpu_reduce_kernel<scalar_t, out_t, 2>(
iter, ops_t{unbiased, take_sqrt}, typename ops_t::acc_t{}); template <>
void std_var_kernel_impl<at::Half>(TensorIterator& iter, bool unbiased, bool take_sqrt) {
// reducing unrolling factor to 2 for welford kernel
// This is necessary to lower register usage that leads to register spills.
gpu_reduce_kernel<at::Half, at::Half, 2>(iter, WelfordOps<at::Half, float, int32_t, float, thrust::pair<at::Half, at::Half>> { unbiased, take_sqrt }, WelfordData<float, int32_t, float> {});
}
template <>
void std_var_kernel_impl<at::BFloat16>(TensorIterator& iter, bool unbiased, bool take_sqrt) {
// reducing unrolling factor to 2 for welford kernel
// This is necessary to lower register usage that leads to register spills.
gpu_reduce_kernel<at::BFloat16, at::BFloat16, 2>(iter, WelfordOps<at::BFloat16, float, int32_t, float, thrust::pair<at::BFloat16, at::BFloat16>> { unbiased, take_sqrt }, WelfordData<float, int32_t, float> {});
} }
static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) { static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) {
const auto input_dtype = iter.input_dtype(); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "std_cuda", [&]() {
if (input_dtype == kHalf && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
std_var_kernel_impl<at::Half, float>(iter, unbiased, take_sqrt);
} else if (input_dtype == kBFloat16 && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
std_var_kernel_impl<at::BFloat16, float>(iter, unbiased, take_sqrt);
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(), "std_cuda", [&]() {
std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt); std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
}); });
}
} }
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>

View File

@ -6,106 +6,44 @@ import torch.nn.functional as F
"""Microbenchmarks for batchnorm operator.""" """Microbenchmarks for batchnorm operator."""
# Benchmark cudnn if available batchnorm_configs_short = op_bench.config_list(
if torch.backends.cudnn.is_available:
def cudnn_benchmark_configs(configs):
result = []
for config in configs:
is_cuda = any('cuda' in attr.values() for attr in config)
if is_cuda:
result.append((*config, dict(cudnn=True)))
result.append((*config, dict(cudnn=False)))
return result
else:
def cudnn_benchmark_configs(configs):
return [(*config, dict(cudnn=False)) for config in configs]
batchnorm_configs_short = cudnn_benchmark_configs(op_bench.config_list(
attr_names=["M", "N", "K"], attr_names=["M", "N", "K"],
attrs=[ attrs=[
[1, 256, 3136], [1, 256, 3136],
], ],
cross_product_configs={ cross_product_configs={
'device': ['cpu', 'cuda'], 'device': ['cpu', 'cuda'],
'training': [True, False],
}, },
tags=["short"] tags=["short"]
)) )
batchnorm_configs_long = cudnn_benchmark_configs(op_bench.cross_product_configs( batchnorm_configs_long = op_bench.cross_product_configs(
M=[2, 128], M=[1, 128],
N=[8192, 2048], N=[8192, 2048],
K=[1], K=[1],
device=['cpu', 'cuda'], device=['cpu', 'cuda'],
training=[True, False],
tags=["long"] tags=["long"]
)) )
class BatchNormBenchmark(op_bench.TorchBenchmarkBase): class BatchNormBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, training, cudnn): def init(self, M, N, K, device):
self.inputs = { self.inputs = {
"input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()), "input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()),
"mean": torch.rand(N, device=device), "mean": torch.rand(N, device=device),
"var": torch.rand(N, device=device), "var": torch.rand(N, device=device),
"weight": torch.rand(N, device=device), "weight": torch.rand(N, device=device),
"bias": torch.rand(N, device=device), "bias": torch.rand(N, device=device)
"training": training,
"cudnn": cudnn,
} }
self.set_module_name("batchnorm") self.set_module_name("batchnorm")
def forward(self, input_one, mean, var, weight, bias, training, cudnn): def forward(self, input_one, mean, var, weight, bias):
with torch.backends.cudnn.flags(enabled=cudnn): return F.batch_norm(input_one, mean, var, weight, bias)
return F.batch_norm(input_one, mean, var, weight, bias, training)
op_bench.generate_pt_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark) op_bench.generate_pt_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark)
op_bench.generate_pt_gradient_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark) op_bench.generate_pt_gradient_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark)
batchnorm1d_configs_short = cudnn_benchmark_configs(op_bench.config_list(
attr_names=["N", "C"],
attrs=[
[3136, 256],
],
cross_product_configs={
'device': ['cpu', 'cuda'],
'training': [True, False],
},
tags=["short"]
))
batchnorm1d_configs_long = cudnn_benchmark_configs(op_bench.cross_product_configs(
N=[2, 128],
C=[8192, 2048],
device=['cpu', 'cuda'],
training=[True, False],
tags=["long"]
))
class BatchNorm1dBenchmark(op_bench.TorchBenchmarkBase):
def init(self, N, C, device, training, cudnn):
self.inputs = {
"input_one": torch.rand(N, C, device=device, requires_grad=self.auto_set()),
"mean": torch.rand(C, device=device),
"var": torch.rand(C, device=device),
"weight": torch.rand(C, device=device),
"bias": torch.rand(C, device=device),
"training": training,
"cudnn": cudnn,
}
self.set_module_name("batchnorm")
def forward(self, input_one, mean, var, weight, bias, training, cudnn):
with torch.backends.cudnn.flags(enabled=cudnn):
return F.batch_norm(input_one, mean, var, weight, bias, training)
op_bench.generate_pt_test(batchnorm1d_configs_short + batchnorm1d_configs_long, BatchNorm1dBenchmark)
op_bench.generate_pt_gradient_test(batchnorm1d_configs_short + batchnorm1d_configs_long, BatchNorm1dBenchmark)
if __name__ == "__main__": if __name__ == "__main__":
op_bench.benchmark_runner.main() op_bench.benchmark_runner.main()

View File

@ -84,18 +84,6 @@
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) #define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__)
#endif #endif
#ifdef __has_attribute
#define C10_HAS_ATTRIBUTE(x) __has_attribute(x)
#else
#define C10_HAS_ATTRIBUTE(x) (0)
#endif
#ifdef __has_cpp_attribute
#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x)
#else
#define C10_HAS_CPP_ATTRIBUTE(x) (0)
#endif
/// C10_NODISCARD - Warn if a type or return value is discarded. /// C10_NODISCARD - Warn if a type or return value is discarded.
// Technically, we should check if __cplusplus > 201402L here, because // Technically, we should check if __cplusplus > 201402L here, because
@ -222,14 +210,6 @@ using namespace c10::hip;
#define C10_ALWAYS_INLINE inline #define C10_ALWAYS_INLINE inline
#endif #endif
#if C10_HAS_CPP_ATTRIBUTE(fallthrough)
#define C10_FALLTHROUGH [[fallthrough]]
#elif C10_HAS_ATTRIBUTE(fallthrough)
#define C10_FALLTHROUGH __attribute__((fallthrough))
#else
#define C10_FALLTHROUGH
#endif
#include <sstream> #include <sstream>
#include <string> #include <string>