mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D28142447: Improve BatchNorm1d performance (CUDA)
Test Plan: revert-hammer
Differential Revision:
D28142447 (b2936ad8fa)
Original commit changeset: c70109780e20
fbshipit-source-id: e93f6d00d644697b106f5ea8ab79872f353b51c6
This commit is contained in:
parent
3948ce2fd9
commit
2992ff3fb8
|
|
@ -1,20 +0,0 @@
|
||||||
#include <ATen/AccumulateType.h>
|
|
||||||
|
|
||||||
namespace at {
|
|
||||||
|
|
||||||
c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) {
|
|
||||||
switch (type) {
|
|
||||||
#define DEFINE_CASE(scalar_t, TypeNum) \
|
|
||||||
case ScalarType::TypeNum: \
|
|
||||||
return is_cuda ? \
|
|
||||||
CppTypeToScalarType<at::acc_type<scalar_t, true>>::value : \
|
|
||||||
CppTypeToScalarType<at::acc_type<scalar_t, false>>::value;
|
|
||||||
|
|
||||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CASE)
|
|
||||||
#undef DEFINE_CASE
|
|
||||||
|
|
||||||
default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <c10/util/Half.h>
|
#include <c10/util/Half.h>
|
||||||
#include <c10/util/BFloat16.h>
|
#include <c10/util/BFloat16.h>
|
||||||
#include <c10/core/ScalarType.h>
|
|
||||||
|
|
||||||
// Defines the accumulation type for a scalar type.
|
// Defines the accumulation type for a scalar type.
|
||||||
// Example:
|
// Example:
|
||||||
|
|
@ -23,8 +22,8 @@ struct AccumulateType { };
|
||||||
|
|
||||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||||
template <> struct AccumulateType<half, true> { using type = float; };
|
template <> struct AccumulateType<half, true> { using type = float; };
|
||||||
#endif
|
|
||||||
template <> struct AccumulateType<BFloat16, true> {using type = float; };
|
template <> struct AccumulateType<BFloat16, true> {using type = float; };
|
||||||
|
#endif
|
||||||
template <> struct AccumulateType<Half, true> { using type = float; };
|
template <> struct AccumulateType<Half, true> { using type = float; };
|
||||||
template <> struct AccumulateType<float, true> { using type = float; };
|
template <> struct AccumulateType<float, true> { using type = float; };
|
||||||
template <> struct AccumulateType<double, true> { using type = double; };
|
template <> struct AccumulateType<double, true> { using type = double; };
|
||||||
|
|
@ -35,7 +34,6 @@ template <> struct AccumulateType<int16_t, true> { using type = int64_t; };
|
||||||
template <> struct AccumulateType<int32_t, true> { using type = int64_t; };
|
template <> struct AccumulateType<int32_t, true> { using type = int64_t; };
|
||||||
template <> struct AccumulateType<int64_t, true> { using type = int64_t; };
|
template <> struct AccumulateType<int64_t, true> { using type = int64_t; };
|
||||||
template <> struct AccumulateType<bool, true> {using type = bool; };
|
template <> struct AccumulateType<bool, true> {using type = bool; };
|
||||||
template <> struct AccumulateType<Half, false> { using type = float; };
|
|
||||||
template <> struct AccumulateType<BFloat16, false> { using type = float; };
|
template <> struct AccumulateType<BFloat16, false> { using type = float; };
|
||||||
template <> struct AccumulateType<c10::complex<float>, false> { using type = c10::complex<double>; };
|
template <> struct AccumulateType<c10::complex<float>, false> { using type = c10::complex<double>; };
|
||||||
template <> struct AccumulateType<c10::complex<double>, false> { using type = c10::complex<double>; };
|
template <> struct AccumulateType<c10::complex<double>, false> { using type = c10::complex<double>; };
|
||||||
|
|
@ -49,11 +47,8 @@ template <> struct AccumulateType<char, false> { using type = int64_t; };
|
||||||
template <> struct AccumulateType<int16_t, false> { using type = int64_t; };
|
template <> struct AccumulateType<int16_t, false> { using type = int64_t; };
|
||||||
template <> struct AccumulateType<int32_t, false> { using type = int64_t; };
|
template <> struct AccumulateType<int32_t, false> { using type = int64_t; };
|
||||||
template <> struct AccumulateType<int64_t, false> { using type = int64_t; };
|
template <> struct AccumulateType<int64_t, false> { using type = int64_t; };
|
||||||
template <> struct AccumulateType<bool, false> {using type = bool; };
|
|
||||||
|
|
||||||
template<typename T, bool is_cuda>
|
template<typename T, bool is_cuda>
|
||||||
using acc_type = typename AccumulateType<T, is_cuda>::type;
|
using acc_type = typename AccumulateType<T, is_cuda>::type;
|
||||||
|
|
||||||
TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
|
|
||||||
|
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
|
|
@ -423,8 +423,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
|
||||||
check_dims_match_num_input_features("bias", num_features, bias.numel());
|
check_dims_match_num_input_features("bias", num_features, bias.numel());
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool use_cudnn = (
|
bool use_cudnn = false;
|
||||||
input.is_cuda()
|
use_cudnn = (input.is_cuda()
|
||||||
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
|
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
|
||||||
&& (input.scalar_type() != at::kHalf
|
&& (input.scalar_type() != at::kHalf
|
||||||
|| weight.scalar_type() == at::kFloat)
|
|| weight.scalar_type() == at::kFloat)
|
||||||
|
|
@ -432,29 +432,26 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
|
||||||
&& ((running_mean.defined() && running_var.defined())
|
&& ((running_mean.defined() && running_var.defined())
|
||||||
|| (!running_mean.defined() && !running_var.defined() && training))
|
|| (!running_mean.defined() && !running_var.defined() && training))
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
&& (input.dim() >= 3)
|
&& ((input.dim() == 2 && input.size(0) <= 131070 && training) // per-activation, training
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
&& ((input.size(0) <= 880801 && training) // spatial, training
|
|| (input.dim() == 2 && input.size(0) <= 262136 && !training) // per-activation, eval
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
||(input.size(0) <= 65535 && !training)) //spatial, eval
|
|| (input.dim() >= 3 && input.size(0) <= 880801 && training) // spatial, training
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
|
|| (input.dim() >= 3 && input.size(0) <= 65535 && !training)) //spatial, eval
|
||||||
&& detail::getCUDAHooks().compiledWithCuDNN()
|
&& detail::getCUDAHooks().compiledWithCuDNN()
|
||||||
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||||
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L);
|
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L);
|
||||||
|
|
||||||
if (use_cudnn) {
|
if (use_cudnn && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()) {
|
||||||
auto input_c = input.contiguous(input.suggest_memory_format());
|
return std::tuple_cat(
|
||||||
auto weight_c = weight.contiguous();
|
at::cudnn_batch_norm(
|
||||||
auto bias_c = bias.contiguous();
|
input.contiguous(input.suggest_memory_format()), weight.contiguous(),
|
||||||
auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
|
bias.contiguous(),
|
||||||
auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
|
running_mean.defined() ? running_mean.contiguous() : running_mean,
|
||||||
|
running_var.defined() ? running_var.contiguous() : running_var,
|
||||||
Tensor output, save_mean, save_var, reserve;
|
training, momentum, eps),
|
||||||
std::tie(output, save_mean, save_var, reserve) =
|
std::make_tuple(1));
|
||||||
at::cudnn_batch_norm(input_c, weight_c, bias_c, rmean_c, rvar_c,
|
|
||||||
training, momentum, eps);
|
|
||||||
|
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t>(
|
|
||||||
output, save_mean, save_var, reserve, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,4 @@ DECLARE_DISPATCH(cum_fn, cumsum_stub);
|
||||||
DECLARE_DISPATCH(cum_fn, cumprod_stub);
|
DECLARE_DISPATCH(cum_fn, cumprod_stub);
|
||||||
DECLARE_DISPATCH(cum_fn, logcumsumexp_stub);
|
DECLARE_DISPATCH(cum_fn, logcumsumexp_stub);
|
||||||
|
|
||||||
TORCH_API std::tuple<Tensor&,Tensor&> var_mean_out(Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim);
|
|
||||||
|
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,8 @@ struct WelfordOps {
|
||||||
auto ret = (divisor > 0) ?
|
auto ret = (divisor > 0) ?
|
||||||
(take_sqrt ? device_sqrt(acc.m2 / divisor) : (acc.m2 / divisor))
|
(take_sqrt ? device_sqrt(acc.m2 / divisor) : (acc.m2 / divisor))
|
||||||
: NAN;
|
: NAN;
|
||||||
return res_t(ret, mean);
|
detail::pair<scalar_t, scalar_t> results{(scalar_t) ret, (scalar_t) mean};
|
||||||
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,6 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/detail/FunctionTraits.h>
|
|
||||||
#include <ATen/native/TensorIterator.h>
|
|
||||||
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
|
||||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
|
||||||
|
|
||||||
#include <thrust/tuple.h>
|
|
||||||
|
|
||||||
#define NUM_THREADS (C10_WARP_SIZE * 2)
|
#define NUM_THREADS (C10_WARP_SIZE * 2)
|
||||||
#define THREAD_WORK_SIZE 4
|
#define THREAD_WORK_SIZE 4
|
||||||
#define BLOCK_WORK_SIZE (THREAD_WORK_SIZE * num_threads)
|
#define BLOCK_WORK_SIZE (THREAD_WORK_SIZE * num_threads)
|
||||||
|
|
@ -16,8 +9,14 @@ constexpr int num_threads = NUM_THREADS;
|
||||||
constexpr int thread_work_size = THREAD_WORK_SIZE;
|
constexpr int thread_work_size = THREAD_WORK_SIZE;
|
||||||
constexpr int block_work_size = BLOCK_WORK_SIZE;
|
constexpr int block_work_size = BLOCK_WORK_SIZE;
|
||||||
|
|
||||||
|
#include <ATen/detail/FunctionTraits.h>
|
||||||
|
#include <ATen/native/TensorIterator.h>
|
||||||
|
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
||||||
|
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
||||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||||
|
|
||||||
|
#include <thrust/tuple.h>
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
template<int N>
|
template<int N>
|
||||||
|
|
|
||||||
|
|
@ -1,280 +1,72 @@
|
||||||
#include <ATen/native/TensorIterator.h>
|
|
||||||
#include <ATen/native/ReduceOps.h>
|
|
||||||
#include <ATen/native/Resize.h>
|
|
||||||
#include <ATen/native/cuda/Loops.cuh>
|
|
||||||
#include <ATen/native/cuda/Reduce.cuh>
|
|
||||||
#include <ATen/native/cuda/Normalization.cuh>
|
#include <ATen/native/cuda/Normalization.cuh>
|
||||||
#include <c10/cuda/CUDAMathCompat.h>
|
|
||||||
|
|
||||||
namespace at { namespace native {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
|
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
|
||||||
return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
|
return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class Impl {
|
namespace at { namespace native {
|
||||||
Contiguous,
|
|
||||||
ChannelsLast,
|
|
||||||
General,
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Impl batch_norm_choose_impl(const Tensor& self) {
|
|
||||||
if (!at::cuda::detail::canUse32BitIndexMath(self)) {
|
|
||||||
return Impl::General;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (self.is_contiguous()) {
|
|
||||||
return self.strides()[1] == 1 ? Impl::ChannelsLast : Impl::Contiguous;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (self.is_contiguous(at::MemoryFormat::ChannelsLast)) {
|
|
||||||
return Impl::ChannelsLast;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Impl::General;
|
|
||||||
}
|
|
||||||
|
|
||||||
void batch_norm_elementwise(
|
|
||||||
const Tensor& out, const Tensor& self, const c10::optional<Tensor>& weight_opt,
|
|
||||||
const c10::optional<Tensor>& bias_opt, const Tensor& mean_, const Tensor& invstd_) {
|
|
||||||
switch (batch_norm_choose_impl(self)) {
|
|
||||||
case Impl::Contiguous: {
|
|
||||||
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
|
|
||||||
c10::MaybeOwned<Tensor> bias = at::borrow_from_optional_tensor(bias_opt);
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(),
|
|
||||||
"batch_norm_elementwise_cuda", [&] {
|
|
||||||
batch_norm_elemt_cuda_template<scalar_t, scalar_t, int32_t>(
|
|
||||||
out, self, *weight, *bias, mean_, invstd_);
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
case Impl::ChannelsLast: {
|
|
||||||
auto weight = at::borrow_from_optional_tensor(weight_opt);
|
|
||||||
auto bias = at::borrow_from_optional_tensor(bias_opt);
|
|
||||||
if ((!weight->defined() || weight->is_contiguous()) &&
|
|
||||||
(!bias->defined() || bias->is_contiguous()) &&
|
|
||||||
(!mean_.defined() || mean_.is_contiguous()) &&
|
|
||||||
(!invstd_.defined() || invstd_.is_contiguous())) {
|
|
||||||
batch_norm_elemt_channels_last_cuda_template(
|
|
||||||
out, self, *weight, *bias, mean_, invstd_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
C10_FALLTHROUGH;
|
|
||||||
}
|
|
||||||
case Impl::General: {
|
|
||||||
const int64_t ndim = self.dim();
|
|
||||||
DimVector sizes(ndim, 1), strides(ndim, 0);
|
|
||||||
// Helper to convert 1d tensors to an nd tensor that broadcasts with input
|
|
||||||
// All elements go into the channel dimension
|
|
||||||
auto as_nd = [&](const Tensor& t) {
|
|
||||||
TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
|
|
||||||
sizes[1] = t.sizes()[0];
|
|
||||||
strides[1] = t.strides()[0];
|
|
||||||
return t.as_strided(sizes, strides);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto weight = weight_opt.has_value() && weight_opt->defined() ?
|
|
||||||
as_nd(*weight_opt) : at::scalar_tensor(1, mean_.options());
|
|
||||||
auto bias = bias_opt.has_value() && bias_opt->defined() ?
|
|
||||||
as_nd(*bias_opt) : at::scalar_tensor(0, mean_.options());
|
|
||||||
auto mean = as_nd(mean_);
|
|
||||||
auto invstd = as_nd(invstd_);
|
|
||||||
|
|
||||||
auto iter = TensorIteratorConfig()
|
|
||||||
.add_output(out)
|
|
||||||
.add_input(self)
|
|
||||||
.add_input(weight)
|
|
||||||
.add_input(bias)
|
|
||||||
.add_input(mean)
|
|
||||||
.add_input(invstd)
|
|
||||||
.check_all_same_dtype(false)
|
|
||||||
.promote_inputs_to_common_dtype(false)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(),
|
|
||||||
"batch_norm_elementwise_cuda", [&] {
|
|
||||||
using acc_t = at::acc_type<scalar_t, true>;
|
|
||||||
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t input, acc_t weight, acc_t bias,
|
|
||||||
acc_t mean, acc_t invstd) -> scalar_t {
|
|
||||||
return ((input - mean) * invstd) * weight + bias;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void batch_norm_mean_var(const Tensor& self, Tensor& save_mean, Tensor& save_var) {
|
|
||||||
// NOTE: Epsilon is only used for InvStd, not Var. The value here is ignored.
|
|
||||||
const double dummy_epsilon = 1e-5;
|
|
||||||
switch (batch_norm_choose_impl(self)) {
|
|
||||||
case Impl::Contiguous: {
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
|
||||||
kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
|
|
||||||
batch_norm_stats_cuda_template<scalar_t, int32_t, Var>(
|
|
||||||
save_mean, save_var, self, dummy_epsilon);
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
case Impl::ChannelsLast: {
|
|
||||||
if ((!save_mean.defined() || save_mean.is_contiguous()) &&
|
|
||||||
(!save_var.defined() || save_var.is_contiguous())) {
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
|
||||||
kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
|
|
||||||
batch_norm_stats_channels_last_cuda_template<scalar_t, Var>(
|
|
||||||
save_mean, save_var, self, dummy_epsilon);
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
C10_FALLTHROUGH;
|
|
||||||
}
|
|
||||||
case Impl::General: {
|
|
||||||
const int64_t ndim = self.dim();
|
|
||||||
DimVector reduce_dims(ndim - 1);
|
|
||||||
reduce_dims[0] = 0;
|
|
||||||
for (int64_t i = 2; i < ndim; ++i) {
|
|
||||||
reduce_dims[i - 1] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For some reason this isn't an actual operator but it exists anyway...
|
|
||||||
at::native::var_mean_out(save_var, save_mean, self, /*dims=*/reduce_dims,
|
|
||||||
/*unbiased=*/false, /*keepdim=*/false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void batch_norm_update_stats(
|
|
||||||
const Tensor& save_mean, const Tensor& save_var,
|
|
||||||
const Tensor& running_mean, const Tensor& running_var,
|
|
||||||
double momentum_, int64_t N) {
|
|
||||||
|
|
||||||
auto iter = TensorIteratorConfig()
|
|
||||||
.add_output(running_mean)
|
|
||||||
.add_output(running_var)
|
|
||||||
.add_input(save_mean)
|
|
||||||
.add_input(save_var)
|
|
||||||
.add_input(running_mean)
|
|
||||||
.add_input(running_var)
|
|
||||||
.check_all_same_dtype(false)
|
|
||||||
.promote_inputs_to_common_dtype(false)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_mean.scalar_type(),
|
|
||||||
"batch_norm_update_stats_cuda", [&] {
|
|
||||||
using acc_t = at::acc_type<scalar_t, true>;
|
|
||||||
const auto bessel_correction_factor = static_cast<acc_t>(
|
|
||||||
static_cast<double>(N) / static_cast<double>(N - 1));
|
|
||||||
const auto momentum = static_cast<acc_t>(momentum_);
|
|
||||||
gpu_kernel_multiple_outputs(
|
|
||||||
iter, [=] GPU_LAMBDA (acc_t mean, acc_t var, scalar_t running_mean, scalar_t running_var)
|
|
||||||
-> thrust::tuple<scalar_t, scalar_t> {
|
|
||||||
const auto unbiased_var = var * bessel_correction_factor;
|
|
||||||
return thrust::tuple<scalar_t, scalar_t>{
|
|
||||||
mean * momentum + (1 - momentum) * running_mean,
|
|
||||||
unbiased_var * momentum + (1 - momentum) * running_var,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void batch_norm_update_stats_and_invert(
|
|
||||||
const Tensor& save_mean, const Tensor& save_var,
|
|
||||||
const Tensor& running_mean, const Tensor& running_var,
|
|
||||||
double momentum_, double epsilon, int64_t N) {
|
|
||||||
|
|
||||||
auto iter = TensorIteratorConfig()
|
|
||||||
.add_output(running_mean)
|
|
||||||
.add_output(running_var)
|
|
||||||
.add_output(save_var)
|
|
||||||
.add_input(save_mean)
|
|
||||||
.add_input(save_var)
|
|
||||||
.add_input(running_mean)
|
|
||||||
.add_input(running_var)
|
|
||||||
.check_all_same_dtype(false)
|
|
||||||
.promote_inputs_to_common_dtype(false)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_mean.scalar_type(),
|
|
||||||
"batch_norm_update_stats_cuda", [&] {
|
|
||||||
using acc_t = at::acc_type<scalar_t, true>;
|
|
||||||
const auto bessel_correction_factor = static_cast<acc_t>(
|
|
||||||
static_cast<double>(N) / static_cast<double>(N - 1));
|
|
||||||
const auto eps = static_cast<acc_t>(epsilon);
|
|
||||||
const auto momentum = static_cast<acc_t>(momentum_);
|
|
||||||
gpu_kernel_multiple_outputs(
|
|
||||||
iter, [=] GPU_LAMBDA (acc_t mean, acc_t var, scalar_t running_mean, scalar_t running_var)
|
|
||||||
-> thrust::tuple<scalar_t, scalar_t, acc_t> {
|
|
||||||
const auto unbiased_var = var * bessel_correction_factor;
|
|
||||||
return thrust::tuple<scalar_t, scalar_t, acc_t>{
|
|
||||||
mean * momentum + (1 - momentum) * running_mean,
|
|
||||||
unbiased_var * momentum + (1 - momentum) * running_var,
|
|
||||||
c10::cuda::compat::rsqrt(var + eps)
|
|
||||||
};
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void batch_norm_calc_invstd(const Tensor& out_invstd, const Tensor& running_var, double epsilon) {
|
|
||||||
auto iter = TensorIteratorConfig()
|
|
||||||
.add_output(out_invstd)
|
|
||||||
.add_input(running_var)
|
|
||||||
.check_all_same_dtype(false)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, running_var.scalar_type(),
|
|
||||||
"batch_norm_invert_std_cuda", [&] {
|
|
||||||
using acc_t = at::acc_type<scalar_t, true>;
|
|
||||||
auto eps = static_cast<acc_t>(epsilon);
|
|
||||||
gpu_kernel(iter, [eps] GPU_LAMBDA (scalar_t var) -> acc_t {
|
|
||||||
return c10::cuda::compat::rsqrt(var + eps);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
|
std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) {
|
||||||
const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined());
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
const bool has_running_var = (running_mean_opt.has_value() && running_mean_opt->defined());
|
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||||
TORCH_CHECK(has_running_mean == has_running_var);
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
|
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
||||||
|
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
||||||
|
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
||||||
|
|
||||||
if (train) {
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_cuda", [&] {
|
||||||
batch_norm_mean_var(self, save_mean, save_invstd);
|
auto mean_st = running_mean.dtype();
|
||||||
if (has_running_mean) {
|
auto var_st = running_var.dtype();
|
||||||
const int64_t N = self.numel() / save_mean.numel();
|
TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
|
||||||
batch_norm_update_stats_and_invert(
|
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
|
||||||
save_mean, save_invstd, *running_mean_opt, *running_var_opt,
|
bool is_bfloat16_float = std::is_same<scalar_t, at::BFloat16>::value && mean_st == at::kFloat;
|
||||||
momentum, epsilon, N);
|
if (cuda::detail::canUse32BitIndexMath(self)) {
|
||||||
|
if (is_half_float || is_bfloat16_float) {
|
||||||
|
batch_norm_cuda_template<scalar_t, float, int32_t>(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon);
|
||||||
} else {
|
} else {
|
||||||
batch_norm_calc_invstd(save_invstd, save_invstd, epsilon);
|
batch_norm_cuda_template<scalar_t, scalar_t, int32_t>(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(has_running_mean);
|
if (is_half_float || is_bfloat16_float) {
|
||||||
at::native::resize_output(save_mean, running_mean_opt->sizes());
|
batch_norm_cuda_template<scalar_t, float, int64_t>(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon);
|
||||||
save_mean.copy_(*running_mean_opt, /*non_blocking=*/true);
|
} else {
|
||||||
batch_norm_calc_invstd(save_invstd, running_var_opt.value(), epsilon);
|
batch_norm_cuda_template<scalar_t, scalar_t, int64_t>(output, save_mean, save_invstd, self, weight, bias, running_mean, running_var, train, momentum, epsilon);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
batch_norm_elementwise(output, self, weight_opt, bias_opt, save_mean, save_invstd);
|
});
|
||||||
return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_invstd);
|
return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_invstd);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon) {
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, bool train, double momentum, double epsilon) {
|
||||||
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
|
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||||
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
|
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
||||||
|
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
||||||
|
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
||||||
|
|
||||||
auto output = at::empty_like(self, at::MemoryFormat::Contiguous);
|
auto output = at::empty_like(self, at::MemoryFormat::Contiguous);
|
||||||
int64_t n_input = self.size(1);
|
int64_t n_input = self.size(1);
|
||||||
auto options = self.options().dtype(
|
auto input_options = self.options();
|
||||||
at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
|
// Accumulate in higher precision if input is half/bfloat16
|
||||||
auto save_mean = at::empty({n_input}, options);
|
if (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
auto save_invstd = at::empty({n_input}, options);
|
input_options = input_options.dtype(ScalarType::Float);
|
||||||
|
}
|
||||||
|
Tensor save_mean, save_invstd;
|
||||||
|
if (train) {
|
||||||
|
save_mean = at::empty({n_input}, input_options);
|
||||||
|
save_invstd = at::empty({n_input}, input_options);
|
||||||
|
} else {
|
||||||
|
save_mean = at::empty({0}, input_options);
|
||||||
|
save_invstd = at::empty({0}, input_options);
|
||||||
|
}
|
||||||
|
|
||||||
at::native::batch_norm_cuda_out(
|
at::native::batch_norm_cuda_out(
|
||||||
self,
|
self,
|
||||||
weight_opt,
|
weight,
|
||||||
bias_opt,
|
bias,
|
||||||
running_mean_opt,
|
running_mean,
|
||||||
running_var_opt,
|
running_var,
|
||||||
train,
|
train,
|
||||||
momentum,
|
momentum,
|
||||||
epsilon,
|
epsilon,
|
||||||
|
|
@ -316,45 +108,65 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_o
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) {
|
std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) {
|
||||||
auto options = self.options().dtype(
|
|
||||||
at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
|
|
||||||
auto n_channels = self.size(1);
|
|
||||||
auto save_mean = at::empty({n_channels}, options);
|
|
||||||
auto save_invstd = at::empty({n_channels}, options);
|
|
||||||
|
|
||||||
bool use_channels_last_kernel = batch_norm_use_channels_last_kernels(self);
|
bool use_channels_last_kernel = batch_norm_use_channels_last_kernels(self);
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
|
|
||||||
self.scalar_type(), "batch_norm_stats_cuda", [&] {
|
return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_stats_cuda", [&] {
|
||||||
if (cuda::detail::canUse32BitIndexMath(self)) {
|
if (cuda::detail::canUse32BitIndexMath(self)) {
|
||||||
if (use_channels_last_kernel) {
|
if (use_channels_last_kernel) {
|
||||||
batch_norm_stats_channels_last_cuda_template<scalar_t, InvStd>(
|
return batch_norm_stats_channels_last_cuda_template<scalar_t>(self, epsilon);
|
||||||
save_mean, save_invstd, self, epsilon);
|
|
||||||
} else {
|
} else {
|
||||||
batch_norm_stats_cuda_template<scalar_t, int32_t, InvStd>(
|
return batch_norm_stats_cuda_template<scalar_t, int32_t>(self, epsilon);
|
||||||
save_mean, save_invstd, self, epsilon);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
batch_norm_stats_cuda_template<scalar_t, int64_t, InvStd>(
|
return batch_norm_stats_cuda_template<scalar_t, int64_t>(self, epsilon);
|
||||||
save_mean, save_invstd, self, epsilon);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return std::tuple<Tensor, Tensor>(save_mean, save_invstd);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor batch_norm_elemt_cuda(
|
Tensor batch_norm_elemt_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
||||||
const Tensor& self, const c10::optional<Tensor>& weight_opt,
|
const Tensor& mean, const Tensor& invstd, double epsilon) {
|
||||||
const c10::optional<Tensor>& bias_opt, const Tensor& mean,
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
const Tensor& invstd, double epsilon) {
|
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||||
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
|
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
||||||
|
|
||||||
auto output = at::empty_like(self, self.suggest_memory_format());
|
auto output = at::empty_like(self, self.suggest_memory_format());
|
||||||
// FIXME: Epsilon parameter isn't required, we don't take the reciprocal
|
at::native::batch_norm_elemt_cuda_out(self, weight, bias, mean, invstd, epsilon, output);
|
||||||
batch_norm_elementwise(output, self, weight_opt, bias_opt, mean, invstd);
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& batch_norm_elemt_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
Tensor& batch_norm_elemt_cuda_out(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
||||||
const Tensor& mean, const Tensor& invstd, double epsilon, Tensor& output) {
|
const Tensor& mean, const Tensor& invstd, double epsilon, Tensor& output) {
|
||||||
// FIXME: Epsilon parameter isn't required, we don't take the reciprocal
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
batch_norm_elementwise(output, self, weight_opt, bias_opt, mean, invstd);
|
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||||
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
|
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
||||||
|
|
||||||
|
if (at::cuda::detail::canUse32BitIndexMath(self) && batch_norm_use_channels_last_kernels(self)){
|
||||||
|
batch_norm_elemt_channels_last_cuda_template(output, self, weight, bias, mean, invstd, epsilon);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_elemt", [&] {
|
||||||
|
auto mean_st = mean.dtype();
|
||||||
|
auto invstd_st = invstd.dtype();
|
||||||
|
TORCH_CHECK(mean_st == invstd_st, "mean and invstd need to have the same data types");
|
||||||
|
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
|
||||||
|
bool is_bfloat16_float = std::is_same<scalar_t, at::BFloat16>::value && mean_st == at::kFloat;
|
||||||
|
if (cuda::detail::canUse32BitIndexMath(self)) {
|
||||||
|
if (is_half_float || is_bfloat16_float) {
|
||||||
|
batch_norm_elemt_cuda_template<scalar_t, float, int32_t>(output, self, weight, bias, mean, invstd, epsilon);
|
||||||
|
} else {
|
||||||
|
batch_norm_elemt_cuda_template<scalar_t, scalar_t, int32_t>(output, self, weight, bias, mean, invstd, epsilon);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (is_half_float || is_bfloat16_float) {
|
||||||
|
batch_norm_elemt_cuda_template<scalar_t, float, int64_t>(output, self, weight, bias, mean, invstd, epsilon);
|
||||||
|
} else {
|
||||||
|
batch_norm_elemt_cuda_template<scalar_t, scalar_t, int64_t>(output, self, weight, bias, mean, invstd, epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -455,24 +267,35 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
|
std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
|
||||||
const Tensor& self, const c10::optional<Tensor>& running_mean_opt,
|
const Tensor& self, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, double momentum) {
|
||||||
const c10::optional<Tensor>& running_var_opt, double momentum) {
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
c10::MaybeOwned<Tensor> running_mean = at::borrow_from_optional_tensor(running_mean_opt);
|
c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
|
||||||
c10::MaybeOwned<Tensor> running_var = at::borrow_from_optional_tensor(running_var_opt);
|
const Tensor& running_mean = *running_mean_maybe_owned;
|
||||||
|
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
||||||
|
|
||||||
const int64_t n_input = self.size(1);
|
return AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "batch_norm_backward", [&] {
|
||||||
auto options = self.options().dtype(
|
auto mean_st = running_mean.dtype();
|
||||||
at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true));
|
auto var_st = running_var.dtype();
|
||||||
auto save_mean = at::empty({n_input}, options);
|
TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
|
||||||
auto save_var = at::empty({n_input}, options);
|
// <sigh> Some workloads depend on passing in half input and float stats, which is
|
||||||
|
// usually handled by cuDNN. However, the JIT sometimes replaces cuDNN calls with this
|
||||||
batch_norm_mean_var(self, save_mean, save_var);
|
// one so it needs to support the same case, or people start to complain.
|
||||||
TORCH_CHECK(running_mean->defined() == running_var->defined());
|
bool is_half_float = std::is_same<scalar_t, at::Half>::value && mean_st == at::kFloat;
|
||||||
if (running_mean->defined()) {
|
bool is_bfloat16_float = std::is_same<scalar_t, at::BFloat16>::value && mean_st == at::kFloat;
|
||||||
const int64_t N = self.numel() / save_mean.numel();
|
if (cuda::detail::canUse32BitIndexMath(self)) {
|
||||||
batch_norm_update_stats(save_mean, save_var, *running_mean, *running_var, momentum, N);
|
if (is_half_float || is_bfloat16_float) {
|
||||||
|
return batch_norm_update_stats_cuda_template<scalar_t, float, int32_t>(self, running_mean, running_var, momentum);
|
||||||
|
} else {
|
||||||
|
return batch_norm_update_stats_cuda_template<scalar_t, scalar_t, int32_t>(self, running_mean, running_var, momentum);
|
||||||
}
|
}
|
||||||
return std::tuple<Tensor, Tensor>(save_mean, save_var);
|
} else {
|
||||||
|
if (is_half_float || is_bfloat16_float) {
|
||||||
|
return batch_norm_update_stats_cuda_template<scalar_t, float, int64_t>(self, running_mean, running_var, momentum);
|
||||||
|
} else {
|
||||||
|
return batch_norm_update_stats_cuda_template<scalar_t, scalar_t, int64_t>(self, running_mean, running_var, momentum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} } // namespace at::native
|
} } // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,6 @@ constexpr int MAX_BLOCK_SIZE = 256;
|
||||||
constexpr int MAX_BLOCK_SIZE = 512;
|
constexpr int MAX_BLOCK_SIZE = 512;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
constexpr unsigned MAX_GRID_SIZE = 65535u;
|
|
||||||
|
|
||||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
||||||
static int getNumThreads(int nElem) {
|
static int getNumThreads(int nElem) {
|
||||||
#if defined(__HIP_PLATFORM_HCC__)
|
#if defined(__HIP_PLATFORM_HCC__)
|
||||||
|
|
@ -274,8 +272,8 @@ __global__ void batch_norm_transform_input_kernel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
struct InvStd {
|
struct InvStd {
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ T operator()(T var, double epsilon) const {
|
__device__ __forceinline__ T operator()(T var, double epsilon) const {
|
||||||
T invstd = 0;
|
T invstd = 0;
|
||||||
if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
|
if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
|
||||||
|
|
@ -285,18 +283,20 @@ struct InvStd {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
struct Var {
|
struct Var {
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ T operator()(T var, double epsilon) const {
|
__device__ __forceinline__ T operator()(T var, double epsilon) const {
|
||||||
return var;
|
return var;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
|
template <template<typename T> class VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
|
||||||
__global__ void batch_norm_collect_statistics_kernel(
|
__global__ void batch_norm_collect_statistics_kernel(
|
||||||
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
|
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
|
||||||
const stat_accscalar_t epsilon,
|
const stat_accscalar_t epsilon,
|
||||||
const stat_accscalar_t momentum,
|
const stat_accscalar_t momentum,
|
||||||
|
GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
|
||||||
|
GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> running_var,
|
||||||
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
|
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
|
||||||
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
|
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
|
||||||
|
|
||||||
|
|
@ -373,7 +373,14 @@ __global__ void batch_norm_collect_statistics_kernel(
|
||||||
save_mean[plane] = avg;
|
save_mean[plane] = avg;
|
||||||
}
|
}
|
||||||
if (save_transformed_var.data() != NULL) {
|
if (save_transformed_var.data() != NULL) {
|
||||||
save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
|
save_transformed_var[plane] = VarTransform<stat_accscalar_t>{}(var_n / N, epsilon);
|
||||||
|
}
|
||||||
|
if (running_mean.data() != NULL) {
|
||||||
|
running_mean[plane] = static_cast<stat_scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
|
||||||
|
}
|
||||||
|
if (running_var.data() != NULL) {
|
||||||
|
stat_accscalar_t unbiasedVar = var_n / (N - 1);
|
||||||
|
running_var[plane] = static_cast<stat_scalar_t>((1 - momentum) * running_var[plane] + momentum * unbiasedVar);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -586,6 +593,74 @@ static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_acc
|
||||||
return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
|
return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
||||||
|
void batch_norm_cuda_template(Tensor& output_, Tensor& save_mean_, Tensor& save_invstd_, const Tensor& input_, const Tensor& weight_, const Tensor& bias_,
|
||||||
|
const Tensor& running_mean_, const Tensor& running_var_,
|
||||||
|
bool train, double momentum, double epsilon) {
|
||||||
|
|
||||||
|
TensorArg output_arg{ output_, "output", 1 },
|
||||||
|
save_mean_arg{ save_mean_, "save_mean", 2 },
|
||||||
|
save_invstd_arg{ save_invstd_, "save_invstd", 3 },
|
||||||
|
input_arg{ input_, "input", 4 },
|
||||||
|
weight_arg{ weight_, "weight", 5 },
|
||||||
|
bias_arg{ bias_, "bias", 6 },
|
||||||
|
run_mean_arg{ running_mean_, "running_mean", 7 },
|
||||||
|
run_var_arg{ running_var_, "running_var", 8 };
|
||||||
|
CheckedFrom c = "batch_norm_cuda";
|
||||||
|
checkAllSameGPU(c, {output_arg, save_mean_arg, save_invstd_arg, input_arg, weight_arg, bias_arg, run_mean_arg, run_var_arg});
|
||||||
|
|
||||||
|
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
|
||||||
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
||||||
|
auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
|
||||||
|
|
||||||
|
auto bs = input_reshaped.size(0);
|
||||||
|
auto features = input_reshaped.size(2);
|
||||||
|
auto input = input_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
|
||||||
|
auto input_options = input_.options();
|
||||||
|
if (input_.scalar_type() == at::ScalarType::Half || input_.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
|
input_options = input_options.dtype(ScalarType::Float);
|
||||||
|
}
|
||||||
|
auto output = output_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
|
||||||
|
auto weight = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_);
|
||||||
|
auto bias = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_);
|
||||||
|
auto running_mean = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
|
||||||
|
auto running_var = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
|
||||||
|
auto save_mean = save_mean_.generic_packed_accessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t>();
|
||||||
|
auto save_invstd = save_invstd_.generic_packed_accessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t>();
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
// The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
|
||||||
|
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
|
||||||
|
// and good occupancy. Quite likely, we could go with even more blocks than 1024.
|
||||||
|
// The various planes are independent, so we use blocks for them.
|
||||||
|
int tf = std::max<int>(getNumThreads(input.size(2)/4),
|
||||||
|
std::min<int>(getNumThreads(input.size(2)), 64));
|
||||||
|
int tb = std::max<int>(64/tf, 1);
|
||||||
|
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
|
||||||
|
(input.size(0)+tb-1)/tb)));
|
||||||
|
blocks_trans.y = std::min<int>(blocks_trans.y, 65535);
|
||||||
|
dim3 threads_trans(tf, tb);
|
||||||
|
if (!train) {
|
||||||
|
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, false, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
|
||||||
|
(input, output, running_mean, running_var, weight, bias, epsilon);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
} else {
|
||||||
|
// for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
|
||||||
|
// the feature dimension, we'll use some threads for blocks
|
||||||
|
dim3 blocks(input.size(1));
|
||||||
|
tf = getNumThreads(input.size(2));
|
||||||
|
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
|
||||||
|
|
||||||
|
batch_norm_collect_statistics_kernel<InvStd, input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
|
||||||
|
(input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|
||||||
|
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
|
||||||
|
(input, output, save_mean, save_invstd, weight, bias, epsilon);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
||||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
|
||||||
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
|
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
|
||||||
|
|
@ -634,39 +709,49 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tenso
|
||||||
return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
|
return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t, typename index_t, typename VarTransform>
|
template<typename scalar_t, typename index_t>
|
||||||
void batch_norm_stats_cuda_template(
|
std::tuple<Tensor, Tensor> batch_norm_stats_cuda_template(const Tensor& input_, double epsilon) {
|
||||||
const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
|
|
||||||
|
|
||||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||||
int64_t n_input = input_.size(1);
|
int64_t n_input = input_.size(1);
|
||||||
Tensor dummy_mean_;
|
Tensor dummy_mean_;
|
||||||
Tensor dummy_var_;
|
Tensor dummy_var_;
|
||||||
|
Tensor mean_;
|
||||||
|
Tensor invstd_;
|
||||||
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
||||||
|
|
||||||
auto bs = input_reshaped.size(0);
|
auto bs = input_reshaped.size(0);
|
||||||
auto features = input_reshaped.size(2);
|
auto features = input_reshaped.size(2);
|
||||||
auto input = input_reshaped.generic_packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
|
auto input = input_reshaped.generic_packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
|
||||||
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
|
auto input_options = input_.options();
|
||||||
out_invstd.sizes()[0]);
|
dummy_mean_ = at::empty({0}, input_options);
|
||||||
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
|
dummy_var_ = at::empty({0}, input_options);
|
||||||
out_mean.sizes()[0]);
|
// promote only mean_/invstd_ precision
|
||||||
|
if (input_.scalar_type() == at::ScalarType::Half || input_.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
auto mean = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean);
|
input_options = input_options.dtype(ScalarType::Float);
|
||||||
auto invstd = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd);
|
}
|
||||||
|
mean_ = at::empty({n_input}, input_options);
|
||||||
|
invstd_ = at::empty({n_input}, input_options);
|
||||||
|
auto mean = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(mean_);
|
||||||
|
auto invstd = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
|
||||||
|
auto dummy_mean = dummy_mean_.generic_packed_accessor<scalar_t, 1, RestrictPtrTraits, index_t>();
|
||||||
|
auto dummy_invstd = dummy_var_.generic_packed_accessor<scalar_t, 1, RestrictPtrTraits, index_t>();
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
dim3 blocks(input.size(1));
|
dim3 blocks(input.size(1));
|
||||||
int tf = getNumThreads(input.size(2));
|
int tf = getNumThreads(input.size(2));
|
||||||
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
|
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
|
||||||
batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
|
batch_norm_collect_statistics_kernel<InvStd, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
|
||||||
(input, epsilon, 0.0, mean, invstd);
|
(input, epsilon, 0.0, dummy_mean, dummy_invstd, mean, invstd);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|
||||||
|
return std::make_tuple(mean_, invstd_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
||||||
void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
|
void batch_norm_elemt_cuda_template(Tensor& output_, const Tensor& input_, const Tensor& weight_, const Tensor& bias_,
|
||||||
const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {
|
const Tensor& mean_, const Tensor& invstd_,
|
||||||
|
double epsilon) {
|
||||||
|
|
||||||
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
|
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
|
||||||
int64_t n_input = input_.size(1);
|
int64_t n_input = input_.size(1);
|
||||||
|
|
@ -676,6 +761,10 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
|
||||||
auto bs = input_reshaped.size(0);
|
auto bs = input_reshaped.size(0);
|
||||||
auto features = input_reshaped.size(2);
|
auto features = input_reshaped.size(2);
|
||||||
auto input = input_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
|
auto input = input_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
|
||||||
|
auto input_options = input_.options();
|
||||||
|
if (input_.scalar_type() == at::ScalarType::Half || input_.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
|
input_options = input_options.dtype(ScalarType::Float);
|
||||||
|
}
|
||||||
auto output = output_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
|
auto output = output_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
|
||||||
auto weight = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_);
|
auto weight = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_);
|
||||||
auto bias = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_);
|
auto bias = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_);
|
||||||
|
|
@ -683,9 +772,6 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
|
||||||
auto invstd = packed_accessor_or_dummy<stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
|
auto invstd = packed_accessor_or_dummy<stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
// NOTE: We use transform_input_kernel in training mode, which ignores epsilon
|
|
||||||
const double dummy_epsilon = 1e-5;
|
|
||||||
|
|
||||||
// The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
|
// The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
|
||||||
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
|
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
|
||||||
// and good occupancy. Quiet likely, we could go with even more blocks than 1024.
|
// and good occupancy. Quiet likely, we could go with even more blocks than 1024.
|
||||||
|
|
@ -695,10 +781,9 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
|
||||||
int tb = std::max<int>(64/tf, 1);
|
int tb = std::max<int>(64/tf, 1);
|
||||||
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
|
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
|
||||||
(input.size(0)+tb-1)/tb)));
|
(input.size(0)+tb-1)/tb)));
|
||||||
blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
|
|
||||||
dim3 threads_trans(tf, tb);
|
dim3 threads_trans(tf, tb);
|
||||||
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
|
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
|
||||||
(input, output, mean, invstd, weight, bias, dummy_epsilon);
|
(input, output, mean, invstd, weight, bias, epsilon);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -830,10 +915,45 @@ Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Te
|
||||||
return grad_input_reshaped.view(input_.sizes());
|
return grad_input_reshaped.view(input_.sizes());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
||||||
|
std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
|
||||||
|
const Tensor& input_, const Tensor& running_mean_, const Tensor& running_var_, double momentum) {
|
||||||
|
|
||||||
|
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
|
||||||
|
int64_t n_channels = input_.size(1);
|
||||||
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
||||||
|
|
||||||
|
auto input_options = input_.options();
|
||||||
|
if (input_.scalar_type() == at::ScalarType::Half || input_.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
|
input_options = input_options.dtype(ScalarType::Float);
|
||||||
|
}
|
||||||
|
Tensor save_mean_ = at::empty({n_channels}, input_options);
|
||||||
|
Tensor save_var_ = at::empty({n_channels}, input_options);
|
||||||
|
|
||||||
|
auto input = input_reshaped.generic_packed_accessor<input_scalar_t, 3, RestrictPtrTraits, index_t>();
|
||||||
|
auto running_mean = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
|
||||||
|
auto running_var = packed_accessor_or_dummy<stat_scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
|
||||||
|
auto save_mean = save_mean_.generic_packed_accessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t>();
|
||||||
|
auto save_var = save_var_.generic_packed_accessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t>();
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
// for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
|
||||||
|
// the feature dimension, we'll use some threads for blocks
|
||||||
|
dim3 blocks(input.size(1));
|
||||||
|
int tf = getNumThreads(input.size(2));
|
||||||
|
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
|
||||||
|
// NB: epsilon is unused by the Var transform, so we set it to 0
|
||||||
|
batch_norm_collect_statistics_kernel<Var, input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
|
||||||
|
(input, 0., momentum, running_mean, running_var, save_mean, save_var);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|
||||||
|
return std::make_tuple(save_mean_, save_var_);
|
||||||
|
}
|
||||||
|
|
||||||
// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
|
// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
|
||||||
// original apex name: welford_kernel_c_last
|
// original apex name: welford_kernel_c_last
|
||||||
template
|
template
|
||||||
<typename VarTransform,
|
<template<typename T> class VarTransform,
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
typename accscalar_t,
|
typename accscalar_t,
|
||||||
int PARALLEL_LOADS>
|
int PARALLEL_LOADS>
|
||||||
|
|
@ -964,13 +1084,13 @@ batch_norm_collect_statistics_channels_last_kernel(
|
||||||
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
|
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
|
||||||
if (threadIdx.y == 0 && c_offset < stride) {
|
if (threadIdx.y == 0 && c_offset < stride) {
|
||||||
out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
|
out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
|
||||||
out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
|
out_invstd[c_offset] = VarTransform<accscalar_t>{}(m2_th/count_th, epsilon);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
|
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
|
||||||
out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
|
out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
|
||||||
out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
|
out_invstd[c_offset] = VarTransform<accscalar_t>{}(m2_th/count_th, epsilon);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1260,18 +1380,18 @@ __global__ void batch_norm_backward_elemt_channels_last_kernel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t, typename VarTransform>
|
template<typename scalar_t>
|
||||||
void batch_norm_stats_channels_last_cuda_template(
|
std::tuple<Tensor, Tensor> batch_norm_stats_channels_last_cuda_template(const Tensor& input, double epsilon) {
|
||||||
const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
|
|
||||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||||
|
|
||||||
const auto stride = input.sizes()[1];
|
const auto stride = input.sizes()[1];
|
||||||
const auto reduction_size = input.numel() / stride;
|
const auto reduction_size = input.numel() / stride;
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
|
auto scalar_type = input.scalar_type() == at::kHalf ? at::kFloat : input.scalar_type();
|
||||||
out_invstd.sizes()[0]);
|
auto option = input.options().dtype(scalar_type);
|
||||||
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
|
|
||||||
out_mean.sizes()[0]);
|
at::Tensor out_invstd = at::empty({stride}, option);
|
||||||
|
at::Tensor out_mean = at::empty({stride}, option);
|
||||||
|
|
||||||
dim3 block;
|
dim3 block;
|
||||||
dim3 grid;
|
dim3 grid;
|
||||||
|
|
@ -1280,13 +1400,13 @@ void batch_norm_stats_channels_last_cuda_template(
|
||||||
at::Tensor staging_data;
|
at::Tensor staging_data;
|
||||||
at::Tensor semaphores;
|
at::Tensor semaphores;
|
||||||
if (grid.y > 1) {
|
if (grid.y > 1) {
|
||||||
staging_data = at::empty({4*stride*grid.y}, out_mean.options());
|
staging_data = at::empty({4*stride*grid.y}, option);
|
||||||
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
|
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
|
||||||
}
|
}
|
||||||
|
|
||||||
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
|
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
|
||||||
int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
|
int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
|
||||||
batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
|
batch_norm_collect_statistics_channels_last_kernel<InvStd, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
|
||||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
input.data_ptr<scalar_t>(),
|
input.data_ptr<scalar_t>(),
|
||||||
out_mean.data_ptr<accscalar_t>(),
|
out_mean.data_ptr<accscalar_t>(),
|
||||||
|
|
@ -1297,15 +1417,18 @@ void batch_norm_stats_channels_last_cuda_template(
|
||||||
stride,
|
stride,
|
||||||
epsilon);
|
epsilon);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|
||||||
|
return std::make_tuple(out_mean, out_invstd);
|
||||||
}
|
}
|
||||||
|
|
||||||
void batch_norm_elemt_channels_last_cuda_template(
|
void batch_norm_elemt_channels_last_cuda_template(
|
||||||
const at::Tensor& output,
|
at::Tensor& output,
|
||||||
const at::Tensor& input,
|
const at::Tensor& input,
|
||||||
const at::Tensor& weight,
|
const at::Tensor& weight,
|
||||||
const at::Tensor& shift, // bias of BN
|
const at::Tensor& shift, // bias of BN
|
||||||
const at::Tensor& mean,
|
const at::Tensor& mean,
|
||||||
const at::Tensor& inv_std,
|
const at::Tensor& inv_std,
|
||||||
|
double epsilon,
|
||||||
const at::optional<at::Tensor>& z = c10::nullopt, // bias after BN
|
const at::optional<at::Tensor>& z = c10::nullopt, // bias after BN
|
||||||
const bool fuse_relu = false) {
|
const bool fuse_relu = false) {
|
||||||
const auto stride = input.sizes()[1];
|
const auto stride = input.sizes()[1];
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
#include <ATen/AccumulateType.h>
|
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cuda/Reduce.cuh>
|
#include <ATen/native/cuda/Reduce.cuh>
|
||||||
#include <ATen/native/DispatchStub.h>
|
#include <ATen/native/DispatchStub.h>
|
||||||
|
|
@ -8,30 +7,31 @@
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
template <typename scalar_t, typename out_t=scalar_t>
|
template <typename scalar_t>
|
||||||
void std_var_kernel_impl(TensorIterator& iter, bool unbiased, bool take_sqrt) {
|
void std_var_kernel_impl(TensorIterator& iter, bool unbiased, bool take_sqrt) {
|
||||||
// reducing unrolling factor to 2 for welford kernel
|
// reducing unrolling factor to 2 for welford kernel
|
||||||
// This is necessary to lower register usage that leads to register spills.
|
// This is necessary to lower register usage that leads to register spills.
|
||||||
using acc_t = at::acc_type<scalar_t, true>;
|
gpu_reduce_kernel<scalar_t, scalar_t, 2>(iter, WelfordOps<scalar_t, scalar_t, int32_t, float, thrust::pair<scalar_t, scalar_t>> { unbiased, take_sqrt }, WelfordData<scalar_t, int32_t, float> {});
|
||||||
using ops_t = WelfordOps<scalar_t, acc_t, int32_t, float, thrust::pair<out_t, out_t>>;
|
}
|
||||||
gpu_reduce_kernel<scalar_t, out_t, 2>(
|
|
||||||
iter, ops_t{unbiased, take_sqrt}, typename ops_t::acc_t{});
|
template <>
|
||||||
|
void std_var_kernel_impl<at::Half>(TensorIterator& iter, bool unbiased, bool take_sqrt) {
|
||||||
|
// reducing unrolling factor to 2 for welford kernel
|
||||||
|
// This is necessary to lower register usage that leads to register spills.
|
||||||
|
gpu_reduce_kernel<at::Half, at::Half, 2>(iter, WelfordOps<at::Half, float, int32_t, float, thrust::pair<at::Half, at::Half>> { unbiased, take_sqrt }, WelfordData<float, int32_t, float> {});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void std_var_kernel_impl<at::BFloat16>(TensorIterator& iter, bool unbiased, bool take_sqrt) {
|
||||||
|
// reducing unrolling factor to 2 for welford kernel
|
||||||
|
// This is necessary to lower register usage that leads to register spills.
|
||||||
|
gpu_reduce_kernel<at::BFloat16, at::BFloat16, 2>(iter, WelfordOps<at::BFloat16, float, int32_t, float, thrust::pair<at::BFloat16, at::BFloat16>> { unbiased, take_sqrt }, WelfordData<float, int32_t, float> {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) {
|
static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) {
|
||||||
const auto input_dtype = iter.input_dtype();
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "std_cuda", [&]() {
|
||||||
if (input_dtype == kHalf && iter.dtype() == kFloat) {
|
|
||||||
// type promotion that does cast and reduction in a single kernel
|
|
||||||
std_var_kernel_impl<at::Half, float>(iter, unbiased, take_sqrt);
|
|
||||||
} else if (input_dtype == kBFloat16 && iter.dtype() == kFloat) {
|
|
||||||
// type promotion that does cast and reduction in a single kernel
|
|
||||||
std_var_kernel_impl<at::BFloat16, float>(iter, unbiased, take_sqrt);
|
|
||||||
} else {
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
|
|
||||||
iter.dtype(), "std_cuda", [&]() {
|
|
||||||
std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
|
std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
|
||||||
});
|
});
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
|
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
|
||||||
|
|
|
||||||
|
|
@ -6,106 +6,44 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
"""Microbenchmarks for batchnorm operator."""
|
"""Microbenchmarks for batchnorm operator."""
|
||||||
|
|
||||||
# Benchmark cudnn if available
|
batchnorm_configs_short = op_bench.config_list(
|
||||||
if torch.backends.cudnn.is_available:
|
|
||||||
def cudnn_benchmark_configs(configs):
|
|
||||||
result = []
|
|
||||||
for config in configs:
|
|
||||||
is_cuda = any('cuda' in attr.values() for attr in config)
|
|
||||||
if is_cuda:
|
|
||||||
result.append((*config, dict(cudnn=True)))
|
|
||||||
result.append((*config, dict(cudnn=False)))
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
def cudnn_benchmark_configs(configs):
|
|
||||||
return [(*config, dict(cudnn=False)) for config in configs]
|
|
||||||
|
|
||||||
|
|
||||||
batchnorm_configs_short = cudnn_benchmark_configs(op_bench.config_list(
|
|
||||||
attr_names=["M", "N", "K"],
|
attr_names=["M", "N", "K"],
|
||||||
attrs=[
|
attrs=[
|
||||||
[1, 256, 3136],
|
[1, 256, 3136],
|
||||||
],
|
],
|
||||||
cross_product_configs={
|
cross_product_configs={
|
||||||
'device': ['cpu', 'cuda'],
|
'device': ['cpu', 'cuda'],
|
||||||
'training': [True, False],
|
|
||||||
},
|
},
|
||||||
tags=["short"]
|
tags=["short"]
|
||||||
))
|
)
|
||||||
|
|
||||||
batchnorm_configs_long = cudnn_benchmark_configs(op_bench.cross_product_configs(
|
batchnorm_configs_long = op_bench.cross_product_configs(
|
||||||
M=[2, 128],
|
M=[1, 128],
|
||||||
N=[8192, 2048],
|
N=[8192, 2048],
|
||||||
K=[1],
|
K=[1],
|
||||||
device=['cpu', 'cuda'],
|
device=['cpu', 'cuda'],
|
||||||
training=[True, False],
|
|
||||||
tags=["long"]
|
tags=["long"]
|
||||||
))
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatchNormBenchmark(op_bench.TorchBenchmarkBase):
|
class BatchNormBenchmark(op_bench.TorchBenchmarkBase):
|
||||||
def init(self, M, N, K, device, training, cudnn):
|
def init(self, M, N, K, device):
|
||||||
self.inputs = {
|
self.inputs = {
|
||||||
"input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()),
|
"input_one": torch.rand(M, N, K, device=device, requires_grad=self.auto_set()),
|
||||||
"mean": torch.rand(N, device=device),
|
"mean": torch.rand(N, device=device),
|
||||||
"var": torch.rand(N, device=device),
|
"var": torch.rand(N, device=device),
|
||||||
"weight": torch.rand(N, device=device),
|
"weight": torch.rand(N, device=device),
|
||||||
"bias": torch.rand(N, device=device),
|
"bias": torch.rand(N, device=device)
|
||||||
"training": training,
|
|
||||||
"cudnn": cudnn,
|
|
||||||
}
|
}
|
||||||
self.set_module_name("batchnorm")
|
self.set_module_name("batchnorm")
|
||||||
|
|
||||||
def forward(self, input_one, mean, var, weight, bias, training, cudnn):
|
def forward(self, input_one, mean, var, weight, bias):
|
||||||
with torch.backends.cudnn.flags(enabled=cudnn):
|
return F.batch_norm(input_one, mean, var, weight, bias)
|
||||||
return F.batch_norm(input_one, mean, var, weight, bias, training)
|
|
||||||
|
|
||||||
|
|
||||||
op_bench.generate_pt_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark)
|
op_bench.generate_pt_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark)
|
||||||
op_bench.generate_pt_gradient_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark)
|
op_bench.generate_pt_gradient_test(batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark)
|
||||||
|
|
||||||
|
|
||||||
batchnorm1d_configs_short = cudnn_benchmark_configs(op_bench.config_list(
|
|
||||||
attr_names=["N", "C"],
|
|
||||||
attrs=[
|
|
||||||
[3136, 256],
|
|
||||||
],
|
|
||||||
cross_product_configs={
|
|
||||||
'device': ['cpu', 'cuda'],
|
|
||||||
'training': [True, False],
|
|
||||||
},
|
|
||||||
tags=["short"]
|
|
||||||
))
|
|
||||||
|
|
||||||
batchnorm1d_configs_long = cudnn_benchmark_configs(op_bench.cross_product_configs(
|
|
||||||
N=[2, 128],
|
|
||||||
C=[8192, 2048],
|
|
||||||
device=['cpu', 'cuda'],
|
|
||||||
training=[True, False],
|
|
||||||
tags=["long"]
|
|
||||||
))
|
|
||||||
|
|
||||||
class BatchNorm1dBenchmark(op_bench.TorchBenchmarkBase):
|
|
||||||
def init(self, N, C, device, training, cudnn):
|
|
||||||
self.inputs = {
|
|
||||||
"input_one": torch.rand(N, C, device=device, requires_grad=self.auto_set()),
|
|
||||||
"mean": torch.rand(C, device=device),
|
|
||||||
"var": torch.rand(C, device=device),
|
|
||||||
"weight": torch.rand(C, device=device),
|
|
||||||
"bias": torch.rand(C, device=device),
|
|
||||||
"training": training,
|
|
||||||
"cudnn": cudnn,
|
|
||||||
}
|
|
||||||
self.set_module_name("batchnorm")
|
|
||||||
|
|
||||||
def forward(self, input_one, mean, var, weight, bias, training, cudnn):
|
|
||||||
with torch.backends.cudnn.flags(enabled=cudnn):
|
|
||||||
return F.batch_norm(input_one, mean, var, weight, bias, training)
|
|
||||||
|
|
||||||
|
|
||||||
op_bench.generate_pt_test(batchnorm1d_configs_short + batchnorm1d_configs_long, BatchNorm1dBenchmark)
|
|
||||||
op_bench.generate_pt_gradient_test(batchnorm1d_configs_short + batchnorm1d_configs_long, BatchNorm1dBenchmark)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
op_bench.benchmark_runner.main()
|
op_bench.benchmark_runner.main()
|
||||||
|
|
|
||||||
|
|
@ -84,18 +84,6 @@
|
||||||
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__)
|
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __has_attribute
|
|
||||||
#define C10_HAS_ATTRIBUTE(x) __has_attribute(x)
|
|
||||||
#else
|
|
||||||
#define C10_HAS_ATTRIBUTE(x) (0)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef __has_cpp_attribute
|
|
||||||
#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x)
|
|
||||||
#else
|
|
||||||
#define C10_HAS_CPP_ATTRIBUTE(x) (0)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/// C10_NODISCARD - Warn if a type or return value is discarded.
|
/// C10_NODISCARD - Warn if a type or return value is discarded.
|
||||||
|
|
||||||
// Technically, we should check if __cplusplus > 201402L here, because
|
// Technically, we should check if __cplusplus > 201402L here, because
|
||||||
|
|
@ -222,14 +210,6 @@ using namespace c10::hip;
|
||||||
#define C10_ALWAYS_INLINE inline
|
#define C10_ALWAYS_INLINE inline
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if C10_HAS_CPP_ATTRIBUTE(fallthrough)
|
|
||||||
#define C10_FALLTHROUGH [[fallthrough]]
|
|
||||||
#elif C10_HAS_ATTRIBUTE(fallthrough)
|
|
||||||
#define C10_FALLTHROUGH __attribute__((fallthrough))
|
|
||||||
#else
|
|
||||||
#define C10_FALLTHROUGH
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user