mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[GraphPartition] cache get_free_symbol_uses (#166338)
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.ee7434be82/torch/_inductor/scheduler.py (L4869-L4885)I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.ee7434be82/torch/_inductor/ir.py (L4541-L4543)This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166338 Approved by: https://github.com/eellison
This commit is contained in:
parent
b09fb481e0
commit
dfebdcab86
|
|
@ -11013,6 +11013,29 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||||
p = torch.tensor(0.50, device=self.device)
|
p = torch.tensor(0.50, device=self.device)
|
||||||
get_mask(x, p)
|
get_mask(x, p)
|
||||||
|
|
||||||
|
def test_flexible_layout_immutable_free_symbols(self):
|
||||||
|
import sympy
|
||||||
|
|
||||||
|
x = sympy.Symbol("x")
|
||||||
|
y = sympy.Symbol("y")
|
||||||
|
z = sympy.Symbol("z")
|
||||||
|
|
||||||
|
layout = torch._inductor.ir.FlexibleLayout(
|
||||||
|
self.device, torch.float32, size=(x, y)
|
||||||
|
)
|
||||||
|
|
||||||
|
# pad_strides works since it does not add new symints
|
||||||
|
layout.pad_strides()
|
||||||
|
|
||||||
|
# same symints and different order should work
|
||||||
|
layout.size = (y, x)
|
||||||
|
|
||||||
|
# adding new symints should fail
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AssertionError, "Expected free symbols unchanged, but got"
|
||||||
|
):
|
||||||
|
layout.size = (z,)
|
||||||
|
|
||||||
def test_sqrt_dynamic_shapes(self):
|
def test_sqrt_dynamic_shapes(self):
|
||||||
# TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877.
|
# TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877.
|
||||||
# TODO: support cuda path.
|
# TODO: support cuda path.
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||||
compute_unbacked_bindings,
|
compute_unbacked_bindings,
|
||||||
free_symbols,
|
free_symbols,
|
||||||
free_unbacked_symbols,
|
free_unbacked_symbols,
|
||||||
|
IterateExprs,
|
||||||
rebind_unbacked,
|
rebind_unbacked,
|
||||||
resolve_unbacked_bindings,
|
resolve_unbacked_bindings,
|
||||||
ShapeEnv,
|
ShapeEnv,
|
||||||
|
|
@ -97,6 +98,7 @@ from .utils import (
|
||||||
argsort,
|
argsort,
|
||||||
argsort_sym,
|
argsort_sym,
|
||||||
cache_on_self,
|
cache_on_self,
|
||||||
|
cache_on_self_and_args,
|
||||||
ceildiv,
|
ceildiv,
|
||||||
convert_shape_to_inductor,
|
convert_shape_to_inductor,
|
||||||
convert_shape_to_symint,
|
convert_shape_to_symint,
|
||||||
|
|
@ -933,6 +935,7 @@ class Loops(IRNode):
|
||||||
inner_fn: Callable[..., Any]
|
inner_fn: Callable[..., Any]
|
||||||
ranges: Sequence[_IntLike]
|
ranges: Sequence[_IntLike]
|
||||||
|
|
||||||
|
@cache_on_self_and_args("Loops")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -1228,6 +1231,7 @@ class Reduction(Loops):
|
||||||
|
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
|
@cache_on_self_and_args("Reduction")
|
||||||
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
||||||
return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
|
return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
|
||||||
*(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
|
*(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
|
||||||
|
|
@ -2327,6 +2331,7 @@ class Scan(Loops):
|
||||||
|
|
||||||
# HACK we mimic reduction
|
# HACK we mimic reduction
|
||||||
|
|
||||||
|
@cache_on_self_and_args("Scan")
|
||||||
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
||||||
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
|
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
|
||||||
# need to explicitly represent the closure so we can pull out unbacked
|
# need to explicitly represent the closure so we can pull out unbacked
|
||||||
|
|
@ -2537,6 +2542,7 @@ class Sort(Loops):
|
||||||
|
|
||||||
# HACK we mimic reduction
|
# HACK we mimic reduction
|
||||||
|
|
||||||
|
@cache_on_self_and_args("Sort")
|
||||||
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
||||||
return (
|
return (
|
||||||
super().get_free_symbol_uses(unbacked_only)
|
super().get_free_symbol_uses(unbacked_only)
|
||||||
|
|
@ -2785,6 +2791,7 @@ def is_unaligned(node: IRNode) -> bool:
|
||||||
class BaseView(IRNode):
|
class BaseView(IRNode):
|
||||||
data: IRNode
|
data: IRNode
|
||||||
|
|
||||||
|
@cache_on_self_and_args("BaseView")
|
||||||
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
||||||
return self.data.get_free_symbol_uses(unbacked_only)
|
return self.data.get_free_symbol_uses(unbacked_only)
|
||||||
|
|
||||||
|
|
@ -3359,6 +3366,7 @@ class ReinterpretView(BaseView):
|
||||||
def freeze_layout(self) -> None:
|
def freeze_layout(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@cache_on_self_and_args("ReinterpretView")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -3643,13 +3651,37 @@ class Layout(OutputSpec):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
assert len(size) == len(stride), f"size={size}, stride={stride}"
|
assert len(size) == len(stride), f"size={size}, stride={stride}"
|
||||||
assert all(isinstance(s, (Expr, int)) for s in size)
|
assert all(isinstance(s, (Expr, int)) for s in size)
|
||||||
self.size = size
|
self._size = size
|
||||||
self.stride = stride
|
self._stride = stride
|
||||||
self.offset = offset
|
self._offset = offset
|
||||||
self.is_pinned = is_pinned
|
self.is_pinned = is_pinned
|
||||||
# is_pinned implies cpu
|
# is_pinned implies cpu
|
||||||
assert (not self.is_pinned) or (self.device.type == "cpu")
|
assert (not self.is_pinned) or (self.device.type == "cpu")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> Sequence[Expr]:
|
||||||
|
return self._size
|
||||||
|
|
||||||
|
@size.setter
|
||||||
|
def size(self, value: Sequence[Expr]) -> None:
|
||||||
|
self._size = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self) -> Sequence[Expr]:
|
||||||
|
return self._stride
|
||||||
|
|
||||||
|
@stride.setter
|
||||||
|
def stride(self, value: Sequence[Expr]) -> None:
|
||||||
|
self._stride = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def offset(self) -> Expr:
|
||||||
|
return self._offset
|
||||||
|
|
||||||
|
@offset.setter
|
||||||
|
def offset(self, value: Expr) -> None:
|
||||||
|
self._offset = value
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
offset = ""
|
offset = ""
|
||||||
if self.offset != 0:
|
if self.offset != 0:
|
||||||
|
|
@ -3869,6 +3901,7 @@ class Layout(OutputSpec):
|
||||||
def storage_size(self) -> Expr:
|
def storage_size(self) -> Expr:
|
||||||
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
|
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
@cache_on_self_and_args("Layout")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -3888,7 +3921,11 @@ class FixedLayout(Layout):
|
||||||
|
|
||||||
|
|
||||||
class FlexibleLayout(Layout):
|
class FlexibleLayout(Layout):
|
||||||
"""A Tensor layout that we are allowed to change"""
|
"""
|
||||||
|
A Tensor layout that we are allowed to change
|
||||||
|
|
||||||
|
Assumption: layout change should NOT add or remove free symbols
|
||||||
|
"""
|
||||||
|
|
||||||
allow_indexing = False
|
allow_indexing = False
|
||||||
|
|
||||||
|
|
@ -3973,6 +4010,33 @@ class FlexibleLayout(Layout):
|
||||||
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
|
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
|
||||||
return FlexibleLayout.fill_ordered(sizes, fill_order)
|
return FlexibleLayout.fill_ordered(sizes, fill_order)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> Sequence[Expr]:
|
||||||
|
return self._size
|
||||||
|
|
||||||
|
@size.setter
|
||||||
|
def size(self, value: Sequence[Expr]) -> None:
|
||||||
|
self.assert_free_symbol_uses_unchanged("size", value)
|
||||||
|
self._size = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self) -> Sequence[Expr]:
|
||||||
|
return self._stride
|
||||||
|
|
||||||
|
@stride.setter
|
||||||
|
def stride(self, value: Sequence[Expr]) -> None:
|
||||||
|
self.assert_free_symbol_uses_unchanged("stride", value)
|
||||||
|
self._stride = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def offset(self) -> Expr:
|
||||||
|
return self._offset
|
||||||
|
|
||||||
|
@offset.setter
|
||||||
|
def offset(self, value: Expr) -> None:
|
||||||
|
self.assert_free_symbol_uses_unchanged("offset", value)
|
||||||
|
self._offset = value
|
||||||
|
|
||||||
def as_stride_order(
|
def as_stride_order(
|
||||||
self, order: Sequence[int], allow_padding: bool = False
|
self, order: Sequence[int], allow_padding: bool = False
|
||||||
) -> FixedLayout:
|
) -> FixedLayout:
|
||||||
|
|
@ -4031,6 +4095,25 @@ class FlexibleLayout(Layout):
|
||||||
self.is_pinned,
|
self.is_pinned,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]:
|
||||||
|
initial_free_symbols = {}
|
||||||
|
for name in ["size", "stride", "offset"]:
|
||||||
|
for unbacked_only in [True, False]:
|
||||||
|
key = (name, unbacked_only)
|
||||||
|
initial_free_symbols[key] = OrderedSet(
|
||||||
|
get_free_symbols(getattr(self, name), unbacked_only)
|
||||||
|
)
|
||||||
|
|
||||||
|
return initial_free_symbols
|
||||||
|
|
||||||
|
def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None:
|
||||||
|
for unbacked_only in [True, False]:
|
||||||
|
old_free_symbols = self.initial_free_symbols[(name, unbacked_only)]
|
||||||
|
new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only))
|
||||||
|
assert new_free_symbols == old_free_symbols, (
|
||||||
|
f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
|
@ -4045,6 +4128,10 @@ class FlexibleLayout(Layout):
|
||||||
strides = FlexibleLayout.contiguous_strides(size)
|
strides = FlexibleLayout.contiguous_strides(size)
|
||||||
super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
|
super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
|
||||||
|
|
||||||
|
# record the initial free symbols to check that we do not add new free symbols
|
||||||
|
# later when modifying sizes, strides, and offsets.
|
||||||
|
self.initial_free_symbols = self.get_initial_free_symbol_uses()
|
||||||
|
|
||||||
|
|
||||||
class NonOwningLayout(Layout):
|
class NonOwningLayout(Layout):
|
||||||
"""Is a view into the storage of another tensor"""
|
"""Is a view into the storage of another tensor"""
|
||||||
|
|
@ -4070,6 +4157,7 @@ class NonOwningLayout(Layout):
|
||||||
|
|
||||||
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
|
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
|
||||||
|
|
||||||
|
@cache_on_self_and_args("NonOwningLayout")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -4358,6 +4446,7 @@ class Buffer(IRNode, CodegenSymbol):
|
||||||
def get_read_names(self) -> OrderedSet[str]:
|
def get_read_names(self) -> OrderedSet[str]:
|
||||||
return OrderedSet([self.get_name()])
|
return OrderedSet([self.get_name()])
|
||||||
|
|
||||||
|
@cache_on_self_and_args("Buffer")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -4430,6 +4519,7 @@ class NoneAsConstantBuffer(IRNode):
|
||||||
def get_reads(self) -> OrderedSet[Dep]:
|
def get_reads(self) -> OrderedSet[Dep]:
|
||||||
return OrderedSet()
|
return OrderedSet()
|
||||||
|
|
||||||
|
@cache_on_self_and_args("NoneAsConstantBuffer")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -4449,6 +4539,7 @@ class NoneAsConstantBuffer(IRNode):
|
||||||
class ShapeAsConstantBuffer(IRNode):
|
class ShapeAsConstantBuffer(IRNode):
|
||||||
expr: Expr
|
expr: Expr
|
||||||
|
|
||||||
|
@cache_on_self_and_args("ShapeAsConstantBuffer")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -4521,6 +4612,7 @@ class ComputedBuffer(OperationBuffer):
|
||||||
self.data.get_size(),
|
self.data.get_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cache_on_self_and_args("ComputedBuffer")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -4974,6 +5066,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||||
self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
|
self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
|
||||||
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
|
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
|
||||||
|
|
||||||
|
@cache_on_self_and_args("TritonTemplateBuffer")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -5340,6 +5433,7 @@ class InputsKernel(OperationBuffer):
|
||||||
def num_reads(self) -> int:
|
def num_reads(self) -> int:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@cache_on_self_and_args("InputsKernel")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -5514,6 +5608,7 @@ class ConcatKernel(NopKernel):
|
||||||
and not isinstance(src.data, ExternKernelAlloc)
|
and not isinstance(src.data, ExternKernelAlloc)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cache_on_self_and_args("ConcatKernel")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -6430,6 +6525,7 @@ class ExternKernel(InputsKernel):
|
||||||
index = sympy_subs(sympy.expand(index), replacement)
|
index = sympy_subs(sympy.expand(index), replacement)
|
||||||
return index, tuple(new_sizes)
|
return index, tuple(new_sizes)
|
||||||
|
|
||||||
|
@cache_on_self_and_args("ExternKernel")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -6889,6 +6985,7 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||||
original_fxnode_name=self.fx_node.name,
|
original_fxnode_name=self.fx_node.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cache_on_self_and_args("UserDefinedTritonKernel")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -7327,6 +7424,7 @@ class DynamicSelectStorageOffset(ExternKernel):
|
||||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||||
return OrderedSet([self.unbacked_offset_symbol])
|
return OrderedSet([self.unbacked_offset_symbol])
|
||||||
|
|
||||||
|
@cache_on_self_and_args("DynamicSelectStorageOffset")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -7377,6 +7475,7 @@ class DynamicSliceSize(ExternKernel):
|
||||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||||
return OrderedSet([self.unbacked_size_symbol])
|
return OrderedSet([self.unbacked_size_symbol])
|
||||||
|
|
||||||
|
@cache_on_self_and_args("DynamicSliceSize")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -7441,6 +7540,7 @@ class AssertScalar(ExternKernel):
|
||||||
def has_side_effects(self) -> bool:
|
def has_side_effects(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@cache_on_self_and_args("AssertScalar")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -8115,6 +8215,7 @@ class MultiOutput(ExternKernel):
|
||||||
self.indices = indices
|
self.indices = indices
|
||||||
self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
|
self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
|
||||||
|
|
||||||
|
@cache_on_self_and_args("MultiOutput")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -8237,6 +8338,7 @@ class MutableBox(IRNode):
|
||||||
def realize(self) -> Optional[str]:
|
def realize(self) -> Optional[str]:
|
||||||
return self.data.realize()
|
return self.data.realize()
|
||||||
|
|
||||||
|
@cache_on_self_and_args("MutableBox")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
@ -9073,6 +9175,7 @@ class EffectfulKernel(FallbackKernel):
|
||||||
|
|
||||||
|
|
||||||
class NonTensorObj(IRNode):
|
class NonTensorObj(IRNode):
|
||||||
|
@cache_on_self_and_args("NonTensorObj")
|
||||||
def get_free_symbol_uses(
|
def get_free_symbol_uses(
|
||||||
self, unbacked_only: bool = False
|
self, unbacked_only: bool = False
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
|
|
||||||
|
|
@ -662,6 +662,7 @@ def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
RV = TypeVar("RV", covariant=True)
|
RV = TypeVar("RV", covariant=True)
|
||||||
|
FN_TYPE = Callable[Concatenate[Any, P], RV]
|
||||||
|
|
||||||
|
|
||||||
class CachedMethod(Protocol, Generic[P, RV]):
|
class CachedMethod(Protocol, Generic[P, RV]):
|
||||||
|
|
@ -709,6 +710,52 @@ def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
|
||||||
return cache_on_self(fn)
|
return cache_on_self(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_on_self_and_args(
|
||||||
|
class_name: str,
|
||||||
|
) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
|
||||||
|
# include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
|
||||||
|
|
||||||
|
def wrapper(
|
||||||
|
fn: FN_TYPE[P, RV],
|
||||||
|
) -> FN_TYPE[P, RV]:
|
||||||
|
key = f"__{class_name}_{fn.__name__}_cache"
|
||||||
|
|
||||||
|
# wrapper is likely on the hot path, compile a specialized version of it
|
||||||
|
ctx = {"fn": fn}
|
||||||
|
exec(
|
||||||
|
f"""\
|
||||||
|
def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
|
||||||
|
args_kwargs = (args, tuple(sorted(kwargs.items())))
|
||||||
|
|
||||||
|
if not hasattr(self, "{key}"):
|
||||||
|
object.__setattr__(self, "{key}", {{}})
|
||||||
|
|
||||||
|
cache = self.{key}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return cache[args_kwargs]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
rv = fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
cache[args_kwargs] = rv
|
||||||
|
return rv
|
||||||
|
""".lstrip(),
|
||||||
|
ctx,
|
||||||
|
)
|
||||||
|
inner = functools.wraps(fn)(ctx["inner"])
|
||||||
|
|
||||||
|
def clear_cache(self: Any) -> None:
|
||||||
|
if hasattr(self, key):
|
||||||
|
delattr(self, key)
|
||||||
|
|
||||||
|
inner.clear_cache = clear_cache # type: ignore[attr-defined]
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def aggregate_origins(
|
def aggregate_origins(
|
||||||
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||||
) -> OrderedSet[Node]:
|
) -> OrderedSet[Node]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user