mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Output of nonzero is transposed, fix fake tensor (#144695)"
This reverts commit 693d8c7e94.
Reverted https://github.com/pytorch/pytorch/pull/144695 on behalf of https://github.com/izaitsevfb due to breaking internal tests, see D68461259 ([comment](https://github.com/pytorch/pytorch/pull/144695#issuecomment-2608443589))
This commit is contained in:
parent
de945d78da
commit
f0a210bf5d
|
|
@ -1406,17 +1406,17 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"):
|
|||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1)
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
|
||||
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
|
||||
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
return (alias_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1427,19 +1427,19 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
|||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
|
||||
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg0_1, _to_copy]); _to_copy = None
|
||||
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
|
||||
alias_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
|
||||
return (alias_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1452,16 +1452,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
|
|||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
|
||||
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg1_1)
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1)
|
||||
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
return (arg1_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
@ -1471,18 +1471,18 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
|
|||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
|
||||
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg0_1)
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
|
||||
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
|
||||
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
|
||||
alias_default: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1)
|
||||
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
|
||||
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None
|
||||
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
|
||||
return (arg0_1, slice_2)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
|
|
|||
|
|
@ -1597,17 +1597,6 @@ class FakeTensorPropTest(TestCase):
|
|||
self.assertIsNot(u0, u1)
|
||||
self.assertTrue(statically_known_true(u0 == u1))
|
||||
|
||||
def test_nonzero_stride(self):
|
||||
shape_env = ShapeEnv()
|
||||
fake_mode = FakeTensorMode(shape_env=shape_env)
|
||||
with fake_mode:
|
||||
value = torch.ones(5)
|
||||
fake_r = value.nonzero()
|
||||
|
||||
r = torch.ones(5).nonzero()
|
||||
|
||||
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
|
||||
|
||||
def test_torch_load_with_fake_mode(self):
|
||||
class TheModelClass(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -678,16 +678,14 @@ class InputWriter:
|
|||
return v
|
||||
|
||||
def tensor(self, name, t) -> None:
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
storage = self.storage(
|
||||
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
|
||||
)
|
||||
args = []
|
||||
# NB: this is positional, must come first
|
||||
if not statically_known_true(
|
||||
sym_eq(_stride_or_default(None, shape=t.shape), t.stride())
|
||||
):
|
||||
if _stride_or_default(None, shape=t.shape) != t.stride():
|
||||
args.append(str(tuple(t.stride())))
|
||||
if _dtype_or_default(None) != t.dtype:
|
||||
args.append(f"dtype={t.dtype!r}")
|
||||
|
|
|
|||
|
|
@ -519,8 +519,7 @@ class FakifiedOutWrapper(CompilerWrapper):
|
|||
out_metas: list[torch.Tensor] = field(default_factory=list)
|
||||
# TracingContext.fwd_output_strides
|
||||
# Generated from actually doing compile
|
||||
# NB: an entry is None if it's not a Tensor
|
||||
fwd_output_strides: Optional[list[Optional[list[int]]]] = None
|
||||
fwd_output_strides: Optional[list[list[int]]] = None
|
||||
needs_post_compile: bool = True
|
||||
|
||||
def pre_compile(
|
||||
|
|
@ -551,23 +550,12 @@ class FakifiedOutWrapper(CompilerWrapper):
|
|||
for i in range(len(out)):
|
||||
if not isinstance(out[i], Tensor):
|
||||
continue
|
||||
strides = fwd_output_strides[i]
|
||||
# fwd_output_strides is best effort by Inductor. When an output
|
||||
# Tensor has unbacked SymInts, Inductor may sometimes be unable
|
||||
# to compute what the output stride would be. If Inductor doesn't
|
||||
# have any clear direction on the layout, we don't have to run
|
||||
# as_strided. To repro without this, run:
|
||||
#
|
||||
# python test/distributed/test_dynamo_distributed.py
|
||||
# TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding
|
||||
if strides is None:
|
||||
continue
|
||||
if all(
|
||||
statically_known_true(s1 == s2)
|
||||
for s1, s2 in zip(out[i].stride(), strides)
|
||||
for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
|
||||
):
|
||||
continue
|
||||
out[i] = out[i].as_strided(out[i].shape, strides)
|
||||
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
|
||||
return out
|
||||
|
||||
# To be called post compile
|
||||
|
|
|
|||
|
|
@ -471,7 +471,7 @@ def nonzero(fake_mode, func, arg):
|
|||
|
||||
arg.nonzero_memo = nnz
|
||||
|
||||
return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64)
|
||||
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
|
||||
|
||||
|
||||
@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user