From 23550ab735eee1b9cc90609788dc64ccfb242af2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Jul 2025 16:20:02 +0000 Subject: [PATCH] Revert "DDE-Free select with unbacked index. (#157605)" This reverts commit 79d7c754ab8ae0e5c3a614521632d2cfbfa0fdba. Reverted https://github.com/pytorch/pytorch/pull/157605 on behalf of https://github.com/laithsakka due to fail pr time benchmarks ([comment](https://github.com/pytorch/pytorch/pull/157605#issuecomment-3084663020)) --- test/export/test_export.py | 22 ------ test/test_dynamic_shapes.py | 82 --------------------- torch/_export/passes/_node_metadata_hook.py | 1 - torch/_inductor/codegen/cpp_wrapper_cpu.py | 14 ---- torch/_inductor/codegen/wrapper.py | 8 -- torch/_inductor/dependencies.py | 35 --------- torch/_inductor/graph.py | 4 +- torch/_inductor/ir.py | 61 ++------------- torch/_inductor/lowering.py | 69 +++-------------- torch/_inductor/scheduler.py | 3 +- torch/_inductor/utils.py | 16 +--- torch/_meta_registrations.py | 33 +++++++++ torch/_subclasses/fake_impls.py | 42 ----------- torch/fx/experimental/symbolic_shapes.py | 2 - torch/fx/passes/runtime_assert.py | 10 --- 15 files changed, 53 insertions(+), 349 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index d1cecb55329..dea00055696 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -15782,28 +15782,6 @@ def forward(self, x, mask): ignore_empty_lines=True, ) - def test_unbacked_select_index(self): - class MyModel(torch.nn.Module): - def forward(self, x, y): - u0 = y.item() - return x.select(0, u0) - - example_inputs = ( - torch.randn((3, 3), dtype=torch.bfloat16), - torch.tensor([0]), - ) - - traced = export(MyModel(), example_inputs).run_decompositions({}) - self.assertExpectedInline( - traced.graph_module.code, - """\ -def forward(self, x, y): - item = torch.ops.aten.item.default(y); y = None - select = torch.ops.aten.select.int(x, 0, item); x = item = None - return (select,)""", - ignore_empty_lines=True, - ) - if __name__ == "__main__": run_tests() diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 59c08f71671..0f299cd6b6c 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3529,88 +3529,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", ignore_empty_lines=True, ) - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_unbacked_select_index(self): - cnt = CompileCounterWithBackend("inductor") - - def func(x, y): - u0 = y.item() - return ( - torch.select(x, 0, u0), - torch.select(x, 1, u0), - torch.select(x, 2, u0), - ) - - compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) - x = torch.rand(3, 3, 3) - zero = torch.tensor([0]) - pos = torch.tensor([1]) - # code can handle both negative and positive indices. - neg = torch.tensor([-1]) - - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - self.assertEqual(compiled_func(x, zero), func(x, zero)) - output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() - self.assertExpectedInline( - output, - """\ - _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None - select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense) - select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense) - select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None - return (select, select_1, select_2)""", # noqa: B950 - ignore_comments=True, - ignore_empty_lines=True, - ) - self.assertEqual(compiled_func(x, pos), func(x, pos)) - self.assertEqual(compiled_func(x, neg), func(x, neg)) - self.assertEqual(cnt.frame_count, 1) - - def func2(x, y): - u0, u1 = y.tolist() - return torch.select(x, 0, u0 + u1) - - compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)( - func2 - ) - zero = torch.tensor([0, 0]) - pos = torch.tensor([1, 1]) - neg = torch.tensor([-1, -1]) - - self.assertEqual(compiled_func2(x, pos), func2(x, pos)) - self.assertEqual(compiled_func2(x, neg), func2(x, neg)) - self.assertEqual(compiled_func2(x, zero), func2(x, zero)) - self.assertEqual(cnt.frame_count, 2) - - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_unbacked_select_index_with_check(self): - def func3(x, y): - u0 = y.item() - # Test that taking the non-unbacked path works fine also. - torch._check(u0 >= 0) - return (torch.select(x, 1, u0),) - - compiled_func3 = torch.compile( - fullgraph=True, backend="inductor", dynamic=True - )(func3) - x = torch.rand(3, 3, 3) - zero = torch.tensor([0]) - pos = torch.tensor([1]) - print(compiled_func3(x, pos)) - - self.assertEqual(compiled_func3(x, pos), func3(x, pos)) - self.assertEqual(compiled_func3(x, zero), func3(x, zero)) - - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - @torch._inductor.config.patch("cpp_wrapper", True) - def test_unbacked_select_index_cpp_wrapper(self): - self.test_unbacked_select_index() - instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index b1195cf4212..41005e50097 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -54,7 +54,6 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) ) }, ) - node.meta["torch_fn"] = ( f"{node.target.__name__}_0", f"{node.target.__class__.__name__}.{node.target.__name__}", diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 8a7f1b2aaa0..cbca6d9fe5d 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1447,20 +1447,6 @@ class CppWrapperCpu(PythonWrapperCodegen): # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.sym)) - def codegen_dynamic_select_index(self, node): - index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int) - - index_compute_str = ( - f"{index_cpp_str} < 0 ? {index_cpp_str} + " - f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}" - ) - self.writeline( - f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + " - f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});" - ) - # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again - self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) - def make_buffer_free(self, buffer): return ( "" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e601cbb8ed8..0b8ba86c3c1 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1802,14 +1802,6 @@ class PythonWrapperCodegen(CodeGen): arg_name = node.input_name(0) self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) - def codegen_dynamic_select_index(self, node): - index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}" - self.writeline( - f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})" - ) - # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again - self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) - def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) if len(node.keypath) == 0: diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index f948a7a534c..9de52061c64 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -11,7 +11,6 @@ from unittest.mock import patch import sympy import torch -from torch._inductor.utils import get_free_symbols from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols from torch.utils._ordered_set import OrderedSet @@ -39,12 +38,6 @@ class Dep(abc.ABC): name: str index: sympy.Expr - @abc.abstractmethod - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - pass - @abc.abstractmethod def rename(self, renames: dict[str, str]) -> Self: pass @@ -77,15 +70,6 @@ class MemoryDep(Dep): size: tuple[sympy.Expr, ...] mode: Optional[str] = None - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return ( - get_free_symbols(self.index, unbacked_only) - | get_free_symbols(self.size, unbacked_only) - | get_free_symbols(self.var_names, unbacked_only) - ) - def __repr__(self) -> str: maybe_mode = "" if self.mode is not None: @@ -323,11 +307,6 @@ class StarDep(Dep): return StarDep(renames[self.name], self.mode) return self - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return OrderedSet() - def numbytes_hint(self) -> int: try: return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( @@ -364,11 +343,6 @@ class WeakDep(Dep): # Buffer that is doing the mutation mutating_buf: str - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return OrderedSet() - @property def index(self) -> sympy.Expr: raise NotImplementedError("WeakDep does not have an index") @@ -466,15 +440,6 @@ class ReadWrites: names.add(dep.name) return names - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - result: OrderedSet[sympy.Symbol] = OrderedSet() - - for dep in self.reads_and_writes(): - result |= dep.get_free_symbol_uses(unbacked_only) - return result - class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 660b01b6923..ac299d5b0c2 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -341,7 +341,6 @@ class GraphLowering(torch.fx.Interpreter): shape_env.deferred_runtime_asserts.copy() ) self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() - self.sizevars = SizeVarAllocator(shape_env) self.graph_input_names: list[str] = [] self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {} @@ -1822,7 +1821,7 @@ class GraphLowering(torch.fx.Interpreter): shape_env = V.graph.sizevars.shape_env - # An input can be unbacked symint i.e.: when mark_unabcked is used. + # An input can an unbacked symint i.e.: when mark_unabcked is used. # in that case add it to new_unbacked_defs. if ( n.op == "placeholder" @@ -1889,7 +1888,6 @@ class GraphLowering(torch.fx.Interpreter): V.fake_mode.shape_env.unbacked_renamings.get(s, s) for s in unbacked_bindings.keys() ) - assert new_unbacked_defs >= renamed_unbacked_bindings, ( f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" f"fx node is: {n.format_node()}\n" diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 25f57a503df..d6dd82aa52f 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -49,7 +49,6 @@ from torch._dynamo.utils import identity from torch._export.serde.serialize import GraphModuleSerializer from torch._higher_order_ops.auto_functionalize import can_auto_functionalize from torch._inductor import metrics -from torch._inductor.utils import get_free_symbols from torch._prims_common import ( compute_required_storage_length, is_boolean_dtype, @@ -63,6 +62,7 @@ from torch.fx.experimental.symbolic_shapes import ( compute_unbacked_bindings, free_symbols, free_unbacked_symbols, + IterateExprs, rebind_unbacked, resolve_unbacked_bindings, ShapeEnv, @@ -304,6 +304,13 @@ def fuse_reindexing( return reindex +def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: + if unbacked_only: + return free_unbacked_symbols(x) + else: + return free_symbols(x) + + NHWC_STRIDE_ORDER = [3, 0, 2, 1] NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] @@ -4322,13 +4329,6 @@ class ComputedBuffer(OperationBuffer): return self.data.get_read_names() def get_read_writes(self) -> dependencies.ReadWrites: - if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)): - return dependencies.ReadWrites( - reads=OrderedSet(), - writes=OrderedSet(), - index_exprs=OrderedSet(), - ) - with patch.object(FlexibleLayout, "allow_indexing", True): if self.data.get_reduction_type(): return extract_read_writes( @@ -4367,7 +4367,6 @@ class ComputedBuffer(OperationBuffer): | get_free_symbols(self.get_stride(), unbacked_only) | get_free_symbols(self.get_offset(), unbacked_only) | self.data.get_free_symbol_uses(unbacked_only) - | self.get_read_writes().get_free_symbol_uses(unbacked_only) ) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: @@ -6976,50 +6975,6 @@ class DeviceCopy(ExternKernelOut): wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1]) -class DynamicSelectStorageOffset(ExternKernel): - """ - The result of computing a dynamic selection index is determined as follows: when the index in the - select operation is unbacked, the actual index calculation is ambiguous for negative indices - (index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked - SymInt to represent the storage offset and decompose the select operation into a call to as_strided, - computing the storage offset at runtime with this node. - """ - - def get_reads(self) -> OrderedSet[Dep]: - return OrderedSet() - - def should_allocate(self) -> bool: - return False - - def __init__( - self, - unbacked_offset_symbol: sympy.Symbol, - index: sympy.Symbol, - base_offset: Union[sympy.Symbol, int], - base_dim_stride: Union[sympy.Symbol, int], - size: Union[sympy.Symbol, int], - ) -> None: - super().__init__(None, NoneLayout(device=torch.device("cpu")), []) - # This node codegen the following: - # unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size) - self.unbacked_offset_symbol = unbacked_offset_symbol - self.index = index - self.base_offset = base_offset - self.base_dim_stride = base_dim_stride - self.size = size - - def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: - return OrderedSet([self.unbacked_offset_symbol]) - - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return get_free_symbols(self.index, unbacked_only) - - def codegen(self, wrapper: PythonWrapperCodegen) -> None: - wrapper.codegen_dynamic_select_index(self) - - class DynamicScalar(ExternKernel): """ The result of a call to aten._local_scalar_dense. diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index f6b08499e4d..c4c8f70003c 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -40,11 +40,7 @@ from torch._prims_common import ( Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.fx.experimental.symbolic_shapes import ( - free_unbacked_symbols, - has_free_unbacked_symbols, - resolve_unbacked_bindings, -) +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing @@ -994,7 +990,10 @@ def squeeze(x, dim=None): new_shape = [] for d, s in enumerate(x.get_size()): - if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))): + if not ( + d in dims + and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True) + ): new_shape.append(s) # squeeze does nothing if the size isn't 1 @@ -1760,60 +1759,8 @@ def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): @register_lowering(aten.select, type_promotion_kind=None) def select(x, dim, idx): - idx = sympy.expand(idx) - size = sympy.expand(x.get_size()[dim]) - actual_index = None - - if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)): - actual_index = idx + size - elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)): - actual_index = idx - - if actual_index is not None: - if has_free_unbacked_symbols(idx): - # Inductor could generate incorrect views for tensors with unbacked symbols here; - # Squeeze operations are translated to views, resulting in incorrect strides. - # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this, - # we use as_strided instead. - # Removing this branch will cause test_unbacked_select_index_with_check to fail. - new_size = x.get_size() - new_stride = x.get_stride() - new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index - - del new_size[dim] - del new_stride[dim] - return as_strided(x, new_size, new_stride, new_storage_offset) - else: - slice_result = slice_(x, dim, actual_index, actual_index + 1) - return squeeze(slice_result, dim) - - # Unbacked Semantics: - # When the index idx is unbacked (e.g., u0), we compute the index dynamically - # during the lowering of the select operation using DynamicSelectStorageOffset. - - unbacked_bindings = resolve_unbacked_bindings( - V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] - ) - assert unbacked_bindings is not None - assert len(unbacked_bindings) == 1, unbacked_bindings - unbacked_offset_sym, _ = next(iter(unbacked_bindings.items())) - - new_size = x.get_size() - new_stride = x.get_stride() - new_storage_offset = unbacked_offset_sym - buffer = ir.DynamicSelectStorageOffset( - unbacked_offset_sym, - idx, - x.get_layout().offset, - new_stride[dim], - x.get_size()[dim], - ) - buffer.name = V.graph.register_buffer(buffer) - V.graph.register_operation(buffer) - - del new_size[dim] - del new_stride[dim] - return as_strided(x, new_size, new_stride, new_storage_offset) + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) @register_lowering(aten.split, type_promotion_kind=None) @@ -3139,6 +3086,8 @@ def long_tensor(data): @register_lowering(aten._local_scalar_dense) def _local_scalar_dense(data): + from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings + # This is interesting! Most lowerings return tensors, so you can just # return the buffer you allocated and it will get used (or not used, if # it's dead.) But _local_scalar_dense (aka item) returns an int, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index a4507990400..34f15869085 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2130,11 +2130,9 @@ class Scheduler: self.logged_slow_fusion = OrderedSet[tuple[str, str]]() if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) - self.nodes = self.fuse_nodes(self.nodes) if config._post_fusion_custom_pass is not None: self.nodes = config._post_fusion_custom_pass(self.nodes) - self.merge_loops() self.finalize_multi_template_buffers() if config.combo_kernels: @@ -2368,6 +2366,7 @@ class Scheduler: for node in self.nodes: log.debug("scheduling %s", node.node) + # unbacked symbols don't follow ordinary buffer dependencies, so # we track their def/uses separately assert node.node is not None diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 7b3f495382f..5f9ce0b814e 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -69,20 +69,13 @@ OPTIMUS_EXCLUDE_POST_GRAD = [ "inductor_autotune_lookup_table", ] -from torch.fx.experimental.symbolic_shapes import ( - free_symbols, - free_unbacked_symbols, - IterateExprs, - ShapeEnv, -) - - if TYPE_CHECKING: from collections.abc import Iterable, Sequence, ValuesView from torch import SymBool, SymFloat, SymInt from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.fx import GraphModule + from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.node import Node from .codegen.common import WorkspaceArg @@ -3366,10 +3359,3 @@ def aoti_model_name_from_config() -> str: model_name = config.aot_inductor.model_name_for_generated_files model_name = "aoti_model" if model_name is None else model_name return model_name - - -def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: - if unbacked_only: - return free_unbacked_symbols(x) - else: - return free_symbols(x) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 2933a37c37f..ae87e0e17fb 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5553,6 +5553,39 @@ def meta_zeros( ) +@register_meta(aten.select.int) +def meta_select(self, dim, index): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + ndim = self.dim() + torch._check_index( + ndim != 0, + lambda: "select() cannot be applied to a 0-dim tensor.", + ) + + dim = dim if dim >= 0 else dim + ndim + size = self.size(dim) + + torch._check_index( + not ( + guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size) + ), + lambda: f"select(): index {index} out of range for tensor of size " + f"{self.size()} at dimension {dim}", + ) + + index = index if index >= 0 else index + size + + new_size = list(self.size()) + new_stride = list(self.stride()) + + new_storage_offset = self.storage_offset() + index * new_stride[dim] + del new_size[dim] + del new_stride[dim] + + return self.as_strided(new_size, new_stride, new_storage_offset) + + @register_meta(aten.select_scatter.default) def meta_select_scatter(self, src, dim, index): return utils.clone_preserve_strides(self) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index e2e24cb59bc..e802d9a4389 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -359,48 +359,6 @@ def unique2( return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) -@register_op_impl(aten.select.int) -def meta_select(fake_mode, func, self, dim, index): - from torch.fx.experimental.symbolic_shapes import guard_or_false - - if self.is_sparse: - return NotImplemented - - ndim = self.dim() - torch._check_index( - ndim != 0, - lambda: "select() cannot be applied to a 0-dim tensor.", - ) - - dim = dim if dim >= 0 else dim + ndim - size = self.size(dim) - - new_size = list(self.size()) - new_stride = list(self.stride()) - - new_storage_offset = None - if guard_or_false(index >= 0): - new_storage_offset = self.storage_offset() + index * new_stride[dim] - elif guard_or_false(index < 0): - new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim] - - if new_storage_offset is None: - if fake_mode.shape_env is None or ( - not fake_mode.shape_env.allow_scalar_outputs - and not fake_mode.allow_scalar_outputs - ): - raise DataDependentOutputException(func) - - # index is data-dependent, we do not know which index we are accessing it could be index or index+size! - # we assign a new data-dependent symbol for the storage offset. - new_storage_offset = fake_mode.shape_env.create_unbacked_symint() - - del new_size[dim] - del new_stride[dim] - assert new_storage_offset is not None - return self.as_strided(new_size, new_stride, new_storage_offset) - - @register_op_impl(aten.unique_dim.default) def unique_dim( fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 4814e2daefe..e38e5f777d6 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1282,7 +1282,6 @@ def compute_unbacked_bindings( return None fs = shape_env.pending_fresh_unbacked_symbols - pending = set(fs) if not pending: return None @@ -4810,7 +4809,6 @@ class ShapeEnv: ) self.counter["create_unbacked_symbol"] += 1 if not self._ignore_fresh_unbacked_symbols_tls(): - print(f"adding {symbol}") self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index bb71a25971d..38c64c527af 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -461,7 +461,6 @@ def insert_deferred_runtime_asserts( ), keypath[2:], ) - return go( graph.call_method( keypath[0].name, (node, keypath[1].idx) @@ -469,15 +468,6 @@ def insert_deferred_runtime_asserts( keypath[2:], ) elif isinstance(keypath[0], CallMethodKey): - if keypath[0].name == "storage_offset": - return go( - graph.call_function( - torch.ops.aten.sym_storage_offset.default, - (node,), - ), - keypath[1:], - ) - return go( graph.call_method(keypath[0].name, (node,)), keypath[1:] )