pytorch/torch/_refs
Sun, Jiayi b2862f1435 optimize the decomposition of aten.native_group_norm (#144733)
Summary:
Optimize the decomposition of aten.native_group_norm. Reduce unnecessary repeated operations by changing the order of operations for `mean`, `rstd`, `weight`, `bias `and `input`, which can improve performance when `flattened_inner_size `is large.

The original decomposition:
1. compute `mean `and `rstd`,
2. out = (x - mean) * rstd, compute in the range [N, C, *],
3. out = out * weight + bias, compute in the range [N, C, *],

The new decomposition:
1. compute `mean `and `rstd`,
2. new_weight = rstd * weight, new_bias = - mean * rstd * weight + bias, compute in the range [N, C],
3. out = out * new_weight + new_bias, compute in the range [N, C, *],

I tested the Inductor performance benchmark with this PR on both CPU and A100. On CPU, two torchbench models(functorch_dp_cifar10 and opacus_cifar10) have about 25% performance improvement, and two diffusion models(Stable Diffusion and Latent Consistency Model(LCM)) have about 2% performance improvement. On A100, no performance gains or regressions were seen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144733
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
2025-03-17 09:27:01 +00:00
..
linalg PEP585: More UP006 fixes (#146392) 2025-02-20 06:18:13 +00:00
nn PEP585: More UP006 fixes (#146392) 2025-02-20 06:18:13 +00:00
special [Inductor][CPP] fix torch logit decomposition (#145576) 2025-01-27 19:37:51 +00:00
__init__.py optimize the decomposition of aten.native_group_norm (#144733) 2025-03-17 09:27:01 +00:00
_conversions.py [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765) 2024-07-31 10:42:50 +00:00
fft.py PEP585 update - torch/_C torch/_decomp torch/_lazy torch/_library torch/_numpy torch/_prims torch/_refs torch/_strobelight (#145102) 2025-01-18 20:47:12 +00:00