Revert "DDE-Free select with unbacked index. (#157605)"

This reverts commit 79d7c754ab.

Reverted https://github.com/pytorch/pytorch/pull/157605 on behalf of https://github.com/laithsakka due to fail pr time benchmarks  ([comment](https://github.com/pytorch/pytorch/pull/157605#issuecomment-3084663020))
This commit is contained in:
PyTorch MergeBot 2025-07-17 16:20:02 +00:00
parent 16b21fa8b2
commit 23550ab735
15 changed files with 53 additions and 349 deletions

View File

@ -15782,28 +15782,6 @@ def forward(self, x, mask):
ignore_empty_lines=True,
)
def test_unbacked_select_index(self):
class MyModel(torch.nn.Module):
def forward(self, x, y):
u0 = y.item()
return x.select(0, u0)
example_inputs = (
torch.randn((3, 3), dtype=torch.bfloat16),
torch.tensor([0]),
)
traced = export(MyModel(), example_inputs).run_decompositions({})
self.assertExpectedInline(
traced.graph_module.code,
"""\
def forward(self, x, y):
item = torch.ops.aten.item.default(y); y = None
select = torch.ops.aten.select.int(x, 0, item); x = item = None
return (select,)""",
ignore_empty_lines=True,
)
if __name__ == "__main__":
run_tests()

View File

