mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add NHWC support for group normalization (#126635)
Fixes #111824 Currently it is the case that if the user specifies their group normalization to be of NHWC format, pytorch will default to NCHW tensors and convert. This conversion is not immediately obvious to the user unless they check the format themselves which is not intuitive. This PR adds suppor for NHWC for cuda by adding necessary kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126635 Approved by: https://github.com/eqy, https://github.com/mikaylagawarecki
This commit is contained in:
parent
59ec011855
commit
ed0e63e938
|
|
@ -1,19 +1,18 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/group_norm.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <thrust/tuple.h>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/SharedReduceOps.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <thrust/tuple.h>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/block_reduce.cuh>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -29,8 +28,12 @@ constexpr int kCUDANumThreads = 256;
|
|||
constexpr int kReduceTileSize = 32;
|
||||
|
||||
template <typename T>
|
||||
__global__ void RowwiseMomentsCUDAKernel(
|
||||
__global__ void RowwiseMomentsCUDAKernelNHWC(
|
||||
int64_t N,
|
||||
int64_t H,
|
||||
int64_t W,
|
||||
int64_t C,
|
||||
int64_t G,
|
||||
T eps,
|
||||
const T* X,
|
||||
T* mean,
|
||||
|
|
@ -40,11 +43,63 @@ __global__ void RowwiseMomentsCUDAKernel(
|
|||
using WelfordOp =
|
||||
WelfordOps<T_ACC, T_ACC, int64_t, thrust::pair<T_ACC, T_ACC>>;
|
||||
|
||||
const int64_t channels_per_group = C / G;
|
||||
const int64_t batch_index = blockIdx.x / G;
|
||||
const int64_t ng = blockIdx.x % G;
|
||||
const int64_t batch_offset = batch_index * H * W * C;
|
||||
const int64_t group_offset = ng * channels_per_group;
|
||||
const int64_t start = batch_offset + group_offset;
|
||||
|
||||
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
|
||||
WelfordType val(0, 0, 0, 0);
|
||||
for (int64_t j = threadIdx.x; j < H * W; j += blockDim.x) {
|
||||
for (int64_t c = 0; c < channels_per_group; ++c) {
|
||||
const int64_t index = start + j * C + c;
|
||||
val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
|
||||
}
|
||||
}
|
||||
|
||||
if (blockDim.x <= C10_WARP_SIZE) {
|
||||
val = cuda_utils::WarpReduce(val, welford_op);
|
||||
} else {
|
||||
__shared__ typename std::aligned_storage<
|
||||
sizeof(WelfordType),
|
||||
alignof(WelfordType)>::type val_shared[C10_WARP_SIZE];
|
||||
WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);
|
||||
val = cuda_utils::BlockReduce(
|
||||
val,
|
||||
welford_op,
|
||||
/*identity_element=*/WelfordType(0, 0, 0, 0),
|
||||
val_shared_ptr);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
T_ACC m1;
|
||||
T_ACC m2;
|
||||
thrust::tie(m2, m1) = welford_op.project(val);
|
||||
mean[blockIdx.x] = m1;
|
||||
rstd[blockIdx.x] = c10::cuda::compat::rsqrt(m2 + static_cast<T_ACC>(eps));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RowwiseMomentsCUDAKernel(
|
||||
int64_t group_span,
|
||||
T eps,
|
||||
const T* X,
|
||||
T* mean,
|
||||
T* rstd,
|
||||
int64_t C) {
|
||||
using T_ACC = acc_type<T, true>;
|
||||
using WelfordType = WelfordData<T_ACC, int64_t>;
|
||||
using WelfordOp =
|
||||
WelfordOps<T_ACC, T_ACC, int64_t, thrust::pair<T_ACC, T_ACC>>;
|
||||
|
||||
const int64_t i = blockIdx.x;
|
||||
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
|
||||
WelfordType val(0, 0, 0, 0);
|
||||
for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
|
||||
const int64_t index = i * N + j;
|
||||
for (int64_t j = threadIdx.x; j < group_span; j += blockDim.x) {
|
||||
const int64_t index = i * group_span + j;
|
||||
val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
|
||||
}
|
||||
if (blockDim.x <= C10_WARP_SIZE) {
|
||||
|
|
@ -570,20 +625,48 @@ void GroupNormKernelImplInternal(
|
|||
if (N == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t G = group;
|
||||
const int64_t D = C / G;
|
||||
const T* X_data = X.const_data_ptr<T>();
|
||||
T* Y_data = Y.mutable_data_ptr<T>();
|
||||
T* mean_data = mean.mutable_data_ptr<T>();
|
||||
T* rstd_data = rstd.mutable_data_ptr<T>();
|
||||
|
||||
at::MemoryFormat x_format = X.suggest_memory_format();
|
||||
Y.is_contiguous(x_format);
|
||||
|
||||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
const int64_t num_threads = D * HxW < cuda_utils::kCUDABlockReduceNumThreads
|
||||
? at::cuda::warp_size()
|
||||
: cuda_utils::kCUDABlockReduceNumThreads;
|
||||
RowwiseMomentsCUDAKernel<T><<<N * G, num_threads, 0, cuda_stream>>>(
|
||||
D * HxW, eps, X_data, mean_data, rstd_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
int height;
|
||||
int width;
|
||||
|
||||
switch (x_format) {
|
||||
case MemoryFormat::Contiguous: {
|
||||
RowwiseMomentsCUDAKernel<T><<<N * G, num_threads, 0, cuda_stream>>>(
|
||||
D * HxW, eps, X_data, mean_data, rstd_data, C);
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
height = X.size(2);
|
||||
width = X.size(3);
|
||||
|
||||
RowwiseMomentsCUDAKernelNHWC<T><<<N * G, num_threads, 0, cuda_stream>>>(
|
||||
N, height, width, C, G, eps, X_data, mean_data, rstd_data);
|
||||
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported memory format for group normalization: ",
|
||||
x_format);
|
||||
}
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
if (HxW == 1) {
|
||||
GroupNorm1dForward<T>(X, mean, rstd, gamma, beta, N, C, G, Y);
|
||||
} else if (!gamma.defined() && !beta.defined()) {
|
||||
|
|
@ -594,6 +677,7 @@ void GroupNormKernelImplInternal(
|
|||
.add_owned_input(mean.view({N * G, 1}))
|
||||
.add_owned_input(rstd.view({N * G, 1}))
|
||||
.build();
|
||||
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(T x, T mean, T rstd) -> T {
|
||||
return (static_cast<T_ACC>(x) - static_cast<T_ACC>(mean)) *
|
||||
static_cast<T_ACC>(rstd);
|
||||
|
|
@ -605,6 +689,7 @@ void GroupNormKernelImplInternal(
|
|||
: X.scalar_type();
|
||||
Tensor a = at::empty({N, C}, X.options().dtype(kAccType));
|
||||
Tensor b = at::empty({N, C}, X.options().dtype(kAccType));
|
||||
|
||||
const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
|
||||
const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
|
||||
T_ACC* a_data = a.mutable_data_ptr<T_ACC>();
|
||||
|
|
@ -614,22 +699,49 @@ void GroupNormKernelImplInternal(
|
|||
// using manual kernel here. Make it using gpu_kernel_multiple_outputs once
|
||||
// the issue fixed.
|
||||
const int64_t B = (N * C + kCUDANumThreads - 1) / kCUDANumThreads;
|
||||
|
||||
ComputeFusedParamsCUDAKernel<T><<<B, kCUDANumThreads, 0, cuda_stream>>>(
|
||||
N, C, G, mean_data, rstd_data, gamma_data, beta_data, a_data, b_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
auto iter = TensorIteratorConfig()
|
||||
.check_all_same_dtype(std::is_same_v<T, T_ACC>)
|
||||
.resize_outputs(false)
|
||||
.add_owned_output(Y.view({N * C, HxW}))
|
||||
.add_owned_const_input(X.view({N * C, HxW}))
|
||||
.add_owned_input(a.view({N * C, 1}))
|
||||
.add_owned_input(b.view({N * C, 1}))
|
||||
.build();
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
|
||||
return a * static_cast<T_ACC>(x) + b;
|
||||
});
|
||||
switch (x_format) {
|
||||
case MemoryFormat::Contiguous: {
|
||||
TensorIterator iter =
|
||||
TensorIteratorConfig()
|
||||
.check_all_same_dtype(std::is_same_v<T, T_ACC>)
|
||||
.resize_outputs(false)
|
||||
.add_owned_output(Y.view({N * C, HxW}))
|
||||
.add_owned_const_input(X.view({N * C, HxW}))
|
||||
.add_owned_input(a.view({N * C, 1}))
|
||||
.add_owned_input(b.view({N * C, 1}))
|
||||
.build();
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
|
||||
return a * static_cast<T_ACC>(x) + b;
|
||||
});
|
||||
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
TensorIterator iter =
|
||||
TensorIteratorConfig()
|
||||
.check_all_same_dtype(std::is_same_v<T, T_ACC>)
|
||||
.resize_outputs(false)
|
||||
.add_owned_output(Y)
|
||||
.add_owned_const_input(X)
|
||||
.add_owned_input(a.view({N, C, 1, 1}))
|
||||
.add_owned_input(b.view({N, C, 1, 1}))
|
||||
.build();
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(T x, T_ACC a, T_ACC b) -> T {
|
||||
return a * static_cast<T_ACC>(x) + b;
|
||||
});
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break; // shouldn't hit this
|
||||
}
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#include <ATen/Parallel.h>
|
||||
#include <ATen/native/cpu/mixed_data_type.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
|
@ -22,7 +21,6 @@
|
|||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -68,6 +66,7 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm(
|
|||
int64_t HxW,
|
||||
int64_t group,
|
||||
double eps) {
|
||||
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> gamma_maybe_owned =
|
||||
at::borrow_from_optional_tensor(gamma_opt);
|
||||
|
|
@ -77,16 +76,14 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm(
|
|||
// repeated check so expanded weights can call native_group_norm directly but
|
||||
// save mean and variance from forward
|
||||
check_group_norm_inputs(X, gamma, beta, C, group);
|
||||
auto memory_format = X.device().is_cpu() ?
|
||||
X.suggest_memory_format() : at::MemoryFormat::Contiguous;
|
||||
|
||||
TORCH_CHECK(X.is_contiguous(memory_format));
|
||||
auto memory_format = X.suggest_memory_format();
|
||||
|
||||
bool mixed_type = is_mixed_type(X, gamma, beta);
|
||||
if (mixed_type) {
|
||||
check_mixed_data_type(X, gamma, beta);
|
||||
}
|
||||
|
||||
|
||||
Tensor Y = at::native::empty_like(
|
||||
X,
|
||||
std::nullopt /* dtype */,
|
||||
|
|
@ -97,6 +94,8 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm(
|
|||
const auto dtype = param_scalar_type(X, mixed_type);
|
||||
Tensor mean = at::empty({N, group}, X.options().dtype(dtype));
|
||||
Tensor rstd = at::empty({N, group}, X.options().dtype(dtype));
|
||||
|
||||
|
||||
GroupNormKernel(
|
||||
X.device().type(), X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd);
|
||||
return std::make_tuple(Y, mean, rstd);
|
||||
|
|
@ -194,11 +193,9 @@ Tensor group_norm(
|
|||
const auto input_shape = input.sym_sizes();
|
||||
const auto HxW =
|
||||
c10::multiply_integers(input_shape.slice(2));
|
||||
|
||||
const Tensor kEmpty;
|
||||
auto memory_format = input.suggest_memory_format();
|
||||
const auto& X = input.device().is_cpu() || input.is_privateuseone() ?
|
||||
input.contiguous(memory_format) : input.contiguous();
|
||||
const auto& X = input.contiguous(memory_format);
|
||||
const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty;
|
||||
const auto& beta = bias.defined() ? bias.contiguous() : kEmpty;
|
||||
TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C);
|
||||
|
|
|
|||
109
test/test_nn.py
109
test/test_nn.py
|
|
@ -2022,6 +2022,64 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
|
||||
gradcheck(fn, (m.weight_orig,))
|
||||
|
||||
def test_groupnorm_nhwc(self):
|
||||
def helper(self, size, groups, memory_format, is_mixed, device, dtype):
|
||||
channels = size[1]
|
||||
input = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
|
||||
input = input.contiguous(memory_format=memory_format)
|
||||
input.retain_grad()
|
||||
grad = torch.randn(size, dtype=dtype, device=device)
|
||||
grad = grad.contiguous(memory_format=memory_format)
|
||||
if dtype == torch.bfloat16 and is_mixed:
|
||||
gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
|
||||
else:
|
||||
gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
|
||||
gn.weight.data.uniform_()
|
||||
gn.bias.data.uniform_()
|
||||
|
||||
ref_input = input.detach().clone().contiguous(memory_format=torch.contiguous_format).requires_grad_(True)
|
||||
ref_grad = grad.detach().clone().contiguous(memory_format=torch.contiguous_format)
|
||||
if dtype == torch.bfloat16 and is_mixed:
|
||||
ref_gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
|
||||
else:
|
||||
ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
|
||||
ref_gn.load_state_dict(gn.state_dict())
|
||||
out = gn(input)
|
||||
out.backward(grad)
|
||||
ref_out = ref_gn(ref_input)
|
||||
ref_out.backward(ref_grad)
|
||||
|
||||
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
||||
print(f'{memory_format}')
|
||||
self.assertTrue(ref_out.is_contiguous(memory_format=torch.contiguous_format))
|
||||
|
||||
self.assertEqual(out, ref_out)
|
||||
# parameters in bfloat16/Half is not recommended
|
||||
atol = 5e-4
|
||||
rtol = 8e-3
|
||||
|
||||
self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=atol, rtol=rtol)
|
||||
self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=atol, rtol=rtol)
|
||||
self.assertEqual(input.grad, ref_input.grad, atol=atol, rtol=rtol)
|
||||
|
||||
for device in ['cpu'] + (['cuda'] if TEST_CUDA else []):
|
||||
for dtype in [torch.float, torch.double]:
|
||||
if device == 'cuda' and dtype not in [torch.float, torch.double]:
|
||||
continue
|
||||
for is_mixed in [True, False]:
|
||||
helper(self, (4, 8, 10, 10), 4, torch.channels_last, is_mixed, device, dtype)
|
||||
helper(self, (2, 30, 9, 9), 3, torch.channels_last, is_mixed, device, dtype)
|
||||
helper(self, (4, 8, 40, 40), 4, torch.channels_last, is_mixed, device, dtype)
|
||||
helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed, device, dtype)
|
||||
helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed, device, dtype)
|
||||
helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed, device, dtype)
|
||||
|
||||
# channels_last_3d is currently not supported for cuda
|
||||
if device == 'cpu':
|
||||
helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, is_mixed, device, dtype)
|
||||
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed, device, dtype)
|
||||
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed, device, dtype)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_spectral_norm_load_state_dict(self):
|
||||
inp = torch.randn(2, 3)
|
||||
|
|
@ -8459,57 +8517,6 @@ class TestNNDeviceType(NNTestCase):
|
|||
with torch.backends.cudnn.flags(enabled=False):
|
||||
_test_module_empty_input(self, mod, inp)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
|
||||
def test_groupnorm_nhwc(self, device, dtype):
|
||||
def helper(self, size, groups, memory_format, is_mixed):
|
||||
channels = size[1]
|
||||
input = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
|
||||
input = input.contiguous(memory_format=memory_format)
|
||||
input.retain_grad()
|
||||
grad = torch.randn(size, dtype=dtype, device=device)
|
||||
grad = grad.contiguous(memory_format=memory_format)
|
||||
if dtype == torch.bfloat16 and is_mixed:
|
||||
gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
|
||||
else:
|
||||
gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
|
||||
gn.weight.data.uniform_()
|
||||
gn.bias.data.uniform_()
|
||||
|
||||
ref_input = input.detach().clone().contiguous(memory_format=torch.contiguous_format).requires_grad_(True)
|
||||
ref_grad = grad.detach().clone().contiguous(memory_format=torch.contiguous_format)
|
||||
if dtype == torch.bfloat16 and is_mixed:
|
||||
ref_gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
|
||||
else:
|
||||
ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
|
||||
ref_gn.load_state_dict(gn.state_dict())
|
||||
out = gn(input)
|
||||
out.backward(grad)
|
||||
ref_out = ref_gn(ref_input)
|
||||
ref_out.backward(ref_grad)
|
||||
|
||||
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
||||
self.assertTrue(ref_out.is_contiguous(memory_format=torch.contiguous_format))
|
||||
self.assertEqual(out, ref_out)
|
||||
# parameters in bfloat16/Half is not recommended
|
||||
atol = 5e-4
|
||||
rtol = 8e-3
|
||||
|
||||
self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=atol, rtol=rtol)
|
||||
self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=atol, rtol=rtol)
|
||||
self.assertEqual(input.grad, ref_input.grad, atol=atol, rtol=rtol)
|
||||
|
||||
for is_mixed in [True, False]:
|
||||
helper(self, (4, 8, 10, 10), 4, torch.channels_last, is_mixed)
|
||||
helper(self, (2, 30, 9, 9), 3, torch.channels_last, is_mixed)
|
||||
helper(self, (4, 8, 40, 40), 4, torch.channels_last, is_mixed)
|
||||
helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed)
|
||||
helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed)
|
||||
helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed)
|
||||
helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, is_mixed)
|
||||
helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
|
||||
helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_GroupNorm_memory_format(self, device):
|
||||
# Tests for regression reported in https://github.com/pytorch/pytorch/issues/92166
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user