mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] ELEMENTS_PER_WARP_32 -> ONE_ELEMENT_PER_THREAD (#136472)
AMD devices have 64 elements per thread; this PR makes the handling of the "ELEMENTS_PER_WARP_32" generic and uses DeviceProperties.warp_size to determine the warp size instead of hard-coding the warp size as 32. It also renames the enum value. Added a unit test for this. Note: I left the old enum option (ELEMENTS_PER_WARP_32) as is instead of renaming it. I'm not sure whether we expect should caches to get invalidated here; if this concern is valid, then there's a risk that this would get updated, but some model could use the cached inductor code, which would reference "ELEMENTS_PER_WARP_32", which would no longer exist. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136472 Approved by: https://github.com/jansel
This commit is contained in:
parent
a259fbf72c
commit
9c2c61d2dd
|
|
@ -18,12 +18,17 @@ except ImportError:
|
|||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.runtime.hints import (
|
||||
AutotuneHint,
|
||||
DeviceProperties,
|
||||
HeuristicType,
|
||||
TRITON_MAX_BLOCK,
|
||||
)
|
||||
from torch._inductor.runtime.triton_helpers import math as tl_math
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, triton_config
|
||||
from torch._inductor.runtime.triton_heuristics import (
|
||||
autotune_hints_to_configs,
|
||||
CachingAutotuner,
|
||||
triton_config,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
||||
|
||||
|
|
@ -140,6 +145,36 @@ class TestTritonHeuristics(TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, "pre_hook"):
|
||||
autotuner = CachingAutotuner(**args)
|
||||
|
||||
def test_autotune_hints_to_configs(self):
|
||||
device_props = DeviceProperties.create(torch.device("cuda"))
|
||||
device_props = device_props._replace(warp_size=8)
|
||||
|
||||
hints = {AutotuneHint.ONE_ELEMENT_PER_THREAD}
|
||||
size_hints = (1024,)
|
||||
block_size = 256
|
||||
|
||||
seen_num_elements_per_warp = set()
|
||||
|
||||
def mock_triton_config(
|
||||
size_hints,
|
||||
x,
|
||||
y=None,
|
||||
z=None,
|
||||
num_stages=None,
|
||||
num_elements_per_warp=None,
|
||||
min_elem_per_thread=None,
|
||||
):
|
||||
seen_num_elements_per_warp.add(num_elements_per_warp)
|
||||
return None
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._inductor.runtime.triton_heuristics.triton_config",
|
||||
mock_triton_config,
|
||||
):
|
||||
_ = autotune_hints_to_configs(hints, size_hints, block_size, device_props)
|
||||
|
||||
self.assertTrue(8 in seen_num_elements_per_warp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_GPU:
|
||||
|
|
|
|||
|
|
@ -1877,8 +1877,8 @@ class TritonKernel(SIMDKernel):
|
|||
# Triton performance for bucketize_binary_search is much better when the number
|
||||
# of threads equals the number of elements.
|
||||
# If we're trying to use a bucketize kernel, we should make sure that an
|
||||
# autotuning config with num_elements_per_warp=32 exists.
|
||||
self.autotune_hints.add(AutotuneHint.ELEMENTS_PER_WARP_32)
|
||||
# autotuning config with num_elements_per_warp=(warp_size) exists.
|
||||
self.autotune_hints.add(AutotuneHint.ONE_ELEMENT_PER_THREAD)
|
||||
|
||||
offsets_ptr = self.args.input(offsets_name)
|
||||
block_size = self.dense_size_str()
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class HeuristicType(Enum):
|
|||
|
||||
|
||||
class AutotuneHint(Enum):
|
||||
ELEMENTS_PER_WARP_32 = 0
|
||||
ONE_ELEMENT_PER_THREAD = 0
|
||||
|
||||
# Triton codegen tries to codegen set of AutotuneHints.
|
||||
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
|
||||
|
|
|
|||
|
|
@ -85,7 +85,10 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def autotune_hints_to_configs(
|
||||
hints: Set[AutotuneHint], size_hints, block_size: int
|
||||
hints: Set[AutotuneHint],
|
||||
size_hints,
|
||||
block_size: int,
|
||||
device_props: DeviceProperties,
|
||||
) -> List[Config]:
|
||||
"""
|
||||
AutotuneHints can be attached to the metadata of triton kernels for providing
|
||||
|
|
@ -100,7 +103,7 @@ def autotune_hints_to_configs(
|
|||
configs = []
|
||||
|
||||
for hint in hints:
|
||||
if hint == AutotuneHint.ELEMENTS_PER_WARP_32:
|
||||
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
|
||||
if len(size_hints) == 1:
|
||||
xyz_options = ((block_size // 4, None, None),)
|
||||
elif len(size_hints) == 2:
|
||||
|
|
@ -116,7 +119,7 @@ def autotune_hints_to_configs(
|
|||
triton_config(
|
||||
size_hints,
|
||||
*xyz,
|
||||
num_elements_per_warp=32,
|
||||
num_elements_per_warp=device_props.warp_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1340,7 +1343,10 @@ def pointwise(
|
|||
bs = max(256, min(numel // 128, 1024))
|
||||
|
||||
hinted_configs = autotune_hints_to_configs(
|
||||
inductor_meta.get("autotune_hints", set()), size_hints, bs
|
||||
inductor_meta.get("autotune_hints", set()),
|
||||
size_hints,
|
||||
bs,
|
||||
triton_meta["device"],
|
||||
)
|
||||
|
||||
triton_config_with_settings = functools.partial(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user