[inductor] ELEMENTS_PER_WARP_32 -> ONE_ELEMENT_PER_THREAD (#136472)

AMD devices have 64 elements per thread; this PR makes the handling of the "ELEMENTS_PER_WARP_32" generic and uses DeviceProperties.warp_size to determine the warp size instead of hard-coding the warp size as 32. It also renames the enum value. Added a unit test for this.

Note: I left the old enum option (ELEMENTS_PER_WARP_32) as is instead of renaming it. I'm not sure whether we expect should caches to get invalidated here; if this concern is valid, then there's a risk that this would get updated, but some model could use the cached inductor code, which would reference "ELEMENTS_PER_WARP_32", which would no longer exist.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136472
Approved by: https://github.com/jansel
This commit is contained in:
David Berard 2024-09-23 22:23:30 -07:00 committed by PyTorch MergeBot
parent a259fbf72c
commit 9c2c61d2dd
4 changed files with 49 additions and 8 deletions

View File

@ -18,12 +18,17 @@ except ImportError:
from torch._inductor import config
from torch._inductor.runtime.hints import (
AutotuneHint,
DeviceProperties,
HeuristicType,
TRITON_MAX_BLOCK,
)
from torch._inductor.runtime.triton_helpers import math as tl_math
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, triton_config
from torch._inductor.runtime.triton_heuristics import (
autotune_hints_to_configs,
CachingAutotuner,
triton_config,
)
from torch._inductor.test_case import run_tests, TestCase
@ -140,6 +145,36 @@ class TestTritonHeuristics(TestCase):
with self.assertRaisesRegex(AssertionError, "pre_hook"):
autotuner = CachingAutotuner(**args)
def test_autotune_hints_to_configs(self):
device_props = DeviceProperties.create(torch.device("cuda"))
device_props = device_props._replace(warp_size=8)
hints = {AutotuneHint.ONE_ELEMENT_PER_THREAD}
size_hints = (1024,)
block_size = 256
seen_num_elements_per_warp = set()
def mock_triton_config(
size_hints,
x,
y=None,
z=None,
num_stages=None,
num_elements_per_warp=None,
min_elem_per_thread=None,
):
seen_num_elements_per_warp.add(num_elements_per_warp)
return None
with unittest.mock.patch(
"torch._inductor.runtime.triton_heuristics.triton_config",
mock_triton_config,
):
_ = autotune_hints_to_configs(hints, size_hints, block_size, device_props)
self.assertTrue(8 in seen_num_elements_per_warp)
if __name__ == "__main__":
if IS_LINUX and HAS_GPU:

View File

@ -1877,8 +1877,8 @@ class TritonKernel(SIMDKernel):
# Triton performance for bucketize_binary_search is much better when the number
# of threads equals the number of elements.
# If we're trying to use a bucketize kernel, we should make sure that an
# autotuning config with num_elements_per_warp=32 exists.
self.autotune_hints.add(AutotuneHint.ELEMENTS_PER_WARP_32)
# autotuning config with num_elements_per_warp=(warp_size) exists.
self.autotune_hints.add(AutotuneHint.ONE_ELEMENT_PER_THREAD)
offsets_ptr = self.args.input(offsets_name)
block_size = self.dense_size_str()

View File

@ -85,7 +85,7 @@ class HeuristicType(Enum):
class AutotuneHint(Enum):
ELEMENTS_PER_WARP_32 = 0
ONE_ELEMENT_PER_THREAD = 0
# Triton codegen tries to codegen set of AutotuneHints.
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""

View File

@ -85,7 +85,10 @@ log = logging.getLogger(__name__)
def autotune_hints_to_configs(
hints: Set[AutotuneHint], size_hints, block_size: int
hints: Set[AutotuneHint],
size_hints,
block_size: int,
device_props: DeviceProperties,
) -> List[Config]:
"""
AutotuneHints can be attached to the metadata of triton kernels for providing
@ -100,7 +103,7 @@ def autotune_hints_to_configs(
configs = []
for hint in hints:
if hint == AutotuneHint.ELEMENTS_PER_WARP_32:
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
if len(size_hints) == 1:
xyz_options = ((block_size // 4, None, None),)
elif len(size_hints) == 2:
@ -116,7 +119,7 @@ def autotune_hints_to_configs(
triton_config(
size_hints,
*xyz,
num_elements_per_warp=32,
num_elements_per_warp=device_props.warp_size,
)
)
@ -1340,7 +1343,10 @@ def pointwise(
bs = max(256, min(numel // 128, 1024))
hinted_configs = autotune_hints_to_configs(
inductor_meta.get("autotune_hints", set()), size_hints, bs
inductor_meta.get("autotune_hints", set()),
size_hints,
bs,
triton_meta["device"],
)
triton_config_with_settings = functools.partial(