pytorch/test/inductor/test_codegen_triton.py
David Berard becb8dc91a [inductor] triton_utils.config_of: check for divisibility by 16, even when expr is not an Integer (#105743)
TL;DR: triton_utils.config_of determines divisibility by 16 for each of the inputs to the kernel (pointer alignment for pointers, and divisibility by 16 for sizes). For sizes, the check previously could only return true if the expr representing the size was an integer. However, it's possible for non-integral exprs to be divisible by 16, e.g. for an expr like 16*s0.

Motivation: Knowledge about divisibility by 16 allows for vectorizing loads and stores, which can improve memory bandwidth. If we have, for example, kernels with shape [s0, 16] (dynamic batch size; static, divisible-by-16 other dimensions), we want to still be able to vectorize those loads and stores.

Dashboard results suggest that this improves dynamic shape training performance for timm, and possibly a small improvement for torchbench as well. More details are provided in a comment below.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105743
Approved by: https://github.com/ezyang, https://github.com/aakhundov
2023-07-24 22:41:50 +00:00

78 lines
2.3 KiB
Python

# Owner(s): ["module: inductor"]
import contextlib
import sympy
import torch
import torch._inductor.config as inductor_config
from torch._inductor.codegen import triton_utils
from torch._inductor.codegen.common import SizeArg
from torch._inductor.graph import GraphLowering
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import TestCase as TorchTestCase
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
class TestCodegenTriton(TorchTestCase):
def setUp(self):
super().setUp()
class DummyModule(torch.nn.Module):
def forward(self, x):
return x * 2
self._gm = torch.fx.symbolic_trace(DummyModule())
self._graph = GraphLowering(self._gm)
self._stack = contextlib.ExitStack()
self._stack.enter_context(V.set_graph_handler(self._graph))
def tearDown(self):
self._stack.close()
super().tearDown()
@inductor_config.patch("triton.divisible_by_16", True)
def test_config_of_sizearg(self):
two = sympy.Integer(2)
eight = sympy.Integer(8)
sixteen = sympy.Integer(16)
s0 = sympy.Symbol("s0", positive=True, integer=True)
s1 = sympy.Symbol("s1", positive=True, integer=True)
self.assertEqual(
(2,),
triton_utils.config_of(
[
SizeArg("A", two), # no
SizeArg("B", eight), # no
SizeArg("C", sixteen), # yes
SizeArg("D", s0), # no
SizeArg("E", s1), # no
]
).divisible_by_16,
)
self.assertEqual(
(0, 2, 4, 5, 6),
triton_utils.config_of(
[
SizeArg("A", two * eight), # 0: yes
SizeArg("B", eight * s0), # 1: no
SizeArg("C", two * eight * s0), # 2: yes
SizeArg("D", s0 * s1), # 3: no
SizeArg("E", sixteen * s0), # 4: yes
SizeArg("F", sixteen * eight * s0 * s1), # 5: yes
SizeArg("G", two * eight * s0 * s1), # 6: yes
]
).divisible_by_16,
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if HAS_CPU or HAS_CUDA:
run_tests("sympy")