[Inductor] Expand Identity ops prior to block pattern matching (#146000)

# Feature

Inductor sometimes uses `Identity` functions to group various terms of an expression. While this is convenient in some scenarios, it can frustrate pattern matching. For example, when we're matching an indexing expression to tell if it can be represented as a block pointer, that analysis should be invariant to `Identity`'s.

This PR adds a few features to achieve this invariance.
 - Create a new expansion mode `expr.expand(identity=True)`, which removes all `Identity` functions from the expression.
 -  Preprocess the expression with this expansion prior to pattern matching.
 - Bonus: create a new test utility function called `dummy_graph()`, which creates a simple `GraphLowering`. This is useful for testing the pattern matcher, as we need to initialize `V.graph` before we can access `V.graph.sizevars`.

# Test plan
This PR adds a few new unit tests:
 - Added a unit test specifically for `expr.expand(identity=True)`.
 - Added a new unit test module for the block pattern matcher. Tested that we can correctly match some example patterns containing Identity ops.

I originally intended to add an end to end test compiling pointwise cat, and mapping the corresponding memory accesses to block pointers. However, it looks like that will take more work, since the [relevant code path](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L1306) disables block pointer analysis. It might be better to defer that to a future PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146000
Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
Blaine Burton Rister 2025-02-08 18:11:53 +00:00 committed by PyTorch MergeBot
parent eee5622b98
commit a1bfb39a31
6 changed files with 170 additions and 9 deletions

View File

@ -0,0 +1,102 @@
# Owner(s): ["module: inductor"]
import sympy
import torch
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.testing._internal.inductor_utils import dummy_graph
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
# Some useful symbols
x, y = sympy.symbols("x y")
@instantiate_parametrized_tests
class BlockAnalysisTest(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Create a GraphLowering, so we can access V.graph.
cls.graph = dummy_graph()
@parametrize(
"stride,symbol,expr",
[
(5, x, Identity(5 * x)),
(4, y, 4 * Identity(y)),
(3, x, Identity(3) * x),
],
)
def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr):
# Test that we can handle an identity expression in affine indexing.
matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol)
self.assertEqual(matched_stride, stride)
@parametrize(
"dims,strides,symbol,expr",
[
(
(2, 4),
(4, 1),
x,
4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4),
),
(
(3, 9),
(5, 2),
x,
5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9),
),
((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))),
],
)
def test_mod_div_identity(
self,
dims: tuple[int],
strides: tuple[int],
symbol: sympy.Symbol,
expr: sympy.Expr,
):
# Test that we can handle an identity expression in modular indexing.
numel = int(torch.prod(torch.Tensor(dims)))
num_dims = len(dims)
with V.set_graph_handler(self.graph):
match_result = BlockPatternMatcher.match_mod_div_block_expr(
expr, symbol, numel, num_dims
)
# Check the matched block dimensions.
self.assertNotEqual(match_result, None)
matched_dims, matched_strides, matched_block_index_exprs = match_result
self.assertEqual(matched_dims, dims)
self.assertEqual(matched_strides, strides)
@parametrize(
"symbol,expr,subexpr",
[
(x, Identity(x), x),
(x, Identity(x + 5), x),
(y, Identity(x + 2 * y) + 5, 2 * y),
],
)
def test_subexpr_identity(
self,
symbol: sympy.Symbol,
expr: sympy.Expr,
subexpr: sympy.Expr,
):
matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
self.assertEqual(matched_subexpr, subexpr)
if __name__ == "__main__":
run_tests()

View File

@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import (
)
from torch.utils._sympy.functions import (
FloorDiv,
Identity,
OpaqueUnaryFn_cos,
simple_floordiv_gcd,
)
@ -955,6 +956,17 @@ class TestSingletonInt(TestCase):
self.assertEqual(j1.free_symbols, set())
class TestIdentity(TestCase):
def test_expand_identity(self):
"""
Test removing an identity via expansion.
"""
x = sympy.Symbol("x")
arg = x + sympy.S.One
expr = Identity(arg)
expanded = expr.expand(identity=True)
self.assertEqual(expanded.count(Identity), 0)
self.assertEqual(expanded, arg)
instantiate_parametrized_tests(TestValueRanges)
instantiate_parametrized_tests(TestSympyInterp)

