mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[export][cond] support merging constant ints as unbacked symint (#152742)"
This reverts commit a805911d15.
Reverted https://github.com/pytorch/pytorch/pull/152742 on behalf of https://github.com/ydwu4 due to breaking trunk ([comment](https://github.com/pytorch/pytorch/pull/152742#issuecomment-2874410372))
This commit is contained in:
parent
a87e810980
commit
641e4bee67
|
|
@ -3121,14 +3121,14 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
|
||||||
)
|
)
|
||||||
|
|
||||||
pred = torch.tensor(True)
|
pred = torch.tensor(True)
|
||||||
for pytree_in in [("string",), (1.0,)]:
|
for pytree_in in [(1,), ("string",), (1.0,)]:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
||||||
):
|
):
|
||||||
fn(pred, pytree_in)
|
fn(pred, pytree_in)
|
||||||
|
|
||||||
for pytree_in in [("string",), (1.0,)]:
|
for pytree_in in [(1,), ("string",), (1.0,)]:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
r"Cond doesn't work unless it is captured completely with torch.compile",
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
||||||
|
|
|
||||||
|
|
@ -1316,98 +1316,6 @@ graph():
|
||||||
M()(torch.randn(7))
|
M()(torch.randn(7))
|
||||||
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
||||||
|
|
||||||
def test_cond_branches_return_constant_int(self):
|
|
||||||
class M(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 1, tuple())
|
|
||||||
return x[idx]
|
|
||||||
|
|
||||||
args = (torch.randn(3, 3),)
|
|
||||||
m = M()
|
|
||||||
ep = export(M(), args)
|
|
||||||
if self._testMethodName == "test_cond_branches_return_constant_int":
|
|
||||||
self.assertExpectedInline(
|
|
||||||
normalize_gm(ep.module().print_readable(print_output=False)),
|
|
||||||
"""\
|
|
||||||
class GraphModule(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
x: "f32[3, 3]";
|
|
||||||
|
|
||||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
||||||
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
|
|
||||||
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
|
|
||||||
|
|
||||||
true_graph_0 = self.true_graph_0
|
|
||||||
false_graph_0 = self.false_graph_0
|
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, ()); gt = true_graph_0 = false_graph_0 = None
|
|
||||||
|
|
||||||
getitem_1: "Sym(u0)" = cond[0]; cond = None
|
|
||||||
|
|
||||||
ge_1: "Sym(u0 >= 0)" = getitem_1 >= 0
|
|
||||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
|
|
||||||
le_1: "Sym(u0 <= 1)" = getitem_1 <= 1
|
|
||||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 1 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
|
|
||||||
|
|
||||||
select: "f32[3]" = torch.ops.aten.select.int(x, 0, getitem_1); x = getitem_1 = None
|
|
||||||
return pytree.tree_unflatten((select,), self._out_spec)
|
|
||||||
|
|
||||||
class true_graph_0(torch.nn.Module):
|
|
||||||
def forward(self):
|
|
||||||
return (0,)
|
|
||||||
|
|
||||||
class false_graph_0(torch.nn.Module):
|
|
||||||
def forward(self):
|
|
||||||
return (1,)
|
|
||||||
""", # noqa: B950
|
|
||||||
)
|
|
||||||
self.assertEqual(m(*args), ep.module()(*args))
|
|
||||||
|
|
||||||
def test_cond_branches_return_same_int(self):
|
|
||||||
class M(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 0, tuple())
|
|
||||||
return x[idx]
|
|
||||||
|
|
||||||
args = (torch.randn(3, 3),)
|
|
||||||
m = M()
|
|
||||||
ep = export(M(), args)
|
|
||||||
# Ideally, we could remove the cond at the front end directly
|
|
||||||
# since it's not used anyway. But we can only do this early
|
|
||||||
# optimization if all the outputs are the same constants, which
|
|
||||||
# will complicates the output check so just keep it in the graph.
|
|
||||||
# let downstream to dce it.
|
|
||||||
if self._testMethodName == "test_cond_branches_return_same_int":
|
|
||||||
self.assertExpectedInline(
|
|
||||||
normalize_gm(ep.module().print_readable(print_output=False)),
|
|
||||||
"""\
|
|
||||||
class GraphModule(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
x: "f32[3, 3]";
|
|
||||||
|
|
||||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
||||||
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
|
|
||||||
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
|
|
||||||
|
|
||||||
true_graph_0 = self.true_graph_0
|
|
||||||
false_graph_0 = self.false_graph_0
|
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, ()); gt = true_graph_0 = false_graph_0 = None
|
|
||||||
getitem = cond[0]; cond = getitem = None
|
|
||||||
|
|
||||||
select: "f32[3]" = torch.ops.aten.select.int(x, 0, 0); x = None
|
|
||||||
return pytree.tree_unflatten((select,), self._out_spec)
|
|
||||||
|
|
||||||
class true_graph_0(torch.nn.Module):
|
|
||||||
def forward(self):
|
|
||||||
return (0,)
|
|
||||||
|
|
||||||
class false_graph_0(torch.nn.Module):
|
|
||||||
def forward(self):
|
|
||||||
return (0,)
|
|
||||||
""", # noqa: B950
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(m(*args), ep.module()(*args))
|
|
||||||
|
|
||||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_cond_contains_unbacked_no_escape(self):
|
def test_cond_contains_unbacked_no_escape(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -8130,10 +8130,10 @@ class GraphModule(torch.nn.Module):
|
||||||
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
|
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
|
||||||
|
|
||||||
@skipIfTorchDynamo(
|
@skipIfTorchDynamo(
|
||||||
"Skip because _merge_output is not intended for dynamo to compile"
|
"Skip because _merge_tensors is not intended for dynamo to compile"
|
||||||
)
|
)
|
||||||
def test_merge_output(self):
|
def test_merge_tensors(self):
|
||||||
from torch._higher_order_ops.cond import _merge_output
|
from torch._higher_order_ops.cond import _merge_tensors
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||||
|
|
||||||
|
|
@ -8178,7 +8178,7 @@ class GraphModule(torch.nn.Module):
|
||||||
with fake_mode:
|
with fake_mode:
|
||||||
t1 = torch.empty_strided(size1, stride1)
|
t1 = torch.empty_strided(size1, stride1)
|
||||||
t2 = torch.empty_strided(size2, stride2)
|
t2 = torch.empty_strided(size2, stride2)
|
||||||
out = _merge_output(t1, t2, fake_mode)
|
out = _merge_tensors(t1, t2, fake_mode)
|
||||||
self.assertEqual(str(tuple(out.size())), merged_size)
|
self.assertEqual(str(tuple(out.size())), merged_size)
|
||||||
self.assertEqual(str(tuple(out.stride())), merged_stride)
|
self.assertEqual(str(tuple(out.stride())), merged_stride)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1059,17 +1059,10 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
should_flatten_outputs=True,
|
should_flatten_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)):
|
if not only_consist_of(ret_val, (TensorVariable,)):
|
||||||
unimplemented(
|
unimplemented(
|
||||||
"Expected branches to return a possibly nested pytree of tensors "
|
"Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.",
|
||||||
"or constant ints but it consists of others.",
|
|
||||||
)
|
)
|
||||||
for ret in ret_val.unpack_var_sequence(tx):
|
|
||||||
if isinstance(ret, ConstantVariable) and ret.python_type() is not int:
|
|
||||||
unimplemented(
|
|
||||||
"Expected branches to return a possibly nested pytree of tensors "
|
|
||||||
f"or constant ints but it consists of others {ret.python_type()}.",
|
|
||||||
)
|
|
||||||
return ret_val, ret_treespec, ret_graph, ret_lifted_freevars
|
return ret_val, ret_treespec, ret_graph, ret_lifted_freevars
|
||||||
|
|
||||||
(true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch(
|
(true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch(
|
||||||
|
|
|
||||||
|
|
@ -1626,13 +1626,11 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
self.module,
|
self.module,
|
||||||
self.serialized_name_to_node,
|
self.serialized_name_to_node,
|
||||||
self.serialized_name_to_meta,
|
self.serialized_name_to_meta,
|
||||||
self.unbacked_symbols
|
|
||||||
)
|
)
|
||||||
self.graph = torch.fx.Graph()
|
self.graph = torch.fx.Graph()
|
||||||
self.module = torch.nn.Module()
|
self.module = torch.nn.Module()
|
||||||
self.serialized_name_to_node = {}
|
self.serialized_name_to_node = {}
|
||||||
self.serialized_name_to_meta = {}
|
self.serialized_name_to_meta = {}
|
||||||
self.unbacked_symbols: set[sympy.Symbol] = set()
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -1641,7 +1639,6 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
self.module,
|
self.module,
|
||||||
self.serialized_name_to_node,
|
self.serialized_name_to_node,
|
||||||
self.serialized_name_to_meta,
|
self.serialized_name_to_meta,
|
||||||
self.unbacked_symbols
|
|
||||||
) = saved
|
) = saved
|
||||||
|
|
||||||
def deserialize_extension_operator(self, serialized_target: str):
|
def deserialize_extension_operator(self, serialized_target: str):
|
||||||
|
|
@ -2172,7 +2169,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
self.symbol_name_to_range = {}
|
self.symbol_name_to_range = {}
|
||||||
# we also need to bump unbacked sym[float,int] counters in the
|
# we also need to bump unbacked sym[float,int] counters in the
|
||||||
# shape env to accommodate unbacked symbols in the exported program
|
# shape env to accommodate unbacked symbols in the exported program
|
||||||
self.unbacked_symbols = set()
|
self.unbacked_symbols: set[sympy.Symbol] = set()
|
||||||
count_unbacked_symfloat, count_unbacked_symint = -1, -1
|
count_unbacked_symfloat, count_unbacked_symint = -1, -1
|
||||||
unbacked_symfloat_prefix, unbacked_symint_prefix = (
|
unbacked_symfloat_prefix, unbacked_symint_prefix = (
|
||||||
prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT]
|
prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT]
|
||||||
|
|
@ -2410,33 +2407,26 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
# Check single value return
|
# Check single value return
|
||||||
if len(serialized_node.outputs) == 0:
|
if len(serialized_node.outputs) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
|
||||||
len(serialized_node.outputs) == 1
|
|
||||||
and "torch.ops.higher_order" in serialized_node.target
|
|
||||||
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
|
|
||||||
):
|
|
||||||
def _deserialize_hop_with_single_return(serialized_node, fx_node):
|
|
||||||
meta_val: list[Any] = []
|
|
||||||
arg = None
|
|
||||||
if serialized_node.outputs[0].type == "as_tensor":
|
|
||||||
arg = serialized_node.outputs[0].as_tensor
|
|
||||||
elif isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument, SymFloatArgument)):
|
|
||||||
arg = serialized_node.outputs[0].value
|
|
||||||
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
|
|
||||||
assert arg is not None
|
|
||||||
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
|
|
||||||
fx_node.meta["val"] = tuple(meta_val)
|
|
||||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
|
||||||
return
|
|
||||||
|
|
||||||
return _deserialize_hop_with_single_return(serialized_node, fx_node)
|
|
||||||
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(serialized_node.outputs) == 1
|
len(serialized_node.outputs) == 1
|
||||||
and serialized_node.outputs[0].type == "as_tensor"
|
and serialized_node.outputs[0].type == "as_tensor"
|
||||||
):
|
):
|
||||||
|
# If it is a HOP node and it returns a tuple containing a single element
|
||||||
|
# we manually insert a getitem node to ensure the graph is consistent
|
||||||
|
# For BC, getattr() will return True if `is_single_tensor_return` doens't exist
|
||||||
|
# as prior to adding this field, it is guaranteed to have a single tensor return
|
||||||
|
# when the serialized_node has length=1 outputs and of type `as_tensor`.
|
||||||
|
if (
|
||||||
|
"torch.ops.higher_order" in serialized_node.target
|
||||||
|
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
|
||||||
|
):
|
||||||
|
meta_val: list[Any] = []
|
||||||
|
arg = serialized_node.outputs[0].as_tensor
|
||||||
|
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
|
||||||
|
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
|
||||||
|
fx_node.meta["val"] = tuple(meta_val)
|
||||||
|
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||||
|
return
|
||||||
|
|
||||||
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ __all__ = [
|
||||||
"while_loop",
|
"while_loop",
|
||||||
"invoke_subgraph",
|
"invoke_subgraph",
|
||||||
"scan",
|
"scan",
|
||||||
"map",
|
|
||||||
"flex_attention",
|
"flex_attention",
|
||||||
"flex_attention_backward",
|
"flex_attention_backward",
|
||||||
"hints_wrapper",
|
"hints_wrapper",
|
||||||
|
|
|
||||||
|
|
@ -102,9 +102,7 @@ def cond(
|
||||||
false_fn (Callable): A callable function (a -> b) that is within the
|
false_fn (Callable): A callable function (a -> b) that is within the
|
||||||
scope that is being traced. The true branch and false branch must
|
scope that is being traced. The true branch and false branch must
|
||||||
have consistent input and outputs, meaning the inputs have to be
|
have consistent input and outputs, meaning the inputs have to be
|
||||||
the same, and the outputs have to be the same type and shape. Int
|
the same, and the outputs have to be the same type and shape.
|
||||||
output is also allowed. We'll make the output dynamic by turning it
|
|
||||||
into a symint.
|
|
||||||
|
|
||||||
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the
|
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the
|
||||||
true/false functions. It can be empty if true_fn/false_fn doesn't require input. Defaults to ().
|
true/false functions. It can be empty if true_fn/false_fn doesn't require input. Defaults to ().
|
||||||
|
|
@ -434,7 +432,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
|
||||||
|
|
||||||
merged_outs = []
|
merged_outs = []
|
||||||
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
|
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
|
||||||
merged_outs.append(_merge_output(true_out, false_out, mode))
|
merged_outs.append(_merge_tensors(true_out, false_out, mode))
|
||||||
return pytree.tree_unflatten(merged_outs, true_out_spec)
|
return pytree.tree_unflatten(merged_outs, true_out_spec)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -456,10 +454,8 @@ def check_tensor_meta_match(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _merge_output(
|
def _merge_tensors(
|
||||||
a: Optional[Union[torch.Tensor, int]],
|
a: Optional[torch.Tensor], b: Optional[torch.Tensor], mode: FakeTensorMode
|
||||||
b: Optional[Union[torch.Tensor, int]],
|
|
||||||
mode: FakeTensorMode,
|
|
||||||
):
|
):
|
||||||
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
|
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
|
||||||
|
|
||||||
|
|
@ -467,28 +463,6 @@ def _merge_output(
|
||||||
assert a is None and b is None, (a, b)
|
assert a is None and b is None, (a, b)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def min_max(s0, s1):
|
|
||||||
def _bound(s0, lower_bound: bool):
|
|
||||||
if isinstance(s0, int):
|
|
||||||
return s0
|
|
||||||
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
|
|
||||||
s0.node.expr,
|
|
||||||
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
|
|
||||||
)
|
|
||||||
return r.lower if lower_bound else r.upper
|
|
||||||
|
|
||||||
return min(_bound(s0, True), _bound(s1, True)), max(
|
|
||||||
_bound(s0, False), _bound(s1, False)
|
|
||||||
)
|
|
||||||
|
|
||||||
if type(a) is int and type(b) is int:
|
|
||||||
if a == b:
|
|
||||||
return a
|
|
||||||
assert mode.shape_env is not None
|
|
||||||
merged_out = mode.shape_env.create_unbacked_symint()
|
|
||||||
mode.shape_env.constrain_symbol_range(merged_out.node.expr, *min_max(a, b))
|
|
||||||
return merged_out
|
|
||||||
|
|
||||||
assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b))
|
assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b))
|
||||||
|
|
||||||
# Note: we don't check size, stride because
|
# Note: we don't check size, stride because
|
||||||
|
|
@ -530,6 +504,21 @@ def _merge_output(
|
||||||
if SymIntEqByExpr(s0) == SymIntEqByExpr(s1):
|
if SymIntEqByExpr(s0) == SymIntEqByExpr(s1):
|
||||||
merged_size.append(s0)
|
merged_size.append(s0)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
def min_max(s0, s1):
|
||||||
|
def _bound(s0, lower_bound: bool):
|
||||||
|
if isinstance(s0, int):
|
||||||
|
return s0
|
||||||
|
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
|
||||||
|
s0.node.expr,
|
||||||
|
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
|
||||||
|
)
|
||||||
|
return r.lower if lower_bound else r.upper
|
||||||
|
|
||||||
|
return min(_bound(s0, True), _bound(s1, True)), max(
|
||||||
|
_bound(s0, False), _bound(s1, False)
|
||||||
|
)
|
||||||
|
|
||||||
assert mode.shape_env is not None
|
assert mode.shape_env is not None
|
||||||
new_size = mode.shape_env.create_unbacked_symint()
|
new_size = mode.shape_env.create_unbacked_symint()
|
||||||
mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1))
|
mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user