Revert "Output of nonzero is transposed, fix fake tensor (#144695)"

This reverts commit 693d8c7e94.

Reverted https://github.com/pytorch/pytorch/pull/144695 on behalf of https://github.com/izaitsevfb due to breaking internal tests, see D68461259 ([comment](https://github.com/pytorch/pytorch/pull/144695#issuecomment-2608443589))
This commit is contained in:
PyTorch MergeBot 2025-01-22 23:04:50 +00:00
parent de945d78da
commit f0a210bf5d
5 changed files with 22 additions and 47 deletions

View File

@ -1406,17 +1406,17 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
return (alias_1, slice_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1427,19 +1427,19 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg0_1, _to_copy]); _to_copy = None
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
alias_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
return (alias_1, slice_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1452,16 +1452,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg1_1)
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
return (arg1_1, slice_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -1471,18 +1471,18 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg0_1)
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
return (arg0_1, slice_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,

View File

@ -1597,17 +1597,6 @@ class FakeTensorPropTest(TestCase):
self.assertIsNot(u0, u1)
self.assertTrue(statically_known_true(u0 == u1))
def test_nonzero_stride(self):
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
with fake_mode:
value = torch.ones(5)
fake_r = value.nonzero()
r = torch.ones(5).nonzero()
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
def test_torch_load_with_fake_mode(self):
class TheModelClass(torch.nn.Module):
def __init__(self) -> None:

View File

@ -678,16 +678,14 @@ class InputWriter:
return v
def tensor(self, name, t) -> None:
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
from torch.fx.experimental.symbolic_shapes import statically_known_true
storage = self.storage(
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
)
args = []
# NB: this is positional, must come first
if not statically_known_true(
sym_eq(_stride_or_default(None, shape=t.shape), t.stride())
):
if _stride_or_default(None, shape=t.shape) != t.stride():
args.append(str(tuple(t.stride())))
if _dtype_or_default(None) != t.dtype:
args.append(f"dtype={t.dtype!r}")

View File

@ -519,8 +519,7 @@ class FakifiedOutWrapper(CompilerWrapper):
out_metas: list[torch.Tensor] = field(default_factory=list)
# TracingContext.fwd_output_strides
# Generated from actually doing compile
# NB: an entry is None if it's not a Tensor
fwd_output_strides: Optional[list[Optional[list[int]]]] = None
fwd_output_strides: Optional[list[list[int]]] = None
needs_post_compile: bool = True
def pre_compile(
@ -551,23 +550,12 @@ class FakifiedOutWrapper(CompilerWrapper):
for i in range(len(out)):
if not isinstance(out[i], Tensor):
continue
strides = fwd_output_strides[i]
# fwd_output_strides is best effort by Inductor. When an output
# Tensor has unbacked SymInts, Inductor may sometimes be unable
# to compute what the output stride would be. If Inductor doesn't
# have any clear direction on the layout, we don't have to run
# as_strided. To repro without this, run:
#
# python test/distributed/test_dynamo_distributed.py
# TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding
if strides is None:
continue
if all(
statically_known_true(s1 == s2)
for s1, s2 in zip(out[i].stride(), strides)
for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
):
continue
out[i] = out[i].as_strided(out[i].shape, strides)
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
return out
# To be called post compile

View File

@ -471,7 +471,7 @@ def nonzero(fake_mode, func, arg):
arg.nonzero_memo = nnz
return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64)
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)