optimize the decomposition of aten.native_group_norm (#144733)

Summary:
Optimize the decomposition of aten.native_group_norm. Reduce unnecessary repeated operations by changing the order of operations for `mean`, `rstd`, `weight`, `bias `and `input`, which can improve performance when `flattened_inner_size `is large.

The original decomposition:
1. compute `mean `and `rstd`,
2. out = (x - mean) * rstd, compute in the range [N, C, *],
3. out = out * weight + bias, compute in the range [N, C, *],

The new decomposition:
1. compute `mean `and `rstd`,
2. new_weight = rstd * weight, new_bias = - mean * rstd * weight + bias, compute in the range [N, C],
3. out = out * new_weight + new_bias, compute in the range [N, C, *],

I tested the Inductor performance benchmark with this PR on both CPU and A100. On CPU, two torchbench models(functorch_dp_cifar10 and opacus_cifar10) have about 25% performance improvement, and two diffusion models(Stable Diffusion and Latent Consistency Model(LCM)) have about 2% performance improvement. On A100, no performance gains or regressions were seen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144733
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
Sun, Jiayi 2025-02-24 23:01:24 -08:00 committed by PyTorch MergeBot
parent 12112fd198
commit b533bb4b13
2 changed files with 33 additions and 16 deletions

View File

@ -4065,8 +4065,8 @@ class CPUReproTests(TestCase):
compiled_m = torch.compile(mod, dynamic=dynamic)
actual, code = run_and_get_cpp_code(compiled_m, x)
self.assertEqual(expected, actual)
# 2 generated kernels (one for var_mean, the other for result)
check_metrics_vec_kernel_count(2)
# 3 generated kernels (first one for var_mean, last two for result)
check_metrics_vec_kernel_count(3)
# check loop split optimization
if fmt == torch.channels_last:

View File

@ -3187,27 +3187,44 @@ def native_group_norm(
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
)
computation_dtype = utils.get_computation_dtype(input.dtype)
input_acc = _maybe_convert_to_dtype(input, computation_dtype)
# num_channels / num_groups and flattened inner dimension are the reduction axes
reduction_dims = [2, 3]
input_reshaped = torch.reshape(
input,
input_acc,
[batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
)
out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
out = out.view(input.shape)
broadcast_dims = [0] + list(range(2, input.ndim))
unsqueeze_bias = None
if bias is not None:
unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
unsqueeze_weight = None
reduction_dims = utils.canonicalize_dims(input_reshaped.ndim, reduction_dims)
biased_var, mean = torch.var_mean(
input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
)
rstd = torch.rsqrt(biased_var + eps)
if weight is not None:
unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
weight_reshaped = torch.reshape(
weight, [1, num_groups, num_channels // num_groups, 1]
)
w = rstd * weight_reshaped
b = -mean * w
else:
w = rstd
b = -mean * rstd
if bias is not None:
bias_reshaped = torch.reshape(
bias, [1, num_groups, num_channels // num_groups, 1]
)
b = b + bias_reshaped
if unsqueeze_weight is not None:
out = out * unsqueeze_weight
if unsqueeze_bias is not None:
out = out + unsqueeze_bias
if input.device.type == "cpu" and weight is not None:
w = w.contiguous().as_strided([batch_size, num_channels], [num_channels, 1])
b = b.contiguous().as_strided([batch_size, num_channels], [num_channels, 1])
broadcast_dims = list(range(2, input.ndim))
unsqueeze_w = _unsqueeze_multiple(w, broadcast_dims)
unsqueeze_b = _unsqueeze_multiple(b, broadcast_dims)
out = input_acc * unsqueeze_w + unsqueeze_b
else:
out = input_reshaped * w + b
out = out.view(input.shape)
out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]