mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "DDE-Free select with unbacked index. (#157605)"
This reverts commit 79d7c754ab.
Reverted https://github.com/pytorch/pytorch/pull/157605 on behalf of https://github.com/laithsakka due to fail pr time benchmarks ([comment](https://github.com/pytorch/pytorch/pull/157605#issuecomment-3084663020))
This commit is contained in:
parent
16b21fa8b2
commit
23550ab735
|
|
@ -15782,28 +15782,6 @@ def forward(self, x, mask):
|
||||||
ignore_empty_lines=True,
|
ignore_empty_lines=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unbacked_select_index(self):
|
|
||||||
class MyModel(torch.nn.Module):
|
|
||||||
def forward(self, x, y):
|
|
||||||
u0 = y.item()
|
|
||||||
return x.select(0, u0)
|
|
||||||
|
|
||||||
example_inputs = (
|
|
||||||
torch.randn((3, 3), dtype=torch.bfloat16),
|
|
||||||
torch.tensor([0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
traced = export(MyModel(), example_inputs).run_decompositions({})
|
|
||||||
self.assertExpectedInline(
|
|
||||||
traced.graph_module.code,
|
|
||||||
"""\
|
|
||||||
def forward(self, x, y):
|
|
||||||
item = torch.ops.aten.item.default(y); y = None
|
|
||||||
select = torch.ops.aten.select.int(x, 0, item); x = item = None
|
|
||||||
return (select,)""",
|
|
||||||
ignore_empty_lines=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -3529,88 +3529,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||||
ignore_empty_lines=True,
|
ignore_empty_lines=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@fresh_cache()
|
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
||||||
def test_unbacked_select_index(self):
|
|
||||||
cnt = CompileCounterWithBackend("inductor")
|
|
||||||
|
|
||||||
def func(x, y):
|
|
||||||
u0 = y.item()
|
|
||||||
return (
|
|
||||||
torch.select(x, 0, u0),
|
|
||||||
torch.select(x, 1, u0),
|
|
||||||
torch.select(x, 2, u0),
|
|
||||||
)
|
|
||||||
|
|
||||||
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
|
|
||||||
x = torch.rand(3, 3, 3)
|
|
||||||
zero = torch.tensor([0])
|
|
||||||
pos = torch.tensor([1])
|
|
||||||
# code can handle both negative and positive indices.
|
|
||||||
neg = torch.tensor([-1])
|
|
||||||
|
|
||||||
log_stream, ctx = logs_to_string(
|
|
||||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
||||||
)
|
|
||||||
with ctx():
|
|
||||||
self.assertEqual(compiled_func(x, zero), func(x, zero))
|
|
||||||
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
||||||
self.assertExpectedInline(
|
|
||||||
output,
|
|
||||||
"""\
|
|
||||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
|
||||||
select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense)
|
|
||||||
select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense)
|
|
||||||
select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None
|
|
||||||
return (select, select_1, select_2)""", # noqa: B950
|
|
||||||
ignore_comments=True,
|
|
||||||
ignore_empty_lines=True,
|
|
||||||
)
|
|
||||||
self.assertEqual(compiled_func(x, pos), func(x, pos))
|
|
||||||
self.assertEqual(compiled_func(x, neg), func(x, neg))
|
|
||||||
self.assertEqual(cnt.frame_count, 1)
|
|
||||||
|
|
||||||
def func2(x, y):
|
|
||||||
u0, u1 = y.tolist()
|
|
||||||
return torch.select(x, 0, u0 + u1)
|
|
||||||
|
|
||||||
compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)(
|
|
||||||
func2
|
|
||||||
)
|
|
||||||
zero = torch.tensor([0, 0])
|
|
||||||
pos = torch.tensor([1, 1])
|
|
||||||
neg = torch.tensor([-1, -1])
|
|
||||||
|
|
||||||
self.assertEqual(compiled_func2(x, pos), func2(x, pos))
|
|
||||||
self.assertEqual(compiled_func2(x, neg), func2(x, neg))
|
|
||||||
self.assertEqual(compiled_func2(x, zero), func2(x, zero))
|
|
||||||
self.assertEqual(cnt.frame_count, 2)
|
|
||||||
|
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
||||||
def test_unbacked_select_index_with_check(self):
|
|
||||||
def func3(x, y):
|
|
||||||
u0 = y.item()
|
|
||||||
# Test that taking the non-unbacked path works fine also.
|
|
||||||
torch._check(u0 >= 0)
|
|
||||||
return (torch.select(x, 1, u0),)
|
|
||||||
|
|
||||||
compiled_func3 = torch.compile(
|
|
||||||
fullgraph=True, backend="inductor", dynamic=True
|
|
||||||
)(func3)
|
|
||||||
x = torch.rand(3, 3, 3)
|
|
||||||
zero = torch.tensor([0])
|
|
||||||
pos = torch.tensor([1])
|
|
||||||
print(compiled_func3(x, pos))
|
|
||||||
|
|
||||||
self.assertEqual(compiled_func3(x, pos), func3(x, pos))
|
|
||||||
self.assertEqual(compiled_func3(x, zero), func3(x, zero))
|
|
||||||
|
|
||||||
@fresh_cache()
|
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
||||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
|
||||||
def test_unbacked_select_index_cpp_wrapper(self):
|
|
||||||
self.test_unbacked_select_index()
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestUnbacked)
|
instantiate_parametrized_tests(TestUnbacked)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,6 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None)
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
node.meta["torch_fn"] = (
|
node.meta["torch_fn"] = (
|
||||||
f"{node.target.__name__}_0",
|
f"{node.target.__name__}_0",
|
||||||
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
||||||
|
|
|
||||||
|
|
@ -1447,20 +1447,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||||
self.unbacked_symbol_decls.add(str(node.sym))
|
self.unbacked_symbol_decls.add(str(node.sym))
|
||||||
|
|
||||||
def codegen_dynamic_select_index(self, node):
|
|
||||||
index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
|
|
||||||
|
|
||||||
index_compute_str = (
|
|
||||||
f"{index_cpp_str} < 0 ? {index_cpp_str} + "
|
|
||||||
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
|
|
||||||
)
|
|
||||||
self.writeline(
|
|
||||||
f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
|
|
||||||
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});"
|
|
||||||
)
|
|
||||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
|
||||||
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
|
||||||
|
|
||||||
def make_buffer_free(self, buffer):
|
def make_buffer_free(self, buffer):
|
||||||
return (
|
return (
|
||||||
""
|
""
|
||||||
|
|
|
||||||
|
|
@ -1802,14 +1802,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
arg_name = node.input_name(0)
|
arg_name = node.input_name(0)
|
||||||
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
|
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
|
||||||
|
|
||||||
def codegen_dynamic_select_index(self, node):
|
|
||||||
index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
|
|
||||||
self.writeline(
|
|
||||||
f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
|
|
||||||
)
|
|
||||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
|
||||||
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
|
||||||
|
|
||||||
def codegen_dynamic_scalar(self, node):
|
def codegen_dynamic_scalar(self, node):
|
||||||
(data,) = (t.codegen_reference() for t in node.inputs)
|
(data,) = (t.codegen_reference() for t in node.inputs)
|
||||||
if len(node.keypath) == 0:
|
if len(node.keypath) == 0:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from unittest.mock import patch
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import get_free_symbols
|
|
||||||
from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols
|
from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
|
@ -39,12 +38,6 @@ class Dep(abc.ABC):
|
||||||
name: str
|
name: str
|
||||||
index: sympy.Expr
|
index: sympy.Expr
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_free_symbol_uses(
|
|
||||||
self, unbacked_only: bool = False
|
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def rename(self, renames: dict[str, str]) -> Self:
|
def rename(self, renames: dict[str, str]) -> Self:
|
||||||
pass
|
pass
|
||||||
|
|
@ -77,15 +70,6 @@ class MemoryDep(Dep):
|
||||||
size: tuple[sympy.Expr, ...]
|
size: tuple[sympy.Expr, ...]
|
||||||
mode: Optional[str] = None
|
mode: Optional[str] = None
|
||||||
|
|
||||||
def get_free_symbol_uses(
|
|
||||||
self, unbacked_only: bool = False
|
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
|
||||||
return (
|
|
||||||
get_free_symbols(self.index, unbacked_only)
|
|
||||||
| get_free_symbols(self.size, unbacked_only)
|
|
||||||
| get_free_symbols(self.var_names, unbacked_only)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
maybe_mode = ""
|
maybe_mode = ""
|
||||||
if self.mode is not None:
|
if self.mode is not None:
|
||||||
|
|
@ -323,11 +307,6 @@ class StarDep(Dep):
|
||||||
return StarDep(renames[self.name], self.mode)
|
return StarDep(renames[self.name], self.mode)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_free_symbol_uses(
|
|
||||||
self, unbacked_only: bool = False
|
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
|
||||||
return OrderedSet()
|
|
||||||
|
|
||||||
def numbytes_hint(self) -> int:
|
def numbytes_hint(self) -> int:
|
||||||
try:
|
try:
|
||||||
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
|
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
|
||||||
|
|
@ -364,11 +343,6 @@ class WeakDep(Dep):
|
||||||
# Buffer that is doing the mutation
|
# Buffer that is doing the mutation
|
||||||
mutating_buf: str
|
mutating_buf: str
|
||||||
|
|
||||||
def get_free_symbol_uses(
|
|
||||||
self, unbacked_only: bool = False
|
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
|
||||||
return OrderedSet()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def index(self) -> sympy.Expr:
|
def index(self) -> sympy.Expr:
|
||||||
raise NotImplementedError("WeakDep does not have an index")
|
raise NotImplementedError("WeakDep does not have an index")
|
||||||
|
|
@ -466,15 +440,6 @@ class ReadWrites:
|
||||||
names.add(dep.name)
|
names.add(dep.name)
|
||||||
return names
|
return names
|
||||||
|
|
||||||
def get_free_symbol_uses(
|
|
||||||
self, unbacked_only: bool = False
|
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
|
||||||
result: OrderedSet[sympy.Symbol] = OrderedSet()
|
|
||||||
|
|
||||||
for dep in self.reads_and_writes():
|
|
||||||
result |= dep.get_free_symbol_uses(unbacked_only)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
||||||
def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
|
def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
|
||||||
|
|
|
||||||
|
|
@ -341,7 +341,6 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
shape_env.deferred_runtime_asserts.copy()
|
shape_env.deferred_runtime_asserts.copy()
|
||||||
)
|
)
|
||||||
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
||||||
|
|
||||||
self.sizevars = SizeVarAllocator(shape_env)
|
self.sizevars = SizeVarAllocator(shape_env)
|
||||||
self.graph_input_names: list[str] = []
|
self.graph_input_names: list[str] = []
|
||||||
self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
|
self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
|
||||||
|
|
@ -1822,7 +1821,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
|
|
||||||
shape_env = V.graph.sizevars.shape_env
|
shape_env = V.graph.sizevars.shape_env
|
||||||
|
|
||||||
# An input can be unbacked symint i.e.: when mark_unabcked is used.
|
# An input can an unbacked symint i.e.: when mark_unabcked is used.
|
||||||
# in that case add it to new_unbacked_defs.
|
# in that case add it to new_unbacked_defs.
|
||||||
if (
|
if (
|
||||||
n.op == "placeholder"
|
n.op == "placeholder"
|
||||||
|
|
@ -1889,7 +1888,6 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
V.fake_mode.shape_env.unbacked_renamings.get(s, s)
|
V.fake_mode.shape_env.unbacked_renamings.get(s, s)
|
||||||
for s in unbacked_bindings.keys()
|
for s in unbacked_bindings.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
assert new_unbacked_defs >= renamed_unbacked_bindings, (
|
assert new_unbacked_defs >= renamed_unbacked_bindings, (
|
||||||
f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
|
f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
|
||||||
f"fx node is: {n.format_node()}\n"
|
f"fx node is: {n.format_node()}\n"
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,6 @@ from torch._dynamo.utils import identity
|
||||||
from torch._export.serde.serialize import GraphModuleSerializer
|
from torch._export.serde.serialize import GraphModuleSerializer
|
||||||
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
|
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
|
||||||
from torch._inductor import metrics
|
from torch._inductor import metrics
|
||||||
from torch._inductor.utils import get_free_symbols
|
|
||||||
from torch._prims_common import (
|
from torch._prims_common import (
|
||||||
compute_required_storage_length,
|
compute_required_storage_length,
|
||||||
is_boolean_dtype,
|
is_boolean_dtype,
|
||||||
|
|
@ -63,6 +62,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||||
compute_unbacked_bindings,
|
compute_unbacked_bindings,
|
||||||
free_symbols,
|
free_symbols,
|
||||||
free_unbacked_symbols,
|
free_unbacked_symbols,
|
||||||
|
IterateExprs,
|
||||||
rebind_unbacked,
|
rebind_unbacked,
|
||||||
resolve_unbacked_bindings,
|
resolve_unbacked_bindings,
|
||||||
ShapeEnv,
|
ShapeEnv,
|
||||||
|
|
@ -304,6 +304,13 @@ def fuse_reindexing(
|
||||||
return reindex
|
return reindex
|
||||||
|
|
||||||
|
|
||||||
|
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
|
||||||
|
if unbacked_only:
|
||||||
|
return free_unbacked_symbols(x)
|
||||||
|
else:
|
||||||
|
return free_symbols(x)
|
||||||
|
|
||||||
|
|
||||||
NHWC_STRIDE_ORDER = [3, 0, 2, 1]
|
NHWC_STRIDE_ORDER = [3, 0, 2, 1]
|
||||||
NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
|
NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
|
||||||
|
|
||||||
|
|
@ -4322,13 +4329,6 @@ class ComputedBuffer(OperationBuffer):
|
||||||
return self.data.get_read_names()
|
return self.data.get_read_names()
|
||||||
|
|
||||||
def get_read_writes(self) -> dependencies.ReadWrites:
|
def get_read_writes(self) -> dependencies.ReadWrites:
|
||||||
if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)):
|
|
||||||
return dependencies.ReadWrites(
|
|
||||||
reads=OrderedSet(),
|
|
||||||
writes=OrderedSet(),
|
|
||||||
index_exprs=OrderedSet(),
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(FlexibleLayout, "allow_indexing", True):
|
with patch.object(FlexibleLayout, "allow_indexing", True):
|
||||||
if self.data.get_reduction_type():
|
if self.data.get_reduction_type():
|
||||||
return extract_read_writes(
|
return extract_read_writes(
|
||||||
|
|
@ -4367,7 +4367,6 @@ class ComputedBuffer(OperationBuffer):
|
||||||
| get_free_symbols(self.get_stride(), unbacked_only)
|
| get_free_symbols(self.get_stride(), unbacked_only)
|
||||||
| get_free_symbols(self.get_offset(), unbacked_only)
|
| get_free_symbols(self.get_offset(), unbacked_only)
|
||||||
| self.data.get_free_symbol_uses(unbacked_only)
|
| self.data.get_free_symbol_uses(unbacked_only)
|
||||||
| self.get_read_writes().get_free_symbol_uses(unbacked_only)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
|
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
|
||||||
|
|
@ -6976,50 +6975,6 @@ class DeviceCopy(ExternKernelOut):
|
||||||
wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
|
wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
|
||||||
|
|
||||||
|
|
||||||
class DynamicSelectStorageOffset(ExternKernel):
|
|
||||||
"""
|
|
||||||
The result of computing a dynamic selection index is determined as follows: when the index in the
|
|
||||||
select operation is unbacked, the actual index calculation is ambiguous for negative indices
|
|
||||||
(index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked
|
|
||||||
SymInt to represent the storage offset and decompose the select operation into a call to as_strided,
|
|
||||||
computing the storage offset at runtime with this node.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_reads(self) -> OrderedSet[Dep]:
|
|
||||||
return OrderedSet()
|
|
||||||
|
|
||||||
def should_allocate(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
unbacked_offset_symbol: sympy.Symbol,
|
|
||||||
index: sympy.Symbol,
|
|
||||||
base_offset: Union[sympy.Symbol, int],
|
|
||||||
base_dim_stride: Union[sympy.Symbol, int],
|
|
||||||
size: Union[sympy.Symbol, int],
|
|
||||||
) -> None:
|
|
||||||
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
|
|
||||||
# This node codegen the following:
|
|
||||||
# unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size)
|
|
||||||
self.unbacked_offset_symbol = unbacked_offset_symbol
|
|
||||||
self.index = index
|
|
||||||
self.base_offset = base_offset
|
|
||||||
self.base_dim_stride = base_dim_stride
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
|
||||||
return OrderedSet([self.unbacked_offset_symbol])
|
|
||||||
|
|
||||||
def get_free_symbol_uses(
|
|
||||||
self, unbacked_only: bool = False
|
|
||||||
) -> OrderedSet[sympy.Symbol]:
|
|
||||||
return get_free_symbols(self.index, unbacked_only)
|
|
||||||
|
|
||||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
|
||||||
wrapper.codegen_dynamic_select_index(self)
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicScalar(ExternKernel):
|
class DynamicScalar(ExternKernel):
|
||||||
"""
|
"""
|
||||||
The result of a call to aten._local_scalar_dense.
|
The result of a call to aten._local_scalar_dense.
|
||||||
|
|
|
||||||
|
|
@ -40,11 +40,7 @@ from torch._prims_common import (
|
||||||
Number,
|
Number,
|
||||||
)
|
)
|
||||||
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||||
free_unbacked_symbols,
|
|
||||||
has_free_unbacked_symbols,
|
|
||||||
resolve_unbacked_bindings,
|
|
||||||
)
|
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing
|
from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing
|
||||||
|
|
||||||
|
|
@ -994,7 +990,10 @@ def squeeze(x, dim=None):
|
||||||
|
|
||||||
new_shape = []
|
new_shape = []
|
||||||
for d, s in enumerate(x.get_size()):
|
for d, s in enumerate(x.get_size()):
|
||||||
if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))):
|
if not (
|
||||||
|
d in dims
|
||||||
|
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
|
||||||
|
):
|
||||||
new_shape.append(s)
|
new_shape.append(s)
|
||||||
|
|
||||||
# squeeze does nothing if the size isn't 1
|
# squeeze does nothing if the size isn't 1
|
||||||
|
|
@ -1760,60 +1759,8 @@ def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
|
||||||
|
|
||||||
@register_lowering(aten.select, type_promotion_kind=None)
|
@register_lowering(aten.select, type_promotion_kind=None)
|
||||||
def select(x, dim, idx):
|
def select(x, dim, idx):
|
||||||
idx = sympy.expand(idx)
|
idx = View.handle_negative_index(idx, x.get_size()[dim])
|
||||||
size = sympy.expand(x.get_size()[dim])
|
return squeeze(slice_(x, dim, idx, idx + 1), dim)
|
||||||
actual_index = None
|
|
||||||
|
|
||||||
if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)):
|
|
||||||
actual_index = idx + size
|
|
||||||
elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)):
|
|
||||||
actual_index = idx
|
|
||||||
|
|
||||||
if actual_index is not None:
|
|
||||||
if has_free_unbacked_symbols(idx):
|
|
||||||
# Inductor could generate incorrect views for tensors with unbacked symbols here;
|
|
||||||
# Squeeze operations are translated to views, resulting in incorrect strides.
|
|
||||||
# Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
|
|
||||||
# we use as_strided instead.
|
|
||||||
# Removing this branch will cause test_unbacked_select_index_with_check to fail.
|
|
||||||
new_size = x.get_size()
|
|
||||||
new_stride = x.get_stride()
|
|
||||||
new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
|
|
||||||
|
|
||||||
del new_size[dim]
|
|
||||||
del new_stride[dim]
|
|
||||||
return as_strided(x, new_size, new_stride, new_storage_offset)
|
|
||||||
else:
|
|
||||||
slice_result = slice_(x, dim, actual_index, actual_index + 1)
|
|
||||||
return squeeze(slice_result, dim)
|
|
||||||
|
|
||||||
# Unbacked Semantics:
|
|
||||||
# When the index idx is unbacked (e.g., u0), we compute the index dynamically
|
|
||||||
# during the lowering of the select operation using DynamicSelectStorageOffset.
|
|
||||||
|
|
||||||
unbacked_bindings = resolve_unbacked_bindings(
|
|
||||||
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
|
|
||||||
)
|
|
||||||
assert unbacked_bindings is not None
|
|
||||||
assert len(unbacked_bindings) == 1, unbacked_bindings
|
|
||||||
unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
|
|
||||||
|
|
||||||
new_size = x.get_size()
|
|
||||||
new_stride = x.get_stride()
|
|
||||||
new_storage_offset = unbacked_offset_sym
|
|
||||||
buffer = ir.DynamicSelectStorageOffset(
|
|
||||||
unbacked_offset_sym,
|
|
||||||
idx,
|
|
||||||
x.get_layout().offset,
|
|
||||||
new_stride[dim],
|
|
||||||
x.get_size()[dim],
|
|
||||||
)
|
|
||||||
buffer.name = V.graph.register_buffer(buffer)
|
|
||||||
V.graph.register_operation(buffer)
|
|
||||||
|
|
||||||
del new_size[dim]
|
|
||||||
del new_stride[dim]
|
|
||||||
return as_strided(x, new_size, new_stride, new_storage_offset)
|
|
||||||
|
|
||||||
|
|
||||||
@register_lowering(aten.split, type_promotion_kind=None)
|
@register_lowering(aten.split, type_promotion_kind=None)
|
||||||
|
|
@ -3139,6 +3086,8 @@ def long_tensor(data):
|
||||||
|
|
||||||
@register_lowering(aten._local_scalar_dense)
|
@register_lowering(aten._local_scalar_dense)
|
||||||
def _local_scalar_dense(data):
|
def _local_scalar_dense(data):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings
|
||||||
|
|
||||||
# This is interesting! Most lowerings return tensors, so you can just
|
# This is interesting! Most lowerings return tensors, so you can just
|
||||||
# return the buffer you allocated and it will get used (or not used, if
|
# return the buffer you allocated and it will get used (or not used, if
|
||||||
# it's dead.) But _local_scalar_dense (aka item) returns an int,
|
# it's dead.) But _local_scalar_dense (aka item) returns an int,
|
||||||
|
|
|
||||||
|
|
@ -2130,11 +2130,9 @@ class Scheduler:
|
||||||
self.logged_slow_fusion = OrderedSet[tuple[str, str]]()
|
self.logged_slow_fusion = OrderedSet[tuple[str, str]]()
|
||||||
if config._pre_fusion_custom_pass is not None:
|
if config._pre_fusion_custom_pass is not None:
|
||||||
self.nodes = config._pre_fusion_custom_pass(self.nodes)
|
self.nodes = config._pre_fusion_custom_pass(self.nodes)
|
||||||
|
|
||||||
self.nodes = self.fuse_nodes(self.nodes)
|
self.nodes = self.fuse_nodes(self.nodes)
|
||||||
if config._post_fusion_custom_pass is not None:
|
if config._post_fusion_custom_pass is not None:
|
||||||
self.nodes = config._post_fusion_custom_pass(self.nodes)
|
self.nodes = config._post_fusion_custom_pass(self.nodes)
|
||||||
|
|
||||||
self.merge_loops()
|
self.merge_loops()
|
||||||
self.finalize_multi_template_buffers()
|
self.finalize_multi_template_buffers()
|
||||||
if config.combo_kernels:
|
if config.combo_kernels:
|
||||||
|
|
@ -2368,6 +2366,7 @@ class Scheduler:
|
||||||
|
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
log.debug("scheduling %s", node.node)
|
log.debug("scheduling %s", node.node)
|
||||||
|
|
||||||
# unbacked symbols don't follow ordinary buffer dependencies, so
|
# unbacked symbols don't follow ordinary buffer dependencies, so
|
||||||
# we track their def/uses separately
|
# we track their def/uses separately
|
||||||
assert node.node is not None
|
assert node.node is not None
|
||||||
|
|
|
||||||
|
|
@ -69,20 +69,13 @@ OPTIMUS_EXCLUDE_POST_GRAD = [
|
||||||
"inductor_autotune_lookup_table",
|
"inductor_autotune_lookup_table",
|
||||||
]
|
]
|
||||||
|
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
|
||||||
free_symbols,
|
|
||||||
free_unbacked_symbols,
|
|
||||||
IterateExprs,
|
|
||||||
ShapeEnv,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Iterable, Sequence, ValuesView
|
from collections.abc import Iterable, Sequence, ValuesView
|
||||||
|
|
||||||
from torch import SymBool, SymFloat, SymInt
|
from torch import SymBool, SymFloat, SymInt
|
||||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from .codegen.common import WorkspaceArg
|
from .codegen.common import WorkspaceArg
|
||||||
|
|
@ -3366,10 +3359,3 @@ def aoti_model_name_from_config() -> str:
|
||||||
model_name = config.aot_inductor.model_name_for_generated_files
|
model_name = config.aot_inductor.model_name_for_generated_files
|
||||||
model_name = "aoti_model" if model_name is None else model_name
|
model_name = "aoti_model" if model_name is None else model_name
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
|
|
||||||
if unbacked_only:
|
|
||||||
return free_unbacked_symbols(x)
|
|
||||||
else:
|
|
||||||
return free_symbols(x)
|
|
||||||
|
|
|
||||||
|
|
@ -5553,6 +5553,39 @@ def meta_zeros(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(aten.select.int)
|
||||||
|
def meta_select(self, dim, index):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||||
|
|
||||||
|
ndim = self.dim()
|
||||||
|
torch._check_index(
|
||||||
|
ndim != 0,
|
||||||
|
lambda: "select() cannot be applied to a 0-dim tensor.",
|
||||||
|
)
|
||||||
|
|
||||||
|
dim = dim if dim >= 0 else dim + ndim
|
||||||
|
size = self.size(dim)
|
||||||
|
|
||||||
|
torch._check_index(
|
||||||
|
not (
|
||||||
|
guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
|
||||||
|
),
|
||||||
|
lambda: f"select(): index {index} out of range for tensor of size "
|
||||||
|
f"{self.size()} at dimension {dim}",
|
||||||
|
)
|
||||||
|
|
||||||
|
index = index if index >= 0 else index + size
|
||||||
|
|
||||||
|
new_size = list(self.size())
|
||||||
|
new_stride = list(self.stride())
|
||||||
|
|
||||||
|
new_storage_offset = self.storage_offset() + index * new_stride[dim]
|
||||||
|
del new_size[dim]
|
||||||
|
del new_stride[dim]
|
||||||
|
|
||||||
|
return self.as_strided(new_size, new_stride, new_storage_offset)
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.select_scatter.default)
|
@register_meta(aten.select_scatter.default)
|
||||||
def meta_select_scatter(self, src, dim, index):
|
def meta_select_scatter(self, src, dim, index):
|
||||||
return utils.clone_preserve_strides(self)
|
return utils.clone_preserve_strides(self)
|
||||||
|
|
|
||||||
|
|
@ -359,48 +359,6 @@ def unique2(
|
||||||
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
|
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
|
||||||
|
|
||||||
|
|
||||||
@register_op_impl(aten.select.int)
|
|
||||||
def meta_select(fake_mode, func, self, dim, index):
|
|
||||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
|
||||||
|
|
||||||
if self.is_sparse:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
ndim = self.dim()
|
|
||||||
torch._check_index(
|
|
||||||
ndim != 0,
|
|
||||||
lambda: "select() cannot be applied to a 0-dim tensor.",
|
|
||||||
)
|
|
||||||
|
|
||||||
dim = dim if dim >= 0 else dim + ndim
|
|
||||||
size = self.size(dim)
|
|
||||||
|
|
||||||
new_size = list(self.size())
|
|
||||||
new_stride = list(self.stride())
|
|
||||||
|
|
||||||
new_storage_offset = None
|
|
||||||
if guard_or_false(index >= 0):
|
|
||||||
new_storage_offset = self.storage_offset() + index * new_stride[dim]
|
|
||||||
elif guard_or_false(index < 0):
|
|
||||||
new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim]
|
|
||||||
|
|
||||||
if new_storage_offset is None:
|
|
||||||
if fake_mode.shape_env is None or (
|
|
||||||
not fake_mode.shape_env.allow_scalar_outputs
|
|
||||||
and not fake_mode.allow_scalar_outputs
|
|
||||||
):
|
|
||||||
raise DataDependentOutputException(func)
|
|
||||||
|
|
||||||
# index is data-dependent, we do not know which index we are accessing it could be index or index+size!
|
|
||||||
# we assign a new data-dependent symbol for the storage offset.
|
|
||||||
new_storage_offset = fake_mode.shape_env.create_unbacked_symint()
|
|
||||||
|
|
||||||
del new_size[dim]
|
|
||||||
del new_stride[dim]
|
|
||||||
assert new_storage_offset is not None
|
|
||||||
return self.as_strided(new_size, new_stride, new_storage_offset)
|
|
||||||
|
|
||||||
|
|
||||||
@register_op_impl(aten.unique_dim.default)
|
@register_op_impl(aten.unique_dim.default)
|
||||||
def unique_dim(
|
def unique_dim(
|
||||||
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
|
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
|
||||||
|
|
|
||||||
|
|
@ -1282,7 +1282,6 @@ def compute_unbacked_bindings(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
fs = shape_env.pending_fresh_unbacked_symbols
|
fs = shape_env.pending_fresh_unbacked_symbols
|
||||||
|
|
||||||
pending = set(fs)
|
pending = set(fs)
|
||||||
if not pending:
|
if not pending:
|
||||||
return None
|
return None
|
||||||
|
|
@ -4810,7 +4809,6 @@ class ShapeEnv:
|
||||||
)
|
)
|
||||||
self.counter["create_unbacked_symbol"] += 1
|
self.counter["create_unbacked_symbol"] += 1
|
||||||
if not self._ignore_fresh_unbacked_symbols_tls():
|
if not self._ignore_fresh_unbacked_symbols_tls():
|
||||||
print(f"adding {symbol}")
|
|
||||||
self.pending_fresh_unbacked_symbols.append(symbol)
|
self.pending_fresh_unbacked_symbols.append(symbol)
|
||||||
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
||||||
vr = self.var_to_range[symbol] = ValueRanges.unknown()
|
vr = self.var_to_range[symbol] = ValueRanges.unknown()
|
||||||
|
|
|
||||||
|
|
@ -461,7 +461,6 @@ def insert_deferred_runtime_asserts(
|
||||||
),
|
),
|
||||||
keypath[2:],
|
keypath[2:],
|
||||||
)
|
)
|
||||||
|
|
||||||
return go(
|
return go(
|
||||||
graph.call_method(
|
graph.call_method(
|
||||||
keypath[0].name, (node, keypath[1].idx)
|
keypath[0].name, (node, keypath[1].idx)
|
||||||
|
|
@ -469,15 +468,6 @@ def insert_deferred_runtime_asserts(
|
||||||
keypath[2:],
|
keypath[2:],
|
||||||
)
|
)
|
||||||
elif isinstance(keypath[0], CallMethodKey):
|
elif isinstance(keypath[0], CallMethodKey):
|
||||||
if keypath[0].name == "storage_offset":
|
|
||||||
return go(
|
|
||||||
graph.call_function(
|
|
||||||
torch.ops.aten.sym_storage_offset.default,
|
|
||||||
(node,),
|
|
||||||
),
|
|
||||||
keypath[1:],
|
|
||||||
)
|
|
||||||
|
|
||||||
return go(
|
return go(
|
||||||
graph.call_method(keypath[0].name, (node,)), keypath[1:]
|
graph.call_method(keypath[0].name, (node,)), keypath[1:]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user