View File

@ -17,8 +17,8 @@ class BlockPatternMatcher:
Matches block indexing expressions.
"""
@staticmethod
def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr:
@classmethod
def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
"""
Given a sympy expression, return the subexpression comprised only of terms
involving the specified symbol.
@ -26,6 +26,7 @@ class BlockPatternMatcher:
For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
this returns `x * 5 + x ** 2`.
"""
expr = cls._preprocess(expr)
return sympy.S.Zero + sum(
term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
)
@ -42,6 +43,11 @@ class BlockPatternMatcher:
numels.appendleft(numel)
return [*numels]
@staticmethod
def _preprocess(expr: Expr) -> Expr:
# Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y.
return expr.expand(identity=True)
@classmethod
def match_mod_div_block_expr(
cls,
@ -54,6 +60,7 @@ class BlockPatternMatcher:
Matches modular indexing expressions, converting them to implied block dimensions and strides.
See triton.py for more information.
"""
index = cls._preprocess(index)
# Pattern match to find the strides and offset.
wild = functools.partial(sympy.Wild, exclude=[index_var])
@ -141,3 +148,21 @@ class BlockPatternMatcher:
)
return dims, strides, block_index_exprs
@classmethod
def match_affine_block_expr(
cls,
index: Expr,
index_var: Symbol,
) -> Optional[Expr]:
"""
Matches simple expressions of the form stride * index, returning the
stride.
"""
index = cls._preprocess(index)
stride = sympy.Wild("stride", exclude=[index_var])
m = index.match(index_var * stride)
if m is None:
return None
return m[stride]

View File

@ -1790,7 +1790,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
and self.index_dtype == "tl.int32"
):
def match_strided_block(
def match_affine_block(
index: sympy.Expr, range_tree: IterationRangesRoot
) -> Optional[BlockParameters]:
"""
@ -1799,16 +1799,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
This implies stride (s,), and shape (XBLOCK,).
"""
symbol = range_tree.symbol()
stride = sympy.Wild("stride", exclude=[symbol])
m = index.match(symbol * stride)
if m is None:
stride = BlockPatternMatcher.match_affine_block_expr(
index, range_tree.symbol()
)
if stride is None:
return None
return BlockParameters(
shape=[range_tree.numel],
block_shape=[TritonSymbols.get_block_size(range_tree)],
strides=[m[stride]],
strides=[stride],
offsets=[TritonSymbols.get_block_offset(range_tree)],
)
@ -1917,7 +1917,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
Match a block indexing subexpression involving a single range tree.
"""
for match_func in (
match_strided_block,
match_affine_block,
match_mod_div_block,
):
match = match_func(expr, range_tree)

View File

@ -10,6 +10,9 @@ import os
from subprocess import CalledProcessError
import sys
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch.fx.experimental.proxy_tensor import make_fx
from torch._inductor.graph import GraphLowering
from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.codecache import CppCodeCache
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
from torch._inductor.utils import GPU_TYPES, get_gpu_type
@ -142,6 +145,21 @@ IS_H100 = LazyVal(
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())
def dummy_graph() -> GraphLowering:
"""
Create a graph. This is useful for unit testing code which accesses
V.graph.sizevars.
"""
example_inputs = [torch.randn(10) for _ in range(2)]
gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs)
shape_env = shape_env_from_inputs(example_inputs)
graph = GraphLowering(
gm,
shape_env=shape_env,
)
return graph
def maybe_skip_size_asserts(op):
"""
For certain ops, there meta and eager implementation returns differents

View File

@ -1286,6 +1286,10 @@ class Identity(sympy.Function):
def _eval_is_integer(self):
return self.args[0].is_integer # type: ignore[attr-defined]
def _eval_expand_identity(self, **hints):
# Removes the identity op.
return self.args[0]
def make_opaque_unary_fn(name):
class OpaqueUnaryFn(sympy.Function):