@ -3529,88 +3529,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
ignore_empty_lines=True,
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_index(self):
cnt = CompileCounterWithBackend("inductor")
def func(x, y):
u0 = y.item()
return (
torch.select(x, 0, u0),
torch.select(x, 1, u0),
torch.select(x, 2, u0),
)
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.rand(3, 3, 3)
zero = torch.tensor([0])
pos = torch.tensor([1])
# code can handle both negative and positive indices.
neg = torch.tensor([-1])
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
self.assertEqual(compiled_func(x, zero), func(x, zero))
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
output,
"""\
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense)
select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense)
select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None
return (select, select_1, select_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
self.assertEqual(compiled_func(x, pos), func(x, pos))
self.assertEqual(compiled_func(x, neg), func(x, neg))
self.assertEqual(cnt.frame_count, 1)
def func2(x, y):
u0, u1 = y.tolist()
return torch.select(x, 0, u0 + u1)
compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)(
func2
)
zero = torch.tensor([0, 0])
pos = torch.tensor([1, 1])
neg = torch.tensor([-1, -1])
self.assertEqual(compiled_func2(x, pos), func2(x, pos))
self.assertEqual(compiled_func2(x, neg), func2(x, neg))
self.assertEqual(compiled_func2(x, zero), func2(x, zero))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_index_with_check(self):
def func3(x, y):
u0 = y.item()
# Test that taking the non-unbacked path works fine also.
torch._check(u0 >= 0)
return (torch.select(x, 1, u0),)
compiled_func3 = torch.compile(
fullgraph=True, backend="inductor", dynamic=True
)(func3)
x = torch.rand(3, 3, 3)
zero = torch.tensor([0])
pos = torch.tensor([1])
print(compiled_func3(x, pos))
self.assertEqual(compiled_func3(x, pos), func3(x, pos))
self.assertEqual(compiled_func3(x, zero), func3(x, zero))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_unbacked_select_index_cpp_wrapper(self):
self.test_unbacked_select_index()
instantiate_parametrized_tests(TestUnbacked)

View File

@ -54,7 +54,6 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None)
)
},
)
node.meta["torch_fn"] = (
f"{node.target.__name__}_0",
f"{node.target.__class__.__name__}.{node.target.__name__}",

View File

@ -1447,20 +1447,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.sym))
def codegen_dynamic_select_index(self, node):
index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
index_compute_str = (
f"{index_cpp_str} < 0 ? {index_cpp_str} + "
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
)
self.writeline(
f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});"
)
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
def make_buffer_free(self, buffer):
return (
""

View File

@ -1802,14 +1802,6 @@ class PythonWrapperCodegen(CodeGen):
arg_name = node.input_name(0)
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
def codegen_dynamic_select_index(self, node):
index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
self.writeline(
f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
)
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
def codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
if len(node.keypath) == 0:

View File

@ -11,7 +11,6 @@ from unittest.mock import patch
import sympy
import torch
from torch._inductor.utils import get_free_symbols
from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
@ -39,12 +38,6 @@ class Dep(abc.ABC):
name: str
index: sympy.Expr
@abc.abstractmethod
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
pass
@abc.abstractmethod
def rename(self, renames: dict[str, str]) -> Self:
pass
@ -77,15 +70,6 @@ class MemoryDep(Dep):
size: tuple[sympy.Expr, ...]
mode: Optional[str] = None
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return (
get_free_symbols(self.index, unbacked_only)
| get_free_symbols(self.size, unbacked_only)
| get_free_symbols(self.var_names, unbacked_only)
)
def __repr__(self) -> str:
maybe_mode = ""
if self.mode is not None:
@ -323,11 +307,6 @@ class StarDep(Dep):
return StarDep(renames[self.name], self.mode)
return self
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def numbytes_hint(self) -> int:
try:
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
@ -364,11 +343,6 @@ class WeakDep(Dep):
# Buffer that is doing the mutation
mutating_buf: str
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
@property
def index(self) -> sympy.Expr:
raise NotImplementedError("WeakDep does not have an index")
@ -466,15 +440,6 @@ class ReadWrites:
names.add(dep.name)
return names
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
result: OrderedSet[sympy.Symbol] = OrderedSet()
for dep in self.reads_and_writes():
result |= dep.get_free_symbol_uses(unbacked_only)
return result
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:

View File

@ -341,7 +341,6 @@ class GraphLowering(torch.fx.Interpreter):
shape_env.deferred_runtime_asserts.copy()
)
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: list[str] = []
self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
@ -1822,7 +1821,7 @@ class GraphLowering(torch.fx.Interpreter):
shape_env = V.graph.sizevars.shape_env
# An input can be unbacked symint i.e.: when mark_unabcked is used.
# An input can an unbacked symint i.e.: when mark_unabcked is used.
# in that case add it to new_unbacked_defs.
if (
n.op == "placeholder"
@ -1889,7 +1888,6 @@ class GraphLowering(torch.fx.Interpreter):
V.fake_mode.shape_env.unbacked_renamings.get(s, s)
for s in unbacked_bindings.keys()
)
assert new_unbacked_defs >= renamed_unbacked_bindings, (
f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
f"fx node is: {n.format_node()}\n"

View File

@ -49,7 +49,6 @@ from torch._dynamo.utils import identity
from torch._export.serde.serialize import GraphModuleSerializer
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
from torch._inductor import metrics
from torch._inductor.utils import get_free_symbols
from torch._prims_common import (
compute_required_storage_length,
is_boolean_dtype,
@ -63,6 +62,7 @@ from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
free_symbols,
free_unbacked_symbols,
IterateExprs,
rebind_unbacked,
resolve_unbacked_bindings,
ShapeEnv,
@ -304,6 +304,13 @@ def fuse_reindexing(
return reindex
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
if unbacked_only:
return free_unbacked_symbols(x)
else:
return free_symbols(x)
NHWC_STRIDE_ORDER = [3, 0, 2, 1]
NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
@ -4322,13 +4329,6 @@ class ComputedBuffer(OperationBuffer):
return self.data.get_read_names()
def get_read_writes(self) -> dependencies.ReadWrites:
if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)):
return dependencies.ReadWrites(
reads=OrderedSet(),
writes=OrderedSet(),
index_exprs=OrderedSet(),
)
with patch.object(FlexibleLayout, "allow_indexing", True):
if self.data.get_reduction_type():
return extract_read_writes(
@ -4367,7 +4367,6 @@ class ComputedBuffer(OperationBuffer):
| get_free_symbols(self.get_stride(), unbacked_only)
| get_free_symbols(self.get_offset(), unbacked_only)
| self.data.get_free_symbol_uses(unbacked_only)
| self.get_read_writes().get_free_symbol_uses(unbacked_only)
)
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
@ -6976,50 +6975,6 @@ class DeviceCopy(ExternKernelOut):
wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
class DynamicSelectStorageOffset(ExternKernel):
"""
The result of computing a dynamic selection index is determined as follows: when the index in the
select operation is unbacked, the actual index calculation is ambiguous for negative indices
(index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked
SymInt to represent the storage offset and decompose the select operation into a call to as_strided,
computing the storage offset at runtime with this node.
"""
def get_reads(self) -> OrderedSet[Dep]:
return OrderedSet()
def should_allocate(self) -> bool:
return False
def __init__(
self,
unbacked_offset_symbol: sympy.Symbol,
index: sympy.Symbol,
base_offset: Union[sympy.Symbol, int],
base_dim_stride: Union[sympy.Symbol, int],
size: Union[sympy.Symbol, int],
) -> None:
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
# This node codegen the following:
# unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size)
self.unbacked_offset_symbol = unbacked_offset_symbol
self.index = index
self.base_offset = base_offset
self.base_dim_stride = base_dim_stride
self.size = size
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.unbacked_offset_symbol])
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return get_free_symbols(self.index, unbacked_only)
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
wrapper.codegen_dynamic_select_index(self)
class DynamicScalar(ExternKernel):
"""
The result of a call to aten._local_scalar_dense.

View File

@ -40,11 +40,7 @@ from torch._prims_common import (
Number,
)
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import (
free_unbacked_symbols,
has_free_unbacked_symbols,
resolve_unbacked_bindings,
)
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing
@ -994,7 +990,10 @@ def squeeze(x, dim=None):
new_shape = []
for d, s in enumerate(x.get_size()):
if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))):
if not (
d in dims
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
):
new_shape.append(s)
# squeeze does nothing if the size isn't 1
@ -1760,60 +1759,8 @@ def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
@register_lowering(aten.select, type_promotion_kind=None)
def select(x, dim, idx):
idx = sympy.expand(idx)
size = sympy.expand(x.get_size()[dim])
actual_index = None
if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)):
actual_index = idx + size
elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)):
actual_index = idx
if actual_index is not None:
if has_free_unbacked_symbols(idx):
# Inductor could generate incorrect views for tensors with unbacked symbols here;
# Squeeze operations are translated to views, resulting in incorrect strides.
# Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
# we use as_strided instead.
# Removing this branch will cause test_unbacked_select_index_with_check to fail.
new_size = x.get_size()
new_stride = x.get_stride()
new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
del new_size[dim]
del new_stride[dim]
return as_strided(x, new_size, new_stride, new_storage_offset)
else:
slice_result = slice_(x, dim, actual_index, actual_index + 1)
return squeeze(slice_result, dim)
# Unbacked Semantics:
# When the index idx is unbacked (e.g., u0), we compute the index dynamically
# during the lowering of the select operation using DynamicSelectStorageOffset.
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
)
assert unbacked_bindings is not None
assert len(unbacked_bindings) == 1, unbacked_bindings
unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
new_size = x.get_size()
new_stride = x.get_stride()
new_storage_offset = unbacked_offset_sym
buffer = ir.DynamicSelectStorageOffset(
unbacked_offset_sym,
idx,
x.get_layout().offset,
new_stride[dim],
x.get_size()[dim],
)
buffer.name = V.graph.register_buffer(buffer)
V.graph.register_operation(buffer)
del new_size[dim]
del new_stride[dim]
return as_strided(x, new_size, new_stride, new_storage_offset)
idx = View.handle_negative_index(idx, x.get_size()[dim])
return squeeze(slice_(x, dim, idx, idx + 1), dim)
@register_lowering(aten.split, type_promotion_kind=None)
@ -3139,6 +3086,8 @@ def long_tensor(data):
@register_lowering(aten._local_scalar_dense)
def _local_scalar_dense(data):
from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings
# This is interesting! Most lowerings return tensors, so you can just
# return the buffer you allocated and it will get used (or not used, if
# it's dead.) But _local_scalar_dense (aka item) returns an int,

