pytorch/test/inductor/test_triton_heuristics.py
Shunting Zhang 95cacb7fa9 [reland][inductor] make thread order consistent with loop order (#107902)
This PR relands https://github.com/pytorch/pytorch/pull/106827 which get reverted because of causing compilation error for some ads model.

Yanbo provide a repro in one of the 14k model ( `pytest ./generated/test_KaiyangZhou_deep_person_reid.py -k test_044`). This is also the model I used to confirm the fix and come up with a unit test. In this model, we call `tritoin_heuristics.triton_config` with size_hints [2048, 2]. Previously this would result in a trition config with XBLOCK=2048 and YBLOCK=2 . But since we change the mapping between size_hints and XYZ dimension, we now generate a triton config with XBLOCK=2 and YBLOCK=2048.  This fails compilation since we set max YBLOCK to be 1024.

My fix is to make sure we never generate a triton config that exceeds the maximum block size.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107902
Approved by: https://github.com/jansel
2023-08-26 02:56:20 +00:00

37 lines
1001 B
Python

# Owner(s): ["module: inductor"]
import sys
import unittest
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
try:
import triton # noqa: F401
except ImportError:
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires triton")
from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import config
from torch._inductor.triton_heuristics import triton_config
class TestTritonHeuristics(TestCase):
def test_triton_config(self):
"""
Make sure block size does not exceed the maximum defined in inductor config.
"""
cfg = triton_config([2048, 2], 64, 64)
for label in "XYZ":
key = f"{label}BLOCK"
if key not in cfg.kwargs:
continue
self.assertTrue(cfg.kwargs[key] <= config.triton.max_block[label])
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()