mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix launch grid calculation (#159497)
Summary: The launch grid calculation code is using a python trick to achieve CeilDiv() through negative integer division with FloorDiv(). This is language dependent behaviour that doesn't apply to all languages. In the FXIR backend we negate this behaviour and replace the experssion with CeilDiv() operation so the computation is correct regardless of language used. Not directly directly changing the orginal computation as it leads to a performance degredation. Test Plan: CI Rollback Plan: Differential Revision: D79275534 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159497 Approved by: https://github.com/blaine-rister
This commit is contained in:
parent
d33a484763
commit
426f249f20
|
|
@ -411,6 +411,28 @@ class FxirTestCase(InductorTestCase):
|
|||
)
|
||||
self.assertIn("ks0", triton_node.kwargs["kwargs"])
|
||||
|
||||
def test_dynamic_launch_grid_calc(self):
|
||||
"""
|
||||
Test the dyanmic launch grid calculation for Triton kernel wrapper
|
||||
"""
|
||||
func = torch.add
|
||||
args = [torch.randn(shape, device=self.device) for shape in [(7, 12), (7, 1)]]
|
||||
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
||||
|
||||
# Check for the precomputed size arg.
|
||||
(triton_node,) = gm.graph.find_nodes(
|
||||
op="call_function", target=triton_kernel_wrapper_mutation
|
||||
)
|
||||
self.assertIn("grid", triton_node.kwargs)
|
||||
self.assertIn("xnumel", triton_node.kwargs["kwargs"])
|
||||
self.assertIn("XBLOCK", triton_node.kwargs["kwargs"])
|
||||
grid = triton_node.kwargs["grid"][0]
|
||||
xblock = triton_node.kwargs["kwargs"]["XBLOCK"]
|
||||
xnumel = triton_node.kwargs["kwargs"]["xnumel"]
|
||||
self.assertEqual(grid[0].node.expr, ((xnumel + xblock - 1) // xblock))
|
||||
self.assertEqual(grid[1], 1)
|
||||
self.assertEqual(grid[2], 1)
|
||||
|
||||
@config.patch({"trace.enabled": True})
|
||||
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
|
||||
def test_debug(self, mock_output_code):
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from torch._inductor.virtualized import V
|
|||
from torch._library.triton import wrap_triton
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._sympy.functions import FloorDiv
|
||||
from torch.utils._sympy.functions import CeilDiv
|
||||
|
||||
from .. import config, ir
|
||||
from ..utils import convert_shape_to_symint, convert_to_symint, LineContext
|
||||
|
|
@ -581,8 +581,8 @@ class FxConverter:
|
|||
assert V.graph.sizevars.statically_known_equals(new_expr, expr), (
|
||||
f"Unsound replacement: '{new_expr}' != '{expr}'"
|
||||
)
|
||||
|
||||
return FloorDiv(numerator, denominator)
|
||||
# Undo the python division trick and replace with explicit CeilDiv
|
||||
return -CeilDiv(-numerator, denominator)
|
||||
else:
|
||||
return sympy.floor(expr)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user