View File

@ -2130,11 +2130,9 @@ class Scheduler:
self.logged_slow_fusion = OrderedSet[tuple[str, str]]()
if config._pre_fusion_custom_pass is not None:
self.nodes = config._pre_fusion_custom_pass(self.nodes)
self.nodes = self.fuse_nodes(self.nodes)
if config._post_fusion_custom_pass is not None:
self.nodes = config._post_fusion_custom_pass(self.nodes)
self.merge_loops()
self.finalize_multi_template_buffers()
if config.combo_kernels:
@ -2368,6 +2366,7 @@ class Scheduler:
for node in self.nodes:
log.debug("scheduling %s", node.node)
# unbacked symbols don't follow ordinary buffer dependencies, so
# we track their def/uses separately
assert node.node is not None

View File

@ -69,20 +69,13 @@ OPTIMUS_EXCLUDE_POST_GRAD = [
"inductor_autotune_lookup_table",
]
from torch.fx.experimental.symbolic_shapes import (
free_symbols,
free_unbacked_symbols,
IterateExprs,
ShapeEnv,
)
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence, ValuesView
from torch import SymBool, SymFloat, SymInt
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.fx import GraphModule
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.fx.node import Node
from .codegen.common import WorkspaceArg
@ -3366,10 +3359,3 @@ def aoti_model_name_from_config() -> str:
model_name = config.aot_inductor.model_name_for_generated_files
model_name = "aoti_model" if model_name is None else model_name
return model_name
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
if unbacked_only:
return free_unbacked_symbols(x)
else:
return free_symbols(x)

View File

@ -5553,6 +5553,39 @@ def meta_zeros(
)
@register_meta(aten.select.int)
def meta_select(self, dim, index):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
ndim = self.dim()
torch._check_index(
ndim != 0,
lambda: "select() cannot be applied to a 0-dim tensor.",
)
dim = dim if dim >= 0 else dim + ndim
size = self.size(dim)
torch._check_index(
not (
guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
),
lambda: f"select(): index {index} out of range for tensor of size "
f"{self.size()} at dimension {dim}",
)
index = index if index >= 0 else index + size
new_size = list(self.size())
new_stride = list(self.stride())
new_storage_offset = self.storage_offset() + index * new_stride[dim]
del new_size[dim]
del new_stride[dim]
return self.as_strided(new_size, new_stride, new_storage_offset)
@register_meta(aten.select_scatter.default)
def meta_select_scatter(self, src, dim, index):
return utils.clone_preserve_strides(self)

View File

@ -359,48 +359,6 @@ def unique2(
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
@register_op_impl(aten.select.int)
def meta_select(fake_mode, func, self, dim, index):
from torch.fx.experimental.symbolic_shapes import guard_or_false
if self.is_sparse:
return NotImplemented
ndim = self.dim()
torch._check_index(
ndim != 0,
lambda: "select() cannot be applied to a 0-dim tensor.",
)
dim = dim if dim >= 0 else dim + ndim
size = self.size(dim)
new_size = list(self.size())
new_stride = list(self.stride())
new_storage_offset = None
if guard_or_false(index >= 0):
new_storage_offset = self.storage_offset() + index * new_stride[dim]
elif guard_or_false(index < 0):
new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim]
if new_storage_offset is None:
if fake_mode.shape_env is None or (
not fake_mode.shape_env.allow_scalar_outputs
and not fake_mode.allow_scalar_outputs
):
raise DataDependentOutputException(func)
# index is data-dependent, we do not know which index we are accessing it could be index or index+size!
# we assign a new data-dependent symbol for the storage offset.
new_storage_offset = fake_mode.shape_env.create_unbacked_symint()
del new_size[dim]
del new_stride[dim]
assert new_storage_offset is not None
return self.as_strided(new_size, new_stride, new_storage_offset)
@register_op_impl(aten.unique_dim.default)
def unique_dim(
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False

View File

@ -1282,7 +1282,6 @@ def compute_unbacked_bindings(
return None
fs = shape_env.pending_fresh_unbacked_symbols
pending = set(fs)
if not pending:
return None
@ -4810,7 +4809,6 @@ class ShapeEnv:
)
self.counter["create_unbacked_symbol"] += 1
if not self._ignore_fresh_unbacked_symbols_tls():
print(f"adding {symbol}")
self.pending_fresh_unbacked_symbols.append(symbol)
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
vr = self.var_to_range[symbol] = ValueRanges.unknown()

View File

@ -461,7 +461,6 @@ def insert_deferred_runtime_asserts(
),
keypath[2:],
)
return go(
graph.call_method(
keypath[0].name, (node, keypath[1].idx)
@ -469,15 +468,6 @@ def insert_deferred_runtime_asserts(
keypath[2:],
)
elif isinstance(keypath[0], CallMethodKey):
if keypath[0].name == "storage_offset":
return go(
graph.call_function(
torch.ops.aten.sym_storage_offset.default,
(node,),
),
keypath[1:],
)
return go(
graph.call_method(keypath[0].name, (node,)), keypath[1:]
)