Fix launch grid calculation (#159497)

Summary:

The launch grid calculation code is using a python trick to achieve CeilDiv() through negative integer division with FloorDiv(). This is language dependent behaviour that doesn't apply to all languages.

In the FXIR backend we negate this behaviour and replace the experssion with CeilDiv() operation so the computation is correct regardless of language used. Not directly directly changing the orginal computation as it leads to a performance degredation.

Test Plan:
CI

Rollback Plan:

Differential Revision: D79275534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159497
Approved by: https://github.com/blaine-rister
This commit is contained in:
nandesuka 2025-08-02 01:12:58 +00:00 committed by PyTorch MergeBot
parent d33a484763
commit 426f249f20
2 changed files with 25 additions and 3 deletions

View File

@ -411,6 +411,28 @@ class FxirTestCase(InductorTestCase):
)
self.assertIn("ks0", triton_node.kwargs["kwargs"])
def test_dynamic_launch_grid_calc(self):
"""
Test the dyanmic launch grid calculation for Triton kernel wrapper
"""
func = torch.add
args = [torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]]
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
# Check for the precomputed size arg.
(triton_node,) = gm.graph.find_nodes(
op="call_function", target=triton_kernel_wrapper_mutation
)
self.assertIn("grid", triton_node.kwargs)
self.assertIn("xnumel", triton_node.kwargs["kwargs"])
self.assertIn("XBLOCK", triton_node.kwargs["kwargs"])
grid = triton_node.kwargs["grid"][0]
xblock = triton_node.kwargs["kwargs"]["XBLOCK"]
xnumel = triton_node.kwargs["kwargs"]["xnumel"]
self.assertEqual(grid[0].node.expr, ((xnumel + xblock - 1) // xblock))
self.assertEqual(grid[1], 1)
self.assertEqual(grid[2], 1)
@config.patch({"trace.enabled": True})
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
def test_debug(self, mock_output_code):

View File

@ -22,7 +22,7 @@ from torch._inductor.virtualized import V
from torch._library.triton import wrap_triton
from torch.fx import GraphModule
from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import FloorDiv
from torch.utils._sympy.functions import CeilDiv
from .. import config, ir
from ..utils import convert_shape_to_symint, convert_to_symint, LineContext
@ -581,8 +581,8 @@ class FxConverter:
assert V.graph.sizevars.statically_known_equals(new_expr, expr), (
f"Unsound replacement: '{new_expr}' != '{expr}'"
)
return FloorDiv(numerator, denominator)
# Undo the python division trick and replace with explicit CeilDiv
return -CeilDiv(-numerator, denominator)
else:
return sympy.floor(expr)