mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "DDE-Free select with unbacked index. (#157605)"
This reverts commit 79d7c754ab.
Reverted https://github.com/pytorch/pytorch/pull/157605 on behalf of https://github.com/laithsakka due to fail pr time benchmarks ([comment](https://github.com/pytorch/pytorch/pull/157605#issuecomment-3084663020))
This commit is contained in:
parent
16b21fa8b2
commit
23550ab735
|
|
@ -15782,28 +15782,6 @@ def forward(self, x, mask):
|
|||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
def test_unbacked_select_index(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
u0 = y.item()
|
||||
return x.select(0, u0)
|
||||
|
||||
example_inputs = (
|
||||
torch.randn((3, 3), dtype=torch.bfloat16),
|
||||
torch.tensor([0]),
|
||||
)
|
||||
|
||||
traced = export(MyModel(), example_inputs).run_decompositions({})
|
||||
self.assertExpectedInline(
|
||||
traced.graph_module.code,
|
||||
"""\
|
||||
def forward(self, x, y):
|
||||
item = torch.ops.aten.item.default(y); y = None
|
||||
select = torch.ops.aten.select.int(x, 0, item); x = item = None
|
||||
return (select,)""",
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -3529,88 +3529,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_select_index(self):
|
||||
cnt = CompileCounterWithBackend("inductor")
|
||||
|
||||
def func(x, y):
|
||||
u0 = y.item()
|
||||
return (
|
||||
torch.select(x, 0, u0),
|
||||
torch.select(x, 1, u0),
|
||||
torch.select(x, 2, u0),
|
||||
)
|
||||
|
||||
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
|
||||
x = torch.rand(3, 3, 3)
|
||||
zero = torch.tensor([0])
|
||||
pos = torch.tensor([1])
|
||||
# code can handle both negative and positive indices.
|
||||
neg = torch.tensor([-1])
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
self.assertEqual(compiled_func(x, zero), func(x, zero))
|
||||
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||
self.assertExpectedInline(
|
||||
output,
|
||||
"""\
|
||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
||||
select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense)
|
||||
select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense)
|
||||
select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None
|
||||
return (select, select_1, select_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
self.assertEqual(compiled_func(x, pos), func(x, pos))
|
||||
self.assertEqual(compiled_func(x, neg), func(x, neg))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def func2(x, y):
|
||||
u0, u1 = y.tolist()
|
||||
return torch.select(x, 0, u0 + u1)
|
||||
|
||||
compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)(
|
||||
func2
|
||||
)
|
||||
zero = torch.tensor([0, 0])
|
||||
pos = torch.tensor([1, 1])
|
||||
neg = torch.tensor([-1, -1])
|
||||
|
||||
self.assertEqual(compiled_func2(x, pos), func2(x, pos))
|
||||
self.assertEqual(compiled_func2(x, neg), func2(x, neg))
|
||||
self.assertEqual(compiled_func2(x, zero), func2(x, zero))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_select_index_with_check(self):
|
||||
def func3(x, y):
|
||||
u0 = y.item()
|
||||
# Test that taking the non-unbacked path works fine also.
|
||||
torch._check(u0 >= 0)
|
||||
return (torch.select(x, 1, u0),)
|
||||
|
||||
compiled_func3 = torch.compile(
|
||||
fullgraph=True, backend="inductor", dynamic=True
|
||||
)(func3)
|
||||
x = torch.rand(3, 3, 3)
|
||||
zero = torch.tensor([0])
|
||||
pos = torch.tensor([1])
|
||||
print(compiled_func3(x, pos))
|
||||
|
||||
self.assertEqual(compiled_func3(x, pos), func3(x, pos))
|
||||
self.assertEqual(compiled_func3(x, zero), func3(x, zero))
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||
def test_unbacked_select_index_cpp_wrapper(self):
|
||||
self.test_unbacked_select_index()
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None)
|
|||
)
|
||||
},
|
||||
)
|
||||
|
||||
node.meta["torch_fn"] = (
|
||||
f"{node.target.__name__}_0",
|
||||
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
||||
|
|
|
|||
|
|
@ -1447,20 +1447,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||
self.unbacked_symbol_decls.add(str(node.sym))
|
||||
|
||||
def codegen_dynamic_select_index(self, node):
|
||||
index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
|
||||
|
||||
index_compute_str = (
|
||||
f"{index_cpp_str} < 0 ? {index_cpp_str} + "
|
||||
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
|
||||
)
|
||||
self.writeline(
|
||||
f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
|
||||
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});"
|
||||
)
|
||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
||||
|
||||
def make_buffer_free(self, buffer):
|
||||
return (
|
||||
""
|
||||
|
|
|
|||
|
|
@ -1802,14 +1802,6 @@ class PythonWrapperCodegen(CodeGen):
|
|||
arg_name = node.input_name(0)
|
||||
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
|
||||
|
||||
def codegen_dynamic_select_index(self, node):
|
||||
index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
|
||||
self.writeline(
|
||||
f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
|
||||
)
|
||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
||||
|
||||
def codegen_dynamic_scalar(self, node):
|
||||
(data,) = (t.codegen_reference() for t in node.inputs)
|
||||
if len(node.keypath) == 0:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from unittest.mock import patch
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import get_free_symbols
|
||||
from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
|
@ -39,12 +38,6 @@ class Dep(abc.ABC):
|
|||
name: str
|
||||
index: sympy.Expr
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def rename(self, renames: dict[str, str]) -> Self:
|
||||
pass
|
||||
|
|
@ -77,15 +70,6 @@ class MemoryDep(Dep):
|
|||
size: tuple[sympy.Expr, ...]
|
||||
mode: Optional[str] = None
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return (
|
||||
get_free_symbols(self.index, unbacked_only)
|
||||
| get_free_symbols(self.size, unbacked_only)
|
||||
| get_free_symbols(self.var_names, unbacked_only)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
maybe_mode = ""
|
||||
if self.mode is not None:
|
||||
|
|
@ -323,11 +307,6 @@ class StarDep(Dep):
|
|||
return StarDep(renames[self.name], self.mode)
|
||||
return self
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def numbytes_hint(self) -> int:
|
||||
try:
|
||||
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
|
||||
|
|
@ -364,11 +343,6 @@ class WeakDep(Dep):
|
|||
# Buffer that is doing the mutation
|
||||
mutating_buf: str
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
@property
|
||||
def index(self) -> sympy.Expr:
|
||||
raise NotImplementedError("WeakDep does not have an index")
|
||||
|
|
@ -466,15 +440,6 @@ class ReadWrites:
|
|||
names.add(dep.name)
|
||||
return names
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
result: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
|
||||
for dep in self.reads_and_writes():
|
||||
result |= dep.get_free_symbol_uses(unbacked_only)
|
||||
return result
|
||||
|
||||
|
||||
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
||||
def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
|
||||
|
|
|
|||
|
|
@ -341,7 +341,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
shape_env.deferred_runtime_asserts.copy()
|
||||
)
|
||||
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
||||
|
||||
self.sizevars = SizeVarAllocator(shape_env)
|
||||
self.graph_input_names: list[str] = []
|
||||
self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
|
||||
|
|
@ -1822,7 +1821,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
|
||||
shape_env = V.graph.sizevars.shape_env
|
||||
|
||||
# An input can be unbacked symint i.e.: when mark_unabcked is used.
|
||||
# An input can an unbacked symint i.e.: when mark_unabcked is used.
|
||||
# in that case add it to new_unbacked_defs.
|
||||
if (
|
||||
n.op == "placeholder"
|
||||
|
|
@ -1889,7 +1888,6 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
V.fake_mode.shape_env.unbacked_renamings.get(s, s)
|
||||
for s in unbacked_bindings.keys()
|
||||
)
|
||||
|
||||
assert new_unbacked_defs >= renamed_unbacked_bindings, (
|
||||
f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
|
||||
f"fx node is: {n.format_node()}\n"
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ from torch._dynamo.utils import identity
|
|||
from torch._export.serde.serialize import GraphModuleSerializer
|
||||
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.utils import get_free_symbols
|
||||
from torch._prims_common import (
|
||||
compute_required_storage_length,
|
||||
is_boolean_dtype,
|
||||
|
|
@ -63,6 +62,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
compute_unbacked_bindings,
|
||||
free_symbols,
|
||||
free_unbacked_symbols,
|
||||
IterateExprs,
|
||||
rebind_unbacked,
|
||||
resolve_unbacked_bindings,
|
||||
ShapeEnv,
|
||||
|
|
@ -304,6 +304,13 @@ def fuse_reindexing(
|
|||
return reindex
|
||||
|
||||
|
||||
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_only:
|
||||
return free_unbacked_symbols(x)
|
||||
else:
|
||||
return free_symbols(x)
|
||||
|
||||
|
||||
NHWC_STRIDE_ORDER = [3, 0, 2, 1]
|
||||
NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
|
||||
|
||||
|
|
@ -4322,13 +4329,6 @@ class ComputedBuffer(OperationBuffer):
|
|||
return self.data.get_read_names()
|
||||
|
||||
def get_read_writes(self) -> dependencies.ReadWrites:
|
||||
if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)):
|
||||
return dependencies.ReadWrites(
|
||||
reads=OrderedSet(),
|
||||
writes=OrderedSet(),
|
||||
index_exprs=OrderedSet(),
|
||||
)
|
||||
|
||||
with patch.object(FlexibleLayout, "allow_indexing", True):
|
||||
if self.data.get_reduction_type():
|
||||
return extract_read_writes(
|
||||
|
|
@ -4367,7 +4367,6 @@ class ComputedBuffer(OperationBuffer):
|
|||
| get_free_symbols(self.get_stride(), unbacked_only)
|
||||
| get_free_symbols(self.get_offset(), unbacked_only)
|
||||
| self.data.get_free_symbol_uses(unbacked_only)
|
||||
| self.get_read_writes().get_free_symbol_uses(unbacked_only)
|
||||
)
|
||||
|
||||
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
|
||||
|
|
@ -6976,50 +6975,6 @@ class DeviceCopy(ExternKernelOut):
|
|||
wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
|
||||
|
||||
|
||||
class DynamicSelectStorageOffset(ExternKernel):
|
||||
"""
|
||||
The result of computing a dynamic selection index is determined as follows: when the index in the
|
||||
select operation is unbacked, the actual index calculation is ambiguous for negative indices
|
||||
(index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked
|
||||
SymInt to represent the storage offset and decompose the select operation into a call to as_strided,
|
||||
computing the storage offset at runtime with this node.
|
||||
"""
|
||||
|
||||
def get_reads(self) -> OrderedSet[Dep]:
|
||||
return OrderedSet()
|
||||
|
||||
def should_allocate(self) -> bool:
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unbacked_offset_symbol: sympy.Symbol,
|
||||
index: sympy.Symbol,
|
||||
base_offset: Union[sympy.Symbol, int],
|
||||
base_dim_stride: Union[sympy.Symbol, int],
|
||||
size: Union[sympy.Symbol, int],
|
||||
) -> None:
|
||||
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
|
||||
# This node codegen the following:
|
||||
# unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size)
|
||||
self.unbacked_offset_symbol = unbacked_offset_symbol
|
||||
self.index = index
|
||||
self.base_offset = base_offset
|
||||
self.base_dim_stride = base_dim_stride
|
||||
self.size = size
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet([self.unbacked_offset_symbol])
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return get_free_symbols(self.index, unbacked_only)
|
||||
|
||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
wrapper.codegen_dynamic_select_index(self)
|
||||
|
||||
|
||||
class DynamicScalar(ExternKernel):
|
||||
"""
|
||||
The result of a call to aten._local_scalar_dense.
|
||||
|
|
|
|||
|
|
@ -40,11 +40,7 @@ from torch._prims_common import (
|
|||
Number,
|
||||
)
|
||||
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
free_unbacked_symbols,
|
||||
has_free_unbacked_symbols,
|
||||
resolve_unbacked_bindings,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing
|
||||
|
||||
|
|
@ -994,7 +990,10 @@ def squeeze(x, dim=None):
|
|||
|
||||
new_shape = []
|
||||
for d, s in enumerate(x.get_size()):
|
||||
if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))):
|
||||
if not (
|
||||
d in dims
|
||||
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
|
||||
):
|
||||
new_shape.append(s)
|
||||
|
||||
# squeeze does nothing if the size isn't 1
|
||||
|
|
@ -1760,60 +1759,8 @@ def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
|
|||
|
||||
@register_lowering(aten.select, type_promotion_kind=None)
|
||||
def select(x, dim, idx):
|
||||
idx = sympy.expand(idx)
|
||||
size = sympy.expand(x.get_size()[dim])
|
||||
actual_index = None
|
||||
|
||||
if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)):
|
||||
actual_index = idx + size
|
||||
elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)):
|
||||
actual_index = idx
|
||||
|
||||
if actual_index is not None:
|
||||
if has_free_unbacked_symbols(idx):
|
||||
# Inductor could generate incorrect views for tensors with unbacked symbols here;
|
||||
# Squeeze operations are translated to views, resulting in incorrect strides.
|
||||
# Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
|
||||
# we use as_strided instead.
|
||||
# Removing this branch will cause test_unbacked_select_index_with_check to fail.
|
||||
new_size = x.get_size()
|
||||
new_stride = x.get_stride()
|
||||
new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
|
||||
|
||||
del new_size[dim]
|
||||
del new_stride[dim]
|
||||
return as_strided(x, new_size, new_stride, new_storage_offset)
|
||||
else:
|
||||
slice_result = slice_(x, dim, actual_index, actual_index + 1)
|
||||
return squeeze(slice_result, dim)
|
||||
|
||||
# Unbacked Semantics:
|
||||
# When the index idx is unbacked (e.g., u0), we compute the index dynamically
|
||||
# during the lowering of the select operation using DynamicSelectStorageOffset.
|
||||
|
||||
unbacked_bindings = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
|
||||
)
|
||||
assert unbacked_bindings is not None
|
||||
assert len(unbacked_bindings) == 1, unbacked_bindings
|
||||
unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
|
||||
|
||||
new_size = x.get_size()
|
||||
new_stride = x.get_stride()
|
||||
new_storage_offset = unbacked_offset_sym
|
||||
buffer = ir.DynamicSelectStorageOffset(
|
||||
unbacked_offset_sym,
|
||||
idx,
|
||||
x.get_layout().offset,
|
||||
new_stride[dim],
|
||||
x.get_size()[dim],
|
||||
)
|
||||
buffer.name = V.graph.register_buffer(buffer)
|
||||
V.graph.register_operation(buffer)
|
||||
|
||||
del new_size[dim]
|
||||
del new_stride[dim]
|
||||
return as_strided(x, new_size, new_stride, new_storage_offset)
|
||||
idx = View.handle_negative_index(idx, x.get_size()[dim])
|
||||
return squeeze(slice_(x, dim, idx, idx + 1), dim)
|
||||
|
||||
|
||||
@register_lowering(aten.split, type_promotion_kind=None)
|
||||
|
|
@ -3139,6 +3086,8 @@ def long_tensor(data):
|
|||
|
||||
@register_lowering(aten._local_scalar_dense)
|
||||
def _local_scalar_dense(data):
|
||||
from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings
|
||||
|
||||
# This is interesting! Most lowerings return tensors, so you can just
|
||||
# return the buffer you allocated and it will get used (or not used, if
|
||||
# it's dead.) But _local_scalar_dense (aka item) returns an int,
|
||||
|
|
|
|||
|
|
@ -2130,11 +2130,9 @@ class Scheduler:
|
|||
self.logged_slow_fusion = OrderedSet[tuple[str, str]]()
|
||||
if config._pre_fusion_custom_pass is not None:
|
||||
self.nodes = config._pre_fusion_custom_pass(self.nodes)
|
||||
|
||||
self.nodes = self.fuse_nodes(self.nodes)
|
||||
if config._post_fusion_custom_pass is not None:
|
||||
self.nodes = config._post_fusion_custom_pass(self.nodes)
|
||||
|
||||
self.merge_loops()
|
||||
self.finalize_multi_template_buffers()
|
||||
if config.combo_kernels:
|
||||
|
|
@ -2368,6 +2366,7 @@ class Scheduler:
|
|||
|
||||
for node in self.nodes:
|
||||
log.debug("scheduling %s", node.node)
|
||||
|
||||
# unbacked symbols don't follow ordinary buffer dependencies, so
|
||||
# we track their def/uses separately
|
||||
assert node.node is not None
|
||||
|
|
|
|||
|
|
@ -69,20 +69,13 @@ OPTIMUS_EXCLUDE_POST_GRAD = [
|
|||
"inductor_autotune_lookup_table",
|
||||
]
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
free_symbols,
|
||||
free_unbacked_symbols,
|
||||
IterateExprs,
|
||||
ShapeEnv,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence, ValuesView
|
||||
|
||||
from torch import SymBool, SymFloat, SymInt
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .codegen.common import WorkspaceArg
|
||||
|
|
@ -3366,10 +3359,3 @@ def aoti_model_name_from_config() -> str:
|
|||
model_name = config.aot_inductor.model_name_for_generated_files
|
||||
model_name = "aoti_model" if model_name is None else model_name
|
||||
return model_name
|
||||
|
||||
|
||||
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_only:
|
||||
return free_unbacked_symbols(x)
|
||||
else:
|
||||
return free_symbols(x)
|
||||
|
|
|
|||
|
|
@ -5553,6 +5553,39 @@ def meta_zeros(
|
|||
)
|
||||
|
||||
|
||||
@register_meta(aten.select.int)
|
||||
def meta_select(self, dim, index):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
ndim = self.dim()
|
||||
torch._check_index(
|
||||
ndim != 0,
|
||||
lambda: "select() cannot be applied to a 0-dim tensor.",
|
||||
)
|
||||
|
||||
dim = dim if dim >= 0 else dim + ndim
|
||||
size = self.size(dim)
|
||||
|
||||
torch._check_index(
|
||||
not (
|
||||
guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
|
||||
),
|
||||
lambda: f"select(): index {index} out of range for tensor of size "
|
||||
f"{self.size()} at dimension {dim}",
|
||||
)
|
||||
|
||||
index = index if index >= 0 else index + size
|
||||
|
||||
new_size = list(self.size())
|
||||
new_stride = list(self.stride())
|
||||
|
||||
new_storage_offset = self.storage_offset() + index * new_stride[dim]
|
||||
del new_size[dim]
|
||||
del new_stride[dim]
|
||||
|
||||
return self.as_strided(new_size, new_stride, new_storage_offset)
|
||||
|
||||
|
||||
@register_meta(aten.select_scatter.default)
|
||||
def meta_select_scatter(self, src, dim, index):
|
||||
return utils.clone_preserve_strides(self)
|
||||
|
|
|
|||
|
|
@ -359,48 +359,6 @@ def unique2(
|
|||
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
|
||||
|
||||
|
||||
@register_op_impl(aten.select.int)
|
||||
def meta_select(fake_mode, func, self, dim, index):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if self.is_sparse:
|
||||
return NotImplemented
|
||||
|
||||
ndim = self.dim()
|
||||
torch._check_index(
|
||||
ndim != 0,
|
||||
lambda: "select() cannot be applied to a 0-dim tensor.",
|
||||
)
|
||||
|
||||
dim = dim if dim >= 0 else dim + ndim
|
||||
size = self.size(dim)
|
||||
|
||||
new_size = list(self.size())
|
||||
new_stride = list(self.stride())
|
||||
|
||||
new_storage_offset = None
|
||||
if guard_or_false(index >= 0):
|
||||
new_storage_offset = self.storage_offset() + index * new_stride[dim]
|
||||
elif guard_or_false(index < 0):
|
||||
new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim]
|
||||
|
||||
if new_storage_offset is None:
|
||||
if fake_mode.shape_env is None or (
|
||||
not fake_mode.shape_env.allow_scalar_outputs
|
||||
and not fake_mode.allow_scalar_outputs
|
||||
):
|
||||
raise DataDependentOutputException(func)
|
||||
|
||||
# index is data-dependent, we do not know which index we are accessing it could be index or index+size!
|
||||
# we assign a new data-dependent symbol for the storage offset.
|
||||
new_storage_offset = fake_mode.shape_env.create_unbacked_symint()
|
||||
|
||||
del new_size[dim]
|
||||
del new_stride[dim]
|
||||
assert new_storage_offset is not None
|
||||
return self.as_strided(new_size, new_stride, new_storage_offset)
|
||||
|
||||
|
||||
@register_op_impl(aten.unique_dim.default)
|
||||
def unique_dim(
|
||||
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
|
||||
|
|
|
|||
|
|
@ -1282,7 +1282,6 @@ def compute_unbacked_bindings(
|
|||
return None
|
||||
|
||||
fs = shape_env.pending_fresh_unbacked_symbols
|
||||
|
||||
pending = set(fs)
|
||||
if not pending:
|
||||
return None
|
||||
|
|
@ -4810,7 +4809,6 @@ class ShapeEnv:
|
|||
)
|
||||
self.counter["create_unbacked_symbol"] += 1
|
||||
if not self._ignore_fresh_unbacked_symbols_tls():
|
||||
print(f"adding {symbol}")
|
||||
self.pending_fresh_unbacked_symbols.append(symbol)
|
||||
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
||||
vr = self.var_to_range[symbol] = ValueRanges.unknown()
|
||||
|
|
|
|||
|
|
@ -461,7 +461,6 @@ def insert_deferred_runtime_asserts(
|
|||
),
|
||||
keypath[2:],
|
||||
)
|
||||
|
||||
return go(
|
||||
graph.call_method(
|
||||
keypath[0].name, (node, keypath[1].idx)
|
||||
|
|
@ -469,15 +468,6 @@ def insert_deferred_runtime_asserts(
|
|||
keypath[2:],
|
||||
)
|
||||
elif isinstance(keypath[0], CallMethodKey):
|
||||
if keypath[0].name == "storage_offset":
|
||||
return go(
|
||||
graph.call_function(
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
(node,),
|
||||
),
|
||||
keypath[1:],
|
||||
)
|
||||
|
||||
return go(
|
||||
graph.call_method(keypath[0].name, (node,)), keypath[1:]
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user