Add less warps config to inner reductions (#162447)

Add less warps to ensure proper vectorization + memory coalescing for inner reductions, prefer more work per thread

<img width="1717" height="731" alt="Screenshot 2025-09-17 at 10 03 25 AM" src="https://github.com/user-attachments/assets/7b1f4a30-62f2-4bee-bb9c-122501bde63e" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162447
Approved by: https://github.com/v0i0, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
PaulZhang12 2025-10-08 13:00:40 -07:00 committed by PyTorch MergeBot
parent d386325ca9
commit f05e23e1bc

View File

@ -2399,6 +2399,7 @@ def triton_config_reduction(
num_warps=None,
register_intensive=False,
dynamic_scale_rblock=True,
reduction_hint=None,
) -> Config:
"""
Construct a reduction triton config with some adjustment heuristics
@ -2426,7 +2427,13 @@ def triton_config_reduction(
rnumels[prefix] *= 2
if num_warps is None:
num_warps = total_numel() // 128
if reduction_hint == ReductionHint.INNER and not is_fbcode():
# r is contiguous, so ensure that each thread has 8 elements for
# vectorized loads, assuming bf16/fp16
# xblock is usually 1-2, default to giving each thread more work
num_warps = r // 128
else:
num_warps = total_numel() // 128
max_num_warps = 16 if r <= 8192 else 32
num_warps = _num_warps(
@ -2696,6 +2703,7 @@ def _reduction_configs(
num_stages=num_stages,
register_intensive=register_intensive,
dynamic_scale_rblock=dynamic_scale_rblock,
reduction_hint=reduction_hint,
)
def outer_config_opt():
@ -2747,7 +2755,7 @@ def _reduction_configs(
)
contiguous_config = make_config(
1,
2 if rnumel <= 2048 and not is_fbcode() else 1, # 1024 or less is persistent
min(rnumel, MAX_R0_BLOCK),
register_intensive=register_intensive,
)
@ -3085,7 +3093,13 @@ def _persistent_reduction_configs(
if "y" not in size_hints:
configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
triton_config_reduction(
size_hints,
xblock,
rnumel,
register_intensive=True,
reduction_hint=reduction_hint,
)
for xblock in (1, 8, 32, 128)
if xblock == 1
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
@ -3125,6 +3139,7 @@ def _persistent_reduction_configs(
x_block,
rnumel,
register_intensive=True,
reduction_hint=reduction_hint,
)
]
@ -3136,6 +3151,7 @@ def _persistent_reduction_configs(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
rnumel,
reduction_hint=reduction_hint,
)
]
for c in configs: