From dfebdcab86acbaa0eaa996b47595e5f27a66492e Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 31 Oct 2025 21:24:05 +0000 Subject: [PATCH] [GraphPartition] cache get_free_symbol_uses (#166338) Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166338 Approved by: https://github.com/eellison --- test/inductor/test_torchinductor.py | 23 ++++++ torch/_inductor/ir.py | 111 +++++++++++++++++++++++++++- torch/_inductor/utils.py | 47 ++++++++++++ 3 files changed, 177 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 219a2585e34..161399bb667 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11013,6 +11013,29 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar p = torch.tensor(0.50, device=self.device) get_mask(x, p) + def test_flexible_layout_immutable_free_symbols(self): + import sympy + + x = sympy.Symbol("x") + y = sympy.Symbol("y") + z = sympy.Symbol("z") + + layout = torch._inductor.ir.FlexibleLayout( + self.device, torch.float32, size=(x, y) + ) + + # pad_strides works since it does not add new symints + layout.pad_strides() + + # same symints and different order should work + layout.size = (y, x) + + # adding new symints should fail + with self.assertRaisesRegex( + AssertionError, "Expected free symbols unchanged, but got" + ): + layout.size = (z,) + def test_sqrt_dynamic_shapes(self): # TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877. # TODO: support cuda path. diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 39d32e41b4e..8e9bf06ddb7 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import ( compute_unbacked_bindings, free_symbols, free_unbacked_symbols, + IterateExprs, rebind_unbacked, resolve_unbacked_bindings, ShapeEnv, @@ -97,6 +98,7 @@ from .utils import ( argsort, argsort_sym, cache_on_self, + cache_on_self_and_args, ceildiv, convert_shape_to_inductor, convert_shape_to_symint, @@ -933,6 +935,7 @@ class Loops(IRNode): inner_fn: Callable[..., Any] ranges: Sequence[_IntLike] + @cache_on_self_and_args("Loops") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -1228,6 +1231,7 @@ class Reduction(Loops): __repr__ = __str__ + @cache_on_self_and_args("Reduction") def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union( *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges) @@ -2327,6 +2331,7 @@ class Scan(Loops): # HACK we mimic reduction + @cache_on_self_and_args("Scan") def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we # need to explicitly represent the closure so we can pull out unbacked @@ -2537,6 +2542,7 @@ class Sort(Loops): # HACK we mimic reduction + @cache_on_self_and_args("Sort") def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: return ( super().get_free_symbol_uses(unbacked_only) @@ -2785,6 +2791,7 @@ def is_unaligned(node: IRNode) -> bool: class BaseView(IRNode): data: IRNode + @cache_on_self_and_args("BaseView") def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: return self.data.get_free_symbol_uses(unbacked_only) @@ -3359,6 +3366,7 @@ class ReinterpretView(BaseView): def freeze_layout(self) -> None: pass + @cache_on_self_and_args("ReinterpretView") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -3643,13 +3651,37 @@ class Layout(OutputSpec): self.dtype = dtype assert len(size) == len(stride), f"size={size}, stride={stride}" assert all(isinstance(s, (Expr, int)) for s in size) - self.size = size - self.stride = stride - self.offset = offset + self._size = size + self._stride = stride + self._offset = offset self.is_pinned = is_pinned # is_pinned implies cpu assert (not self.is_pinned) or (self.device.type == "cpu") + @property + def size(self) -> Sequence[Expr]: + return self._size + + @size.setter + def size(self, value: Sequence[Expr]) -> None: + self._size = value + + @property + def stride(self) -> Sequence[Expr]: + return self._stride + + @stride.setter + def stride(self, value: Sequence[Expr]) -> None: + self._stride = value + + @property + def offset(self) -> Expr: + return self._offset + + @offset.setter + def offset(self, value: Expr) -> None: + self._offset = value + def __str__(self) -> str: offset = "" if self.offset != 0: @@ -3869,6 +3901,7 @@ class Layout(OutputSpec): def storage_size(self) -> Expr: return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type] + @cache_on_self_and_args("Layout") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -3888,7 +3921,11 @@ class FixedLayout(Layout): class FlexibleLayout(Layout): - """A Tensor layout that we are allowed to change""" + """ + A Tensor layout that we are allowed to change + + Assumption: layout change should NOT add or remove free symbols + """ allow_indexing = False @@ -3973,6 +4010,33 @@ class FlexibleLayout(Layout): fill_order = sorted(range(len(stride)), key=stride.__getitem__) return FlexibleLayout.fill_ordered(sizes, fill_order) + @property + def size(self) -> Sequence[Expr]: + return self._size + + @size.setter + def size(self, value: Sequence[Expr]) -> None: + self.assert_free_symbol_uses_unchanged("size", value) + self._size = value + + @property + def stride(self) -> Sequence[Expr]: + return self._stride + + @stride.setter + def stride(self, value: Sequence[Expr]) -> None: + self.assert_free_symbol_uses_unchanged("stride", value) + self._stride = value + + @property + def offset(self) -> Expr: + return self._offset + + @offset.setter + def offset(self, value: Expr) -> None: + self.assert_free_symbol_uses_unchanged("offset", value) + self._offset = value + def as_stride_order( self, order: Sequence[int], allow_padding: bool = False ) -> FixedLayout: @@ -4031,6 +4095,25 @@ class FlexibleLayout(Layout): self.is_pinned, ) + def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]: + initial_free_symbols = {} + for name in ["size", "stride", "offset"]: + for unbacked_only in [True, False]: + key = (name, unbacked_only) + initial_free_symbols[key] = OrderedSet( + get_free_symbols(getattr(self, name), unbacked_only) + ) + + return initial_free_symbols + + def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None: + for unbacked_only in [True, False]: + old_free_symbols = self.initial_free_symbols[(name, unbacked_only)] + new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only)) + assert new_free_symbols == old_free_symbols, ( + f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}" + ) + def __init__( self, device: torch.device, @@ -4045,6 +4128,10 @@ class FlexibleLayout(Layout): strides = FlexibleLayout.contiguous_strides(size) super().__init__(device, dtype, size, strides, is_pinned=is_pinned) + # record the initial free symbols to check that we do not add new free symbols + # later when modifying sizes, strides, and offsets. + self.initial_free_symbols = self.get_initial_free_symbol_uses() + class NonOwningLayout(Layout): """Is a view into the storage of another tensor""" @@ -4070,6 +4157,7 @@ class NonOwningLayout(Layout): return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) + @cache_on_self_and_args("NonOwningLayout") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4358,6 +4446,7 @@ class Buffer(IRNode, CodegenSymbol): def get_read_names(self) -> OrderedSet[str]: return OrderedSet([self.get_name()]) + @cache_on_self_and_args("Buffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4430,6 +4519,7 @@ class NoneAsConstantBuffer(IRNode): def get_reads(self) -> OrderedSet[Dep]: return OrderedSet() + @cache_on_self_and_args("NoneAsConstantBuffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4449,6 +4539,7 @@ class NoneAsConstantBuffer(IRNode): class ShapeAsConstantBuffer(IRNode): expr: Expr + @cache_on_self_and_args("ShapeAsConstantBuffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4521,6 +4612,7 @@ class ComputedBuffer(OperationBuffer): self.data.get_size(), ) + @cache_on_self_and_args("ComputedBuffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4974,6 +5066,7 @@ class TritonTemplateBuffer(TemplateBuffer): self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None self.subgraph_outs: Optional[list[Optional[IRNode]]] = None + @cache_on_self_and_args("TritonTemplateBuffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -5340,6 +5433,7 @@ class InputsKernel(OperationBuffer): def num_reads(self) -> int: return 1 + @cache_on_self_and_args("InputsKernel") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -5514,6 +5608,7 @@ class ConcatKernel(NopKernel): and not isinstance(src.data, ExternKernelAlloc) ) + @cache_on_self_and_args("ConcatKernel") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -6430,6 +6525,7 @@ class ExternKernel(InputsKernel): index = sympy_subs(sympy.expand(index), replacement) return index, tuple(new_sizes) + @cache_on_self_and_args("ExternKernel") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -6889,6 +6985,7 @@ class UserDefinedTritonKernel(ExternKernel): original_fxnode_name=self.fx_node.name, ) + @cache_on_self_and_args("UserDefinedTritonKernel") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -7327,6 +7424,7 @@ class DynamicSelectStorageOffset(ExternKernel): def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet([self.unbacked_offset_symbol]) + @cache_on_self_and_args("DynamicSelectStorageOffset") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -7377,6 +7475,7 @@ class DynamicSliceSize(ExternKernel): def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet([self.unbacked_size_symbol]) + @cache_on_self_and_args("DynamicSliceSize") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -7441,6 +7540,7 @@ class AssertScalar(ExternKernel): def has_side_effects(self) -> bool: return True + @cache_on_self_and_args("AssertScalar") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -8115,6 +8215,7 @@ class MultiOutput(ExternKernel): self.indices = indices self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks + @cache_on_self_and_args("MultiOutput") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -8237,6 +8338,7 @@ class MutableBox(IRNode): def realize(self) -> Optional[str]: return self.data.realize() + @cache_on_self_and_args("MutableBox") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -9073,6 +9175,7 @@ class EffectfulKernel(FallbackKernel): class NonTensorObj(IRNode): + @cache_on_self_and_args("NonTensorObj") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3162e002a75..353d5fdd4cc 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -662,6 +662,7 @@ def tuple_sorted(x: tuple[_T, ...]) -> list[_T]: P = ParamSpec("P") RV = TypeVar("RV", covariant=True) +FN_TYPE = Callable[Concatenate[Any, P], RV] class CachedMethod(Protocol, Generic[P, RV]): @@ -709,6 +710,52 @@ def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]: return cache_on_self(fn) +def cache_on_self_and_args( + class_name: str, +) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]: + # include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls. + + def wrapper( + fn: FN_TYPE[P, RV], + ) -> FN_TYPE[P, RV]: + key = f"__{class_name}_{fn.__name__}_cache" + + # wrapper is likely on the hot path, compile a specialized version of it + ctx = {"fn": fn} + exec( + f"""\ + def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV: + args_kwargs = (args, tuple(sorted(kwargs.items()))) + + if not hasattr(self, "{key}"): + object.__setattr__(self, "{key}", {{}}) + + cache = self.{key} + + try: + return cache[args_kwargs] + except KeyError: + pass + + rv = fn(self, *args, **kwargs) + + cache[args_kwargs] = rv + return rv + """.lstrip(), + ctx, + ) + inner = functools.wraps(fn)(ctx["inner"]) + + def clear_cache(self: Any) -> None: + if hasattr(self, key): + delattr(self, key) + + inner.clear_cache = clear_cache # type: ignore[attr-defined] + return inner + + return wrapper + + def aggregate_origins( node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], ) -> OrderedSet[Node]: