Add NHWC support for group normalization (#126635)

Fixes #111824

Currently it is the case that if the user specifies their group normalization to be of NHWC format, pytorch will default to NCHW tensors and convert. This  conversion is not immediately obvious to the user unless they check the format themselves which is not intuitive. This PR adds suppor for NHWC for cuda by adding necessary kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126635
Approved by: https://github.com/eqy, https://github.com/mikaylagawarecki
This commit is contained in:
Danial Javady 2024-11-07 01:12:05 +00:00 committed by PyTorch MergeBot
parent 59ec011855
commit ed0e63e938
3 changed files with 198 additions and 82 deletions

View File

@ -1,19 +1,18 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/group_norm.h>
#include <type_traits>
#include <thrust/tuple.h>
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/TensorIterator.h>
#include <c10/core/MemoryFormat.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <thrust/tuple.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/block_reduce.cuh>
#include <type_traits>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -29,8 +28,12 @@ constexpr int kCUDANumThreads = 256;
constexpr int kReduceTileSize = 32;
template <typename T>
__global__ void RowwiseMomentsCUDAKernel(
__global__ void RowwiseMomentsCUDAKernelNHWC(
int64_t N,
int64_t H,
int64_t W,
int64_t C,
int64_t G,
T eps,
const T* X,
T* mean,
@ -40,11 +43,63 @@ __global__ void RowwiseMomentsCUDAKernel(
using WelfordOp =
WelfordOps<T_ACC, T_ACC, int64_t, thrust::pair<T_ACC, T_ACC>>;
const int64_t channels_per_group = C / G;
const int64_t batch_index = blockIdx.x / G;
const int64_t ng = blockIdx.x % G;
const int64_t batch_offset = batch_index * H * W * C;
const int64_t group_offset = ng * channels_per_group;
const int64_t start = batch_offset + group_offset;
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
WelfordType val(0, 0, 0, 0);
for (int64_t j = threadIdx.x; j < H * W; j += blockDim.x) {
for (int64_t c = 0; c < channels_per_group; ++c) {
const int64_t index = start + j * C + c;
val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
}
}
if (blockDim.x <= C10_WARP_SIZE) {
val = cuda_utils::WarpReduce(val, welford_op);
} else {
__shared__ typename std::aligned_storage<
sizeof(WelfordType),
alignof(WelfordType)>::type val_shared[C10_WARP_SIZE];
WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);
val = cuda_utils::BlockReduce(
val,
welford_op,
/*identity_element=*/WelfordType(0, 0, 0, 0),
val_shared_ptr);
}
if (threadIdx.x == 0) {
T_ACC m1;
T_ACC m2;
thrust::tie(m2, m1) = welford_op.project(val);
mean[blockIdx.x] = m1;
rstd[blockIdx.x] = c10::cuda::compat::rsqrt(m2 + static_cast<T_ACC>(eps));
}
}
template <typename T>
__global__ void RowwiseMomentsCUDAKernel(
int64_t group_span,
T eps,
const T* X,
T* mean,
T* rstd,
int64_t C) {
using T_ACC = acc_type<T, true>;
using WelfordType = WelfordData<T_ACC, int64_t>;
using WelfordOp =
WelfordOps<T_ACC, T_ACC, int64_t, thrust::pair<T_ACC, T_ACC>>;
const int64_t i = blockIdx.x;
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
WelfordType val(0, 0, 0, 0);
for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
const int64_t index = i * N + j;
for (int64_t j = threadIdx.x; j < group_span; j += blockDim.x) {
const int64_t index = i * group_span + j;
val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
}
if (blockDim.x <= C10_WARP_SIZE) {
@ -570,20 +625,48 @@ void GroupNormKernelImplInternal(
if (N == 0) {
return;
}
const int64_t G = group;
const int64_t D = C / G;
const T* X_data = X.const_data_ptr<T>();
T* Y_data = Y.mutable_data_ptr<T>();
T* mean_data = mean.mutable_data_ptr<T>();
T* rstd_data = rstd.mutable_data_ptr<T>();
at::MemoryFormat x_format = X.suggest_memory_format();
Y.is_contiguous(x_format);
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
const int64_t num_threads = D * HxW < cuda_utils::kCUDABlockReduceNumThreads
? at::cuda::warp_size()
: cuda_utils::kCUDABlockReduceNumThreads;
RowwiseMomentsCUDAKernel<T><<<N * G, num_threads, 0, cuda_stream>>>(
D * HxW, eps, X_data, mean_data, rstd_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
int height;
int width;
switch (x_format) {
case MemoryFormat::Contiguous: {
RowwiseMomentsCUDAKernel<T><<<N * G, num_threads, 0, cuda_stream>>>(
D * HxW, eps, X_data, mean_data, rstd_data, C);
break;
}
case MemoryFormat::ChannelsLast: {
height = X.size(2);
width = X.size(3);
RowwiseMomentsCUDAKernelNHWC<T><<<N * G, num_threads, 0, cuda_stream>>>(
N, height, width, C, G, eps, X_data, mean_data, rstd_data);
break;
}
default: {
TORCH_CHECK(
false,
"Unsupported memory format for group normalization: ",
x_format);
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
if (HxW == 1) {
GroupNorm1dForward<T>(X, mean, rstd, gamma, beta, N, C, G, Y);
} else if (!gamma.defined() && !beta.defined()) {
@ -594,6 +677,7 @@ void GroupNormKernelImplInternal(
.add_owned_input(mean.view({N * G, 1}))
.add_owned_input(rstd.view({N * G, 1}))
.build();
gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd) -> T {
return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
static_cast<T_ACC>(rstd);
@ -605,6 +689,7 @@ void GroupNormKernelImplInternal(
: X.scalar_type();
Tensor a = at::empty({N, C}, X.options().dtype(kAccType));
Tensor b = at::empty({N, C}, X.options().dtype(kAccType));
const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
T_ACC* a_data = a.mutable_data_ptr<T_ACC>();
@ -614,22 +699,49 @@ void GroupNormKernelImplInternal(
// using manual kernel here. Make it using gpu_kernel_multiple_outputs once
// the issue fixed.
const int64_t B = (N * C + kCUDANumThreads - 1) / kCUDANumThreads;
ComputeFusedParamsCUDAKernel<T><<<B, kCUDANumThreads, 0, cuda_stream>>>(
N, C, G, mean_data, rstd_data, gamma_data, beta_data, a_data, b_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto iter = TensorIteratorConfig()
.check_all_same_dtype(std::is_same_v<T, T_ACC>)
.resize_outputs(false)
.add_owned_output(Y.view({N * C, HxW}))
.add_owned_const_input(X.view({N * C, HxW}))
.add_owned_input(a.view({N * C, 1}))
.add_owned_input(b.view({N * C, 1}))
.build();
gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
return a * static_cast<T_ACC>(x) + b;
});
switch (x_format) {
case MemoryFormat::Contiguous: {
TensorIterator iter =
TensorIteratorConfig()
.check_all_same_dtype(std::is_same_v<T, T_ACC>)
.resize_outputs(false)
.add_owned_output(Y.view({N * C, HxW}))
.add_owned_const_input(X.view({N * C, HxW}))
.add_owned_input(a.view({N * C, 1}))
.add_owned_input(b.view({N * C, 1}))
.build();
gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
return a * static_cast<T_ACC>(x) + b;
});
break;
}
case MemoryFormat::ChannelsLast: {
TensorIterator iter =
TensorIteratorConfig()
.check_all_same_dtype(std::is_same_v<T, T_ACC>)
.resize_outputs(false)
.add_owned_output(Y)
.add_owned_const_input(X)
.add_owned_input(a.view({N, C, 1, 1}))
.add_owned_input(b.view({N, C, 1, 1}))
.build();
gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
return a * static_cast<T_ACC>(x) + b;
});
break;
}
default:
break; // shouldn't hit this
}
}
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -4,7 +4,6 @@
#include <ATen/Parallel.h>
#include <ATen/native/cpu/mixed_data_type.h>
#include <c10/util/accumulate.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
@ -22,7 +21,6 @@
#include <functional>
#include <tuple>
#include <vector>
namespace at::native {
template <typename T>
@ -68,6 +66,7 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm(
int64_t HxW,
int64_t group,
double eps) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> gamma_maybe_owned =
at::borrow_from_optional_tensor(gamma_opt);
@ -77,16 +76,14 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm(
// repeated check so expanded weights can call native_group_norm directly but
// save mean and variance from forward
check_group_norm_inputs(X, gamma, beta, C, group);
auto memory_format = X.device().is_cpu() ?
X.suggest_memory_format() : at::MemoryFormat::Contiguous;
TORCH_CHECK(X.is_contiguous(memory_format));
auto memory_format = X.suggest_memory_format();
bool mixed_type = is_mixed_type(X, gamma, beta);
if (mixed_type) {
check_mixed_data_type(X, gamma, beta);
}
Tensor Y = at::native::empty_like(
X,
std::nullopt /* dtype */,
@ -97,6 +94,8 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm(
const auto dtype = param_scalar_type(X, mixed_type);
Tensor mean = at::empty({N, group}, X.options().dtype(dtype));
Tensor rstd = at::empty({N, group}, X.options().dtype(dtype));
GroupNormKernel(
X.device().type(), X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
return std::make_tuple(Y, mean, rstd);
@ -194,11 +193,9 @@ Tensor group_norm(
const auto input_shape = input.sym_sizes();
const auto HxW =
c10::multiply_integers(input_shape.slice(2));
const Tensor kEmpty;
auto memory_format = input.suggest_memory_format();
const auto& X = input.device().is_cpu() || input.is_privateuseone() ?
input.contiguous(memory_format) : input.contiguous();
const auto& X = input.contiguous(memory_format);
const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty;
const auto& beta = bias.defined() ? bias.contiguous() : kEmpty;
TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C);

View File

@ -2022,6 +2022,64 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
gradcheck(fn, (m.weight_orig,))
def test_groupnorm_nhwc(self):
def helper(self, size, groups, memory_format, is_mixed, device, dtype):
channels = size[1]
input = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
input = input.contiguous(memory_format=memory_format)
input.retain_grad()
grad = torch.randn(size, dtype=dtype, device=device)
grad = grad.contiguous(memory_format=memory_format)
if dtype == torch.bfloat16 and is_mixed:
gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
else:
gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
gn.weight.data.uniform_()
gn.bias.data.uniform_()
ref_input = input.detach().clone().contiguous(memory_format=torch.contiguous_format).requires_grad_(True)
ref_grad = grad.detach().clone().contiguous(memory_format=torch.contiguous_format)
if dtype == torch.bfloat16 and is_mixed:
ref_gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
else:
ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
ref_gn.load_state_dict(gn.state_dict())
out = gn(input)
out.backward(grad)
ref_out = ref_gn(ref_input)
ref_out.backward(ref_grad)
self.assertTrue(out.is_contiguous(memory_format=memory_format))
print(f'{memory_format}')
self.assertTrue(ref_out.is_contiguous(memory_format=torch.contiguous_format))
self.assertEqual(out, ref_out)
# parameters in bfloat16/Half is not recommended
atol = 5e-4
rtol = 8e-3
self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=atol, rtol=rtol)
self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=atol, rtol=rtol)
self.assertEqual(input.grad, ref_input.grad, atol=atol, rtol=rtol)
for device in ['cpu'] + (['cuda'] if TEST_CUDA else []):
for dtype in [torch.float, torch.double]:
if device == 'cuda' and dtype not in [torch.float, torch.double]:
continue
for is_mixed in [True, False]:
helper(self, (4, 8, 10, 10), 4, torch.channels_last, is_mixed, device, dtype)
helper(self, (2, 30, 9, 9), 3, torch.channels_last, is_mixed, device, dtype)
helper(self, (4, 8, 40, 40), 4, torch.channels_last, is_mixed, device, dtype)
helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed, device, dtype)
helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed, device, dtype)
helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed, device, dtype)
# channels_last_3d is currently not supported for cuda
if device == 'cpu':
helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, is_mixed, device, dtype)
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed, device, dtype)
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed, device, dtype)
@skipIfNoLapack
def test_spectral_norm_load_state_dict(self):
inp = torch.randn(2, 3)
@ -8459,57 +8517,6 @@ class TestNNDeviceType(NNTestCase):
with torch.backends.cudnn.flags(enabled=False):
_test_module_empty_input(self, mod, inp)
@onlyCPU
@dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
def test_groupnorm_nhwc(self, device, dtype):
def helper(self, size, groups, memory_format, is_mixed):
channels = size[1]
input = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
input = input.contiguous(memory_format=memory_format)
input.retain_grad()
grad = torch.randn(size, dtype=dtype, device=device)
grad = grad.contiguous(memory_format=memory_format)
if dtype == torch.bfloat16 and is_mixed:
gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
else:
gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
gn.weight.data.uniform_()
gn.bias.data.uniform_()
ref_input = input.detach().clone().contiguous(memory_format=torch.contiguous_format).requires_grad_(True)
ref_grad = grad.detach().clone().contiguous(memory_format=torch.contiguous_format)
if dtype == torch.bfloat16 and is_mixed:
ref_gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
else:
ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
ref_gn.load_state_dict(gn.state_dict())
out = gn(input)
out.backward(grad)
ref_out = ref_gn(ref_input)
ref_out.backward(ref_grad)
self.assertTrue(out.is_contiguous(memory_format=memory_format))
self.assertTrue(ref_out.is_contiguous(memory_format=torch.contiguous_format))
self.assertEqual(out, ref_out)
# parameters in bfloat16/Half is not recommended
atol = 5e-4
rtol = 8e-3
self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=atol, rtol=rtol)
self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=atol, rtol=rtol)
self.assertEqual(input.grad, ref_input.grad, atol=atol, rtol=rtol)
for is_mixed in [True, False]:
helper(self, (4, 8, 10, 10), 4, torch.channels_last, is_mixed)
helper(self, (2, 30, 9, 9), 3, torch.channels_last, is_mixed)
helper(self, (4, 8, 40, 40), 4, torch.channels_last, is_mixed)
helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed)
helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed)
helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed)
helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, is_mixed)
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
@onlyNativeDeviceTypes
def test_GroupNorm_memory_format(self, device):
# Tests for regression reported in https://github.com/pytorch/pytorch/issues/92166