mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
optimize the decomposition of aten.native_group_norm (#144733)
Summary: Optimize the decomposition of aten.native_group_norm. Reduce unnecessary repeated operations by changing the order of operations for `mean`, `rstd`, `weight`, `bias `and `input`, which can improve performance when `flattened_inner_size `is large. The original decomposition: 1. compute `mean `and `rstd`, 2. out = (x - mean) * rstd, compute in the range [N, C, *], 3. out = out * weight + bias, compute in the range [N, C, *], The new decomposition: 1. compute `mean `and `rstd`, 2. new_weight = rstd * weight, new_bias = - mean * rstd * weight + bias, compute in the range [N, C], 3. out = out * new_weight + new_bias, compute in the range [N, C, *], I tested the Inductor performance benchmark with this PR on both CPU and A100. On CPU, two torchbench models(functorch_dp_cifar10 and opacus_cifar10) have about 25% performance improvement, and two diffusion models(Stable Diffusion and Latent Consistency Model(LCM)) have about 2% performance improvement. On A100, no performance gains or regressions were seen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144733 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
parent
12112fd198
commit
b533bb4b13
|
|
@ -4065,8 +4065,8 @@ class CPUReproTests(TestCase):
|
||||||
compiled_m = torch.compile(mod, dynamic=dynamic)
|
compiled_m = torch.compile(mod, dynamic=dynamic)
|
||||||
actual, code = run_and_get_cpp_code(compiled_m, x)
|
actual, code = run_and_get_cpp_code(compiled_m, x)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
# 2 generated kernels (one for var_mean, the other for result)
|
# 3 generated kernels (first one for var_mean, last two for result)
|
||||||
check_metrics_vec_kernel_count(2)
|
check_metrics_vec_kernel_count(3)
|
||||||
|
|
||||||
# check loop split optimization
|
# check loop split optimization
|
||||||
if fmt == torch.channels_last:
|
if fmt == torch.channels_last:
|
||||||
|
|
|
||||||
|
|
@ -3187,27 +3187,44 @@ def native_group_norm(
|
||||||
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||||
|
input_acc = _maybe_convert_to_dtype(input, computation_dtype)
|
||||||
# num_channels / num_groups and flattened inner dimension are the reduction axes
|
# num_channels / num_groups and flattened inner dimension are the reduction axes
|
||||||
reduction_dims = [2, 3]
|
reduction_dims = [2, 3]
|
||||||
input_reshaped = torch.reshape(
|
input_reshaped = torch.reshape(
|
||||||
input,
|
input_acc,
|
||||||
[batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
|
[batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
|
||||||
)
|
)
|
||||||
out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
|
reduction_dims = utils.canonicalize_dims(input_reshaped.ndim, reduction_dims)
|
||||||
out = out.view(input.shape)
|
biased_var, mean = torch.var_mean(
|
||||||
|
input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
|
||||||
broadcast_dims = [0] + list(range(2, input.ndim))
|
)
|
||||||
unsqueeze_bias = None
|
rstd = torch.rsqrt(biased_var + eps)
|
||||||
if bias is not None:
|
|
||||||
unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
|
|
||||||
unsqueeze_weight = None
|
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
|
weight_reshaped = torch.reshape(
|
||||||
|
weight, [1, num_groups, num_channels // num_groups, 1]
|
||||||
|
)
|
||||||
|
w = rstd * weight_reshaped
|
||||||
|
b = -mean * w
|
||||||
|
else:
|
||||||
|
w = rstd
|
||||||
|
b = -mean * rstd
|
||||||
|
if bias is not None:
|
||||||
|
bias_reshaped = torch.reshape(
|
||||||
|
bias, [1, num_groups, num_channels // num_groups, 1]
|
||||||
|
)
|
||||||
|
b = b + bias_reshaped
|
||||||
|
|
||||||
if unsqueeze_weight is not None:
|
if input.device.type == "cpu" and weight is not None:
|
||||||
out = out * unsqueeze_weight
|
w = w.contiguous().as_strided([batch_size, num_channels], [num_channels, 1])
|
||||||
if unsqueeze_bias is not None:
|
b = b.contiguous().as_strided([batch_size, num_channels], [num_channels, 1])
|
||||||
out = out + unsqueeze_bias
|
broadcast_dims = list(range(2, input.ndim))
|
||||||
|
unsqueeze_w = _unsqueeze_multiple(w, broadcast_dims)
|
||||||
|
unsqueeze_b = _unsqueeze_multiple(b, broadcast_dims)
|
||||||
|
out = input_acc * unsqueeze_w + unsqueeze_b
|
||||||
|
else:
|
||||||
|
out = input_reshaped * w + b
|
||||||
|
out = out.view(input.shape)
|
||||||
|
|
||||||
out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
|
out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
|
||||||
mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
|
mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user