mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
optimize the decomposition of aten.native_group_norm (#144733)
Summary: Optimize the decomposition of aten.native_group_norm. Reduce unnecessary repeated operations by changing the order of operations for `mean`, `rstd`, `weight`, `bias `and `input`, which can improve performance when `flattened_inner_size `is large. The original decomposition: 1. compute `mean `and `rstd`, 2. out = (x - mean) * rstd, compute in the range [N, C, *], 3. out = out * weight + bias, compute in the range [N, C, *], The new decomposition: 1. compute `mean `and `rstd`, 2. new_weight = rstd * weight, new_bias = - mean * rstd * weight + bias, compute in the range [N, C], 3. out = out * new_weight + new_bias, compute in the range [N, C, *], I tested the Inductor performance benchmark with this PR on both CPU and A100. On CPU, two torchbench models(functorch_dp_cifar10 and opacus_cifar10) have about 25% performance improvement, and two diffusion models(Stable Diffusion and Latent Consistency Model(LCM)) have about 2% performance improvement. On A100, no performance gains or regressions were seen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144733 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
parent
12112fd198
commit
b533bb4b13
|
|
@ -4065,8 +4065,8 @@ class CPUReproTests(TestCase):
|
|||
compiled_m = torch.compile(mod, dynamic=dynamic)
|
||||
actual, code = run_and_get_cpp_code(compiled_m, x)
|
||||
self.assertEqual(expected, actual)
|
||||
# 2 generated kernels (one for var_mean, the other for result)
|
||||
check_metrics_vec_kernel_count(2)
|
||||
# 3 generated kernels (first one for var_mean, last two for result)
|
||||
check_metrics_vec_kernel_count(3)
|
||||
|
||||
# check loop split optimization
|
||||
if fmt == torch.channels_last:
|
||||
|
|
|
|||
|
|
@ -3187,27 +3187,44 @@ def native_group_norm(
|
|||
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
||||
)
|
||||
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
input_acc = _maybe_convert_to_dtype(input, computation_dtype)
|
||||
# num_channels / num_groups and flattened inner dimension are the reduction axes
|
||||
reduction_dims = [2, 3]
|
||||
input_reshaped = torch.reshape(
|
||||
input,
|
||||
input_acc,
|
||||
[batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
|
||||
)
|
||||
out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
|
||||
out = out.view(input.shape)
|
||||
|
||||
broadcast_dims = [0] + list(range(2, input.ndim))
|
||||
unsqueeze_bias = None
|
||||
if bias is not None:
|
||||
unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
|
||||
unsqueeze_weight = None
|
||||
reduction_dims = utils.canonicalize_dims(input_reshaped.ndim, reduction_dims)
|
||||
biased_var, mean = torch.var_mean(
|
||||
input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
|
||||
)
|
||||
rstd = torch.rsqrt(biased_var + eps)
|
||||
if weight is not None:
|
||||
unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
|
||||
weight_reshaped = torch.reshape(
|
||||
weight, [1, num_groups, num_channels // num_groups, 1]
|
||||
)
|
||||
w = rstd * weight_reshaped
|
||||
b = -mean * w
|
||||
else:
|
||||
w = rstd
|
||||
b = -mean * rstd
|
||||
if bias is not None:
|
||||
bias_reshaped = torch.reshape(
|
||||
bias, [1, num_groups, num_channels // num_groups, 1]
|
||||
)
|
||||
b = b + bias_reshaped
|
||||
|
||||
if unsqueeze_weight is not None:
|
||||
out = out * unsqueeze_weight
|
||||
if unsqueeze_bias is not None:
|
||||
out = out + unsqueeze_bias
|
||||
if input.device.type == "cpu" and weight is not None:
|
||||
w = w.contiguous().as_strided([batch_size, num_channels], [num_channels, 1])
|
||||
b = b.contiguous().as_strided([batch_size, num_channels], [num_channels, 1])
|
||||
broadcast_dims = list(range(2, input.ndim))
|
||||
unsqueeze_w = _unsqueeze_multiple(w, broadcast_dims)
|
||||
unsqueeze_b = _unsqueeze_multiple(b, broadcast_dims)
|
||||
out = input_acc * unsqueeze_w + unsqueeze_b
|
||||
else:
|
||||
out = input_reshaped * w + b
|
||||
out = out.view(input.shape)
|
||||
|
||||
out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
|
||||
mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user