mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add less warps config to inner reductions (#162447)
Add less warps to ensure proper vectorization + memory coalescing for inner reductions, prefer more work per thread <img width="1717" height="731" alt="Screenshot 2025-09-17 at 10 03 25 AM" src="https://github.com/user-attachments/assets/7b1f4a30-62f2-4bee-bb9c-122501bde63e" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/162447 Approved by: https://github.com/v0i0, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
parent
d386325ca9
commit
f05e23e1bc
|
|
@ -2399,6 +2399,7 @@ def triton_config_reduction(
|
|||
num_warps=None,
|
||||
register_intensive=False,
|
||||
dynamic_scale_rblock=True,
|
||||
reduction_hint=None,
|
||||
) -> Config:
|
||||
"""
|
||||
Construct a reduction triton config with some adjustment heuristics
|
||||
|
|
@ -2426,7 +2427,13 @@ def triton_config_reduction(
|
|||
rnumels[prefix] *= 2
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = total_numel() // 128
|
||||
if reduction_hint == ReductionHint.INNER and not is_fbcode():
|
||||
# r is contiguous, so ensure that each thread has 8 elements for
|
||||
# vectorized loads, assuming bf16/fp16
|
||||
# xblock is usually 1-2, default to giving each thread more work
|
||||
num_warps = r // 128
|
||||
else:
|
||||
num_warps = total_numel() // 128
|
||||
|
||||
max_num_warps = 16 if r <= 8192 else 32
|
||||
num_warps = _num_warps(
|
||||
|
|
@ -2696,6 +2703,7 @@ def _reduction_configs(
|
|||
num_stages=num_stages,
|
||||
register_intensive=register_intensive,
|
||||
dynamic_scale_rblock=dynamic_scale_rblock,
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
|
||||
def outer_config_opt():
|
||||
|
|
@ -2747,7 +2755,7 @@ def _reduction_configs(
|
|||
)
|
||||
|
||||
contiguous_config = make_config(
|
||||
1,
|
||||
2 if rnumel <= 2048 and not is_fbcode() else 1, # 1024 or less is persistent
|
||||
min(rnumel, MAX_R0_BLOCK),
|
||||
register_intensive=register_intensive,
|
||||
)
|
||||
|
|
@ -3085,7 +3093,13 @@ def _persistent_reduction_configs(
|
|||
|
||||
if "y" not in size_hints:
|
||||
configs = [
|
||||
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
|
||||
triton_config_reduction(
|
||||
size_hints,
|
||||
xblock,
|
||||
rnumel,
|
||||
register_intensive=True,
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
for xblock in (1, 8, 32, 128)
|
||||
if xblock == 1
|
||||
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
|
||||
|
|
@ -3125,6 +3139,7 @@ def _persistent_reduction_configs(
|
|||
x_block,
|
||||
rnumel,
|
||||
register_intensive=True,
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -3136,6 +3151,7 @@ def _persistent_reduction_configs(
|
|||
size_hints,
|
||||
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
||||
rnumel,
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
]
|
||||
for c in configs:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user