mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Expand Identity ops prior to block pattern matching (#146000)
# Feature Inductor sometimes uses `Identity` functions to group various terms of an expression. While this is convenient in some scenarios, it can frustrate pattern matching. For example, when we're matching an indexing expression to tell if it can be represented as a block pointer, that analysis should be invariant to `Identity`'s. This PR adds a few features to achieve this invariance. - Create a new expansion mode `expr.expand(identity=True)`, which removes all `Identity` functions from the expression. - Preprocess the expression with this expansion prior to pattern matching. - Bonus: create a new test utility function called `dummy_graph()`, which creates a simple `GraphLowering`. This is useful for testing the pattern matcher, as we need to initialize `V.graph` before we can access `V.graph.sizevars`. # Test plan This PR adds a few new unit tests: - Added a unit test specifically for `expr.expand(identity=True)`. - Added a new unit test module for the block pattern matcher. Tested that we can correctly match some example patterns containing Identity ops. I originally intended to add an end to end test compiling pointwise cat, and mapping the corresponding memory accesses to block pointers. However, it looks like that will take more work, since the [relevant code path](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L1306) disables block pointer analysis. It might be better to defer that to a future PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146000 Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
parent
eee5622b98
commit
a1bfb39a31
102
test/inductor/test_block_analysis.py
Normal file
102
test/inductor/test_block_analysis.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import dummy_graph
|
||||
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
|
||||
|
||||
|
||||
# Some useful symbols
|
||||
x, y = sympy.symbols("x y")
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class BlockAnalysisTest(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
# Create a GraphLowering, so we can access V.graph.
|
||||
cls.graph = dummy_graph()
|
||||
|
||||
@parametrize(
|
||||
"stride,symbol,expr",
|
||||
[
|
||||
(5, x, Identity(5 * x)),
|
||||
(4, y, 4 * Identity(y)),
|
||||
(3, x, Identity(3) * x),
|
||||
],
|
||||
)
|
||||
def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr):
|
||||
# Test that we can handle an identity expression in affine indexing.
|
||||
matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol)
|
||||
self.assertEqual(matched_stride, stride)
|
||||
|
||||
@parametrize(
|
||||
"dims,strides,symbol,expr",
|
||||
[
|
||||
(
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
x,
|
||||
4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4),
|
||||
),
|
||||
(
|
||||
(3, 9),
|
||||
(5, 2),
|
||||
x,
|
||||
5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9),
|
||||
),
|
||||
((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))),
|
||||
],
|
||||
)
|
||||
def test_mod_div_identity(
|
||||
self,
|
||||
dims: tuple[int],
|
||||
strides: tuple[int],
|
||||
symbol: sympy.Symbol,
|
||||
expr: sympy.Expr,
|
||||
):
|
||||
# Test that we can handle an identity expression in modular indexing.
|
||||
numel = int(torch.prod(torch.Tensor(dims)))
|
||||
num_dims = len(dims)
|
||||
with V.set_graph_handler(self.graph):
|
||||
match_result = BlockPatternMatcher.match_mod_div_block_expr(
|
||||
expr, symbol, numel, num_dims
|
||||
)
|
||||
|
||||
# Check the matched block dimensions.
|
||||
self.assertNotEqual(match_result, None)
|
||||
matched_dims, matched_strides, matched_block_index_exprs = match_result
|
||||
self.assertEqual(matched_dims, dims)
|
||||
self.assertEqual(matched_strides, strides)
|
||||
|
||||
@parametrize(
|
||||
"symbol,expr,subexpr",
|
||||
[
|
||||
(x, Identity(x), x),
|
||||
(x, Identity(x + 5), x),
|
||||
(y, Identity(x + 2 * y) + 5, 2 * y),
|
||||
],
|
||||
)
|
||||
def test_subexpr_identity(
|
||||
self,
|
||||
symbol: sympy.Symbol,
|
||||
expr: sympy.Expr,
|
||||
subexpr: sympy.Expr,
|
||||
):
|
||||
matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
|
||||
self.assertEqual(matched_subexpr, subexpr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
from torch.utils._sympy.functions import (
|
||||
FloorDiv,
|
||||
Identity,
|
||||
OpaqueUnaryFn_cos,
|
||||
simple_floordiv_gcd,
|
||||
)
|
||||
|
|
@ -955,6 +956,17 @@ class TestSingletonInt(TestCase):
|
|||
|
||||
self.assertEqual(j1.free_symbols, set())
|
||||
|
||||
class TestIdentity(TestCase):
|
||||
def test_expand_identity(self):
|
||||
"""
|
||||
Test removing an identity via expansion.
|
||||
"""
|
||||
x = sympy.Symbol("x")
|
||||
arg = x + sympy.S.One
|
||||
expr = Identity(arg)
|
||||
expanded = expr.expand(identity=True)
|
||||
self.assertEqual(expanded.count(Identity), 0)
|
||||
self.assertEqual(expanded, arg)
|
||||
|
||||
instantiate_parametrized_tests(TestValueRanges)
|
||||
instantiate_parametrized_tests(TestSympyInterp)
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ class BlockPatternMatcher:
|
|||
Matches block indexing expressions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_subexpr_involving_symbol(expr: Expr, symbol: Symbol) -> Expr:
|
||||
@classmethod
|
||||
def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
|
||||
"""
|
||||
Given a sympy expression, return the subexpression comprised only of terms
|
||||
involving the specified symbol.
|
||||
|
|
@ -26,6 +26,7 @@ class BlockPatternMatcher:
|
|||
For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
|
||||
this returns `x * 5 + x ** 2`.
|
||||
"""
|
||||
expr = cls._preprocess(expr)
|
||||
return sympy.S.Zero + sum(
|
||||
term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
|
||||
)
|
||||
|
|
@ -42,6 +43,11 @@ class BlockPatternMatcher:
|
|||
numels.appendleft(numel)
|
||||
return [*numels]
|
||||
|
||||
@staticmethod
|
||||
def _preprocess(expr: Expr) -> Expr:
|
||||
# Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y.
|
||||
return expr.expand(identity=True)
|
||||
|
||||
@classmethod
|
||||
def match_mod_div_block_expr(
|
||||
cls,
|
||||
|
|
@ -54,6 +60,7 @@ class BlockPatternMatcher:
|
|||
Matches modular indexing expressions, converting them to implied block dimensions and strides.
|
||||
See triton.py for more information.
|
||||
"""
|
||||
index = cls._preprocess(index)
|
||||
|
||||
# Pattern match to find the strides and offset.
|
||||
wild = functools.partial(sympy.Wild, exclude=[index_var])
|
||||
|
|
@ -141,3 +148,21 @@ class BlockPatternMatcher:
|
|||
)
|
||||
|
||||
return dims, strides, block_index_exprs
|
||||
|
||||
@classmethod
|
||||
def match_affine_block_expr(
|
||||
cls,
|
||||
index: Expr,
|
||||
index_var: Symbol,
|
||||
) -> Optional[Expr]:
|
||||
"""
|
||||
Matches simple expressions of the form stride * index, returning the
|
||||
stride.
|
||||
"""
|
||||
index = cls._preprocess(index)
|
||||
stride = sympy.Wild("stride", exclude=[index_var])
|
||||
m = index.match(index_var * stride)
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return m[stride]
|
||||
|
|
|
|||
|
|
@ -1790,7 +1790,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
and self.index_dtype == "tl.int32"
|
||||
):
|
||||
|
||||
def match_strided_block(
|
||||
def match_affine_block(
|
||||
index: sympy.Expr, range_tree: IterationRangesRoot
|
||||
) -> Optional[BlockParameters]:
|
||||
"""
|
||||
|
|
@ -1799,16 +1799,16 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
|
||||
This implies stride (s,), and shape (XBLOCK,).
|
||||
"""
|
||||
symbol = range_tree.symbol()
|
||||
stride = sympy.Wild("stride", exclude=[symbol])
|
||||
m = index.match(symbol * stride)
|
||||
if m is None:
|
||||
stride = BlockPatternMatcher.match_affine_block_expr(
|
||||
index, range_tree.symbol()
|
||||
)
|
||||
if stride is None:
|
||||
return None
|
||||
|
||||
return BlockParameters(
|
||||
shape=[range_tree.numel],
|
||||
block_shape=[TritonSymbols.get_block_size(range_tree)],
|
||||
strides=[m[stride]],
|
||||
strides=[stride],
|
||||
offsets=[TritonSymbols.get_block_offset(range_tree)],
|
||||
)
|
||||
|
||||
|
|
@ -1917,7 +1917,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
Match a block indexing subexpression involving a single range tree.
|
||||
"""
|
||||
for match_func in (
|
||||
match_strided_block,
|
||||
match_affine_block,
|
||||
match_mod_div_block,
|
||||
):
|
||||
match = match_func(expr, range_tree)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ import os
|
|||
from subprocess import CalledProcessError
|
||||
import sys
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.compile_fx import shape_env_from_inputs
|
||||
from torch._inductor.codecache import CppCodeCache
|
||||
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
|
||||
from torch._inductor.utils import GPU_TYPES, get_gpu_type
|
||||
|
|
@ -142,6 +145,21 @@ IS_H100 = LazyVal(
|
|||
|
||||
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())
|
||||
|
||||
def dummy_graph() -> GraphLowering:
|
||||
"""
|
||||
Create a graph. This is useful for unit testing code which accesses
|
||||
V.graph.sizevars.
|
||||
"""
|
||||
example_inputs = [torch.randn(10) for _ in range(2)]
|
||||
gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs)
|
||||
shape_env = shape_env_from_inputs(example_inputs)
|
||||
graph = GraphLowering(
|
||||
gm,
|
||||
shape_env=shape_env,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def maybe_skip_size_asserts(op):
|
||||
"""
|
||||
For certain ops, there meta and eager implementation returns differents
|
||||
|
|
|
|||
|
|
@ -1286,6 +1286,10 @@ class Identity(sympy.Function):
|
|||
def _eval_is_integer(self):
|
||||
return self.args[0].is_integer # type: ignore[attr-defined]
|
||||
|
||||
def _eval_expand_identity(self, **hints):
|
||||
# Removes the identity op.
|
||||
return self.args[0]
|
||||
|
||||
|
||||
def make_opaque_unary_fn(name):
|
||||
class OpaqueUnaryFn(sympy.Function):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user