[HOP] Mutation and alias rework (#146658)

This PR reworks the way the input mutations and various aliases are checked

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146658
Approved by: https://github.com/ydwu4
This commit is contained in:
Thomas Bohnstingl 2025-05-18 08:05:22 +00:00 committed by PyTorch MergeBot
parent 0e805aad7f
commit 68034198e5
20 changed files with 519 additions and 269 deletions

View File

@ -1,5 +1,4 @@
from torch import cond # noqa: F401 from torch import cond # noqa: F401
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
from torch._higher_order_ops.map import ( # noqa: F401 from torch._higher_order_ops.map import ( # noqa: F401
_stack_pytree, _stack_pytree,
_unstack_pytree, _unstack_pytree,

View File

@ -1873,7 +1873,7 @@ def forward(self, x, y):
return x + x return x + x
def false_fn(x): def false_fn(x):
return x[:2] return x[:2].clone()
return cond(x.shape[0] <= 2, true_fn, false_fn, [x]) return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
@ -1883,7 +1883,7 @@ def forward(self, x, y):
return x + x return x + x
def false_fn(x): def false_fn(x):
return x[:2] return x[:2].clone()
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,)) return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
@ -1924,7 +1924,8 @@ def forward(self, l_x_):
def forward(self, l_x_): def forward(self, l_x_):
l_x__1 = l_x_ l_x__1 = l_x_
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
return (getitem,)""", clone = getitem.clone(); getitem = None
return (clone,)""",
) )
# We could successfully export branches that return different sizes # We could successfully export branches that return different sizes
torch._dynamo.export(mod)(torch.randn(3, 2)) torch._dynamo.export(mod)(torch.randn(3, 2))
@ -3302,7 +3303,12 @@ def forward(self, x):
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self): def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
def f_branch_return_multiple_tensors(pred, x, y): def f_branch_return_multiple_tensors(pred, x, y):
return cond(pred, lambda x: (x, x), lambda x: (x, x), [y]) return cond(
pred,
lambda x: (x.clone(), x.clone()),
lambda x: (x.clone(), x.clone()),
[y],
)
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2)) example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
gm, _ = torch._dynamo.export( gm, _ = torch._dynamo.export(
@ -3324,10 +3330,10 @@ def forward(self, x):
def test_cond_raise_user_error_on_mismatch_return_length(self): def test_cond_raise_user_error_on_mismatch_return_length(self):
def true_fn(x): def true_fn(x):
return x return x.clone()
def false_fn(x): def false_fn(x):
return (x, x) return (x.clone(), x.clone())
def f_mismatch_return_length(x): def f_mismatch_return_length(x):
return cond(torch.tensor(100), true_fn, false_fn, [x]) return cond(torch.tensor(100), true_fn, false_fn, [x])

View File

@ -1791,7 +1791,13 @@ def forward(self, child : torch.Tensor):
def test_map_pytree_return(self): def test_map_pytree_return(self):
def _construct_pytree(a): def _construct_pytree(a):
return (a, [[[a]]], a, (a, (a,), a), {"a": a}) return (
a.clone(),
[[[a.clone()]]],
a.clone(),
(a.clone(), (a.clone(),), a.clone()),
{"a": a.clone()},
)
def f(x): def f(x):
def inner_f(xs): def inner_f(xs):
@ -1823,7 +1829,14 @@ def forward(self, L_x_ : torch.Tensor):
body_graph, body_graph,
"""\ """\
def forward(self, child : torch.Tensor): def forward(self, child : torch.Tensor):
return (child, child, child, child, child, child, child)""", child_1 = child.clone()
child_2 = child.clone()
child_3 = child.clone()
child_4 = child.clone()
child_5 = child.clone()
child_6 = child.clone()
child_7 = child.clone(); child = None
return (child_1, child_2, child_3, child_4, child_5, child_6, child_7)""",
) )
def test_map_kwargs(self): def test_map_kwargs(self):
@ -6902,7 +6915,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
def test(pred, x): def test(pred, x):
def true_fn(x): def true_fn(x):
return x return x.clone()
def false_fn(x): def false_fn(x):
return -x return -x
@ -6926,7 +6939,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
def test(pred, mode, x): def test(pred, mode, x):
def true_fn(x): def true_fn(x):
return x return x.clone()
def false_fn(x): def false_fn(x):
return -x return -x

View File

@ -5931,7 +5931,7 @@ utils_device.CURRENT_DEVICE == None""".split(
from functorch.experimental.control_flow import cond from functorch.experimental.control_flow import cond
def true_fn(x): def true_fn(x):
return x return x.clone()
def false_fn(x): def false_fn(x):
return x.sin() return x.sin()

View File

@ -7604,20 +7604,25 @@ def forward(self, b_a_buffer, x):
self.assertTrue(torch.allclose(ep.module()(xs), module_out)) self.assertTrue(torch.allclose(ep.module()(xs), module_out))
@requires_cuda @requires_cuda
@testing.expectedFailureCppRuntime
def test_export_associative_scan_lifted_buffers(self): def test_export_associative_scan_lifted_buffers(self):
device = torch.device("cuda") device = torch.device("cuda")
combine_mode = "pointwise" combine_mode = "pointwise"
class A(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buffer = torch.nn.Buffer(torch.ones(3, 2, device=device))
def forward(self):
return self.buffer.cos()
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.register_buffer( self.a = A()
"buf", torch.ones(3, 2, device=device), persistent=False
)
def combine_fn(self, x, y): def combine_fn(self, x, y):
return x + y * self.buf return (x + y) * self.a()
def forward(self, x): def forward(self, x):
return associative_scan( return associative_scan(

View File

@ -4572,17 +4572,17 @@ class <lambda>(torch.nn.Module):
body_graph_0 = self.body_graph_0 body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
body_graph_1 = self.body_graph_1 body_graph_1 = self.body_graph_1
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
getitem_1: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None getitem_5: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_5); getitem_5 = None
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
return (add_1,) return (add_1,)
@ -4635,9 +4635,9 @@ class <lambda>(torch.nn.Module):
body_graph_0 = self.body_graph_0 body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add,) return (add,)

View File

@ -6,7 +6,7 @@ import unittest
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from functorch.experimental import control_flow from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException from functorch.experimental.control_flow import cond
from torch._dynamo.testing import normalize_gm from torch._dynamo.testing import normalize_gm
from torch._higher_order_ops.associative_scan import ( from torch._higher_order_ops.associative_scan import (
_fake_associative_scan, _fake_associative_scan,
@ -36,7 +36,6 @@ from torch.testing._internal.common_utils import (
TEST_WITH_CROSSREF, TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO, TEST_WITH_TORCHDYNAMO,
TestCase, TestCase,
xfailIfTorchDynamo,
) )
@ -159,7 +158,7 @@ def get_scan_combine_fn(name, associative=True, parameters=None):
def RNN(x: torch.Tensor, y: torch.Tensor): def RNN(x: torch.Tensor, y: torch.Tensor):
c_new = y @ parameters[0] + parameters[1] c_new = y @ parameters[0] + parameters[1]
h_new = torch.tanh(c_new + x @ parameters[2] + parameters[3]) h_new = torch.tanh(c_new + x @ parameters[2] + parameters[3])
return h_new, h_new return h_new, h_new.clone()
def fct_c1_no_grad(x: torch.Tensor, y: torch.Tensor): def fct_c1_no_grad(x: torch.Tensor, y: torch.Tensor):
h_new = torch.tanh(x[0] + x[1] + y) h_new = torch.tanh(x[0] + x[1] + y)
@ -888,8 +887,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
) )
def test_cond_autograd_pytree_input(self): def test_cond_autograd_pytree_input(self):
# TODO: This is an unexpected behavior for cond
# Without this additional multiplication,
# the output of the backward graph would alias the
# inputs, as the gradients are just 1s and thus get optimized
def true_fn(x): def true_fn(x):
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0] return (x["t"][0] * 2.0) + x["t"][1]["b"] * x["t"][2][0]
def false_fn(x): def false_fn(x):
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"]) return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
@ -966,10 +969,10 @@ def forward(self, pred_1):
@skipIfTorchDynamo("Skip due to graph break when run with dynamo") @skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_same_pytree_output(self): def test_cond_autograd_same_pytree_output(self):
def true_fn(x): def true_fn(x):
return {"res": [x["t"][0], (x["t"][2][0],)]} return {"res": [x["t"][0].clone(), (x["t"][2][0].clone(),)]}
def false_fn(x): def false_fn(x):
return {"res": [x["t"][1]["b"], (x["t"][2][0],)]} return {"res": [x["t"][1]["b"].clone(), (x["t"][2][0].clone(),)]}
a = torch.randn(4, requires_grad=True) a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True) b = torch.randn(4, requires_grad=True)
@ -1007,9 +1010,7 @@ def forward(self, pred_1):
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
getitem = cond[0] getitem = cond[0]
getitem_1 = cond[1]; cond = None getitem_1 = cond[1]; cond = None
view = torch.ops.aten.view.default(getitem, [4]); getitem = None return {'res': [getitem, (getitem_1,)]}""", # noqa: B950
view_1 = torch.ops.aten.view.default(getitem_1, [4]); getitem_1 = None
return {'res': [view, (view_1,)]}""", # noqa: B950
) )
@skipIfTorchDynamo("Skip due to graph break when run with dynamo") @skipIfTorchDynamo("Skip due to graph break when run with dynamo")
@ -1901,12 +1902,15 @@ def forward(self, pred_1, x_1):
{ {
"i": x["i"] * y["j"][0][0], "i": x["i"] * y["j"][0][0],
"k": torch.tensor(0.0), "k": torch.tensor(0.0),
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), "j": (
[x["j"][1][0]["o"].clone()],
[{"o": torch.sin(x["i"])}],
),
}, },
{ {
"i": x["i"] * y["j"][0][0], "i": x["i"] * y["j"][0][0],
"k": torch.tensor(0.0), "k": torch.tensor(0.0),
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), "j": ([x["j"][1][0]["o"].clone()], [{"o": torch.sin(x["i"])}]),
}, },
) )
@ -2395,8 +2399,11 @@ def forward(self, pred_1, x_1):
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
# Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.UncapturedHigherOrderOpError,
"Expected init and carry to have same metadata.*", "scan must be captured completely with torch.compile.*",
): ):
scan_fct(wrong_carry_shape, init, x, dim=dim) scan_fct(wrong_carry_shape, init, x, dim=dim)
@ -3311,6 +3318,114 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
return (carry, out_1)""", # noqa: B950 return (carry, out_1)""", # noqa: B950
) )
@requires_cuda
def test_scan_input_mutation(self):
device = torch.device("cuda")
def fct_input_mutation(x, y):
x.add_(1)
return x + y, x + y + 2
x = torch.randn(3, 2, 2, device=device)
init = torch.randn(2, 2, device=device)
with self.assertRaisesRegex(
# Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"scan must be captured completely with torch.compile.*",
):
scan(fct_input_mutation, init, x, dim=0)
@requires_cuda
def test_scan_input_carry_alias(self):
device = torch.device("cuda")
def fct_input_output_alias(x, y):
return (x[0], x[1] + y[1]), (x[1] + y[1] + 1, x[1] + y[1] + 2)
x = torch.randn(3, 2, 2, device=device)
y = torch.randn(3, 2, 2, device=device)
inp = (x, y)
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
with self.assertRaisesRegex(
# Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"scan must be captured completely with torch.compile.*",
):
scan(fct_input_output_alias, init, inp, dim=0)
@requires_cuda
def test_scan_input_output_alias(self):
device = torch.device("cuda")
def fct_input_output_alias(x, y):
return (x[0] + 1, x[1] + y[1]), (x[1], x[1] + y[1] + 2)
x = torch.randn(3, 2, 2, device=device)
y = torch.randn(3, 2, 2, device=device)
inp = (x, y)
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
with self.assertRaisesRegex(
# Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"scan must be captured completely with torch.compile.*",
):
scan(fct_input_output_alias, init, inp, dim=0)
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
def test_scan_carry_carry_alias(self):
device = torch.device("cuda")
def fct_carry_carry_alias(x, y):
c = x[0] + y[1]
return (c, c), (x[0] + y[1], x[0] + y[1] + 1)
x = torch.randn(3, 2, 2, device=device)
y = torch.randn(3, 2, 2, device=device)
inp = (x, y)
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
with self.assertRaisesRegex(
# Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"scan must be captured completely with torch.compile.*",
):
scan(fct_carry_carry_alias, init, inp, dim=0)
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
def test_scan_carry_output_alias(self):
device = torch.device("cuda")
def fct_carry_output_alias(x, y):
c = x[0] + y[1]
return (x[0] + y[1], c), (c, x[0] + y[1] + 1)
x = torch.randn(3, 2, 2, device=device)
y = torch.randn(3, 2, 2, device=device)
inp = (x, y)
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
with self.assertRaisesRegex(
# Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"scan must be captured completely with torch.compile.*",
):
scan(fct_carry_output_alias, init, inp, dim=0)
class AssociativeScanModels: class AssociativeScanModels:
@staticmethod @staticmethod
@ -4158,7 +4273,7 @@ class GraphModule(torch.nn.Module):
) )
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device): def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
def combine_fn(x, y): def combine_fn(x, y):
val = cond(torch.sum(y) > 0.0, lambda y: y + 0.0, lambda y: 1.0 - y, (y,)) val = cond(torch.sum(y) > 0.0, lambda y: y.clone(), lambda y: 1.0 - y, (y,))
return x * val return x * val
inp = torch.randn(3, 10, 1, device=device) inp = torch.randn(3, 10, 1, device=device)
@ -4775,8 +4890,10 @@ class GraphModule(torch.nn.Module):
with self.assertRaisesRegex( with self.assertRaisesRegex(
# Should be # Should be
RuntimeError, # torch._dynamo.exc.Unsupported,
"Combine_fn might be modifying the input!", # "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"associative_scan must be captured completely with torch.compile.*",
): ):
associative_scan(fct_input_mutation, x, 0) associative_scan(fct_input_mutation, x, 0)
@ -4793,11 +4910,35 @@ class GraphModule(torch.nn.Module):
with self.assertRaisesRegex( with self.assertRaisesRegex(
# Should be # Should be
RuntimeError, # torch._dynamo.exc.Unsupported,
"Combine_fn might be aliasing the input!", # "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"associative_scan must be captured completely with torch.compile.*",
): ):
associative_scan(fct_input_output_alias, inp, 0) associative_scan(fct_input_output_alias, inp, 0)
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
def test_associative_scan_output_output_alias(self):
device = torch.device("cuda")
def fct_output_output_alias(x, y):
c = x[0] + y[1]
return c, c
x = torch.randn(3, 2, 2, device=device)
y = torch.randn(3, 2, 2, device=device)
inp = (x, y)
with self.assertRaisesRegex(
# Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"associative_scan must be captured completely with torch.compile.*",
):
associative_scan(fct_output_output_alias, inp, 0)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@skipIfNoDynamoSupport @skipIfNoDynamoSupport
@ -5597,8 +5738,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs) graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs)) self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
# https://github.com/pytorch/pytorch/issues/126988 def test_cond_functionalized_input_mutation_on_true_branch(self):
def test_cond_functionalized_input_mutation_on_true_brancte(self):
def true_fn(x): def true_fn(x):
view_x = x.view(x.shape) view_x = x.view(x.shape)
view_x.add_(1) view_x.add_(1)
@ -5632,13 +5772,13 @@ def forward(self, x_1):
# torch.cond triggers the check of the branches because the predicate # torch.cond triggers the check of the branches because the predicate
# is a SymBool. # is a SymBool.
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch" torch._dynamo.exc.TorchRuntimeError,
"cond_true might be modifying the input!",
): ):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs *example_inputs
) )
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_input_mutation_on_false_branch(self): def test_cond_functionalized_input_mutation_on_false_branch(self):
def true_fn(x): def true_fn(x):
return x.sin().sum() return x.sin().sum()
@ -5673,16 +5813,16 @@ def forward(self, x_1):
# torch.cond triggers the check of the branches because the predicate # torch.cond triggers the check of the branches because the predicate
# is a SymBool. # is a SymBool.
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch" torch._dynamo.exc.TorchRuntimeError,
"cond_false might be modifying the input!",
): ):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs *example_inputs
) )
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_output_alias_input(self): def test_cond_functionalized_output_alias_input(self):
def true_fn(x): def true_fn(x):
return x return x.clone()
def false_fn(x): def false_fn(x):
view_x = x.view(x.shape) view_x = x.view(x.shape)
@ -5707,13 +5847,16 @@ def forward(self, x_1):
# torch.cond triggers the check of the branches because the predicate # torch.cond triggers the check of the branches because the predicate
# is a SymBool. # is a SymBool.
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch" # Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs *example_inputs
) )
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_nested_input_mutation(self): def test_cond_functionalized_nested_input_mutation(self):
def true_true_fn(x): def true_true_fn(x):
x.add_(4) x.add_(4)
@ -5735,13 +5878,13 @@ def forward(self, x_1):
example_inputs = (torch.ones(4, 5),) example_inputs = (torch.ones(4, 5),)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch" torch._dynamo.exc.TorchRuntimeError,
"cond_true might be modifying the input!",
): ):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")( make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs *example_inputs
) )
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_nested_input_mutation_with_aot_func(self): def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
def true_true_fn(x): def true_true_fn(x):
x.add_(4) x.add_(4)
@ -5768,7 +5911,11 @@ def forward(self, x_1):
f(example_input_func) f(example_input_func)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch" # Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
make_fx(f, tracing_mode="symbolic")(example_input_func) make_fx(f, tracing_mode="symbolic")(example_input_func)
finally: finally:
@ -5786,7 +5933,11 @@ def forward(self, x_1):
return wrapper return wrapper
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch" # Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func) make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
@ -5807,8 +5958,11 @@ def forward(self, x_1):
example_input_func = to_fun_old(example_input) example_input_func = to_fun_old(example_input)
torch._enable_functionalization(reapply_views=False) torch._enable_functionalization(reapply_views=False)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, # Should be
"One of torch.cond branch might be aliasing", # torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
f(example_input_func) f(example_input_func)
finally: finally:
@ -5838,8 +5992,11 @@ def forward(self, x_1):
return wrapper return wrapper
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, # Should be
"One of torch.cond branch might be aliasing", # torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input) make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
@ -5974,8 +6131,11 @@ def forward(self, arg0_1):
x = torch.randn(4) x = torch.randn(4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, # Should be
"Unmatched output spec from torch.cond branches", # torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
make_fx(f)(x, torch.tensor(False)) make_fx(f)(x, torch.tensor(False))
@ -6147,8 +6307,11 @@ def forward(self, arg0_1):
x = torch.randn(4) x = torch.randn(4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, # Should be
"Unmatched output spec from torch.cond branches", # torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
@ -6431,7 +6594,8 @@ def forward(self, arg0_1):
example_inputs = (torch.ones(3, 2, 4), torch.ones(4)) example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
functional_f = torch.func.functionalize(f) functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "torch.map is mutating the input!" torch._dynamo.exc.TorchRuntimeError,
"map might be modifying the input!",
): ):
functional_f(*example_inputs) functional_f(*example_inputs)
@ -6446,7 +6610,7 @@ def forward(self, arg0_1):
example_inputs = (torch.ones(3, 2, 4), torch.ones(4)) example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
functional_f = torch.func.functionalize(f) functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "torch.map is mutating the input!" torch._dynamo.exc.TorchRuntimeError, "map might be modifying the input!"
): ):
functional_f(*example_inputs) functional_f(*example_inputs)
@ -6484,7 +6648,11 @@ def forward(self, arg0_1):
example_inputs = (torch.ones(3, 2, 4),) example_inputs = (torch.ones(3, 2, 4),)
functional_f = torch.func.functionalize(f) functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.TorchRuntimeError, "torch.map is aliasing the input!" # Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"map doesn't work unless it is captured completely with torch.compile.*",
): ):
functional_f(*example_inputs) functional_f(*example_inputs)
@ -6843,10 +7011,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
gm.code.strip(), gm.code.strip(),
"""\ """\
def forward(self, pred_1, x_1): def forward(self, pred_1, x_1):
unbind = torch.ops.aten.unbind.int(x_1)
getitem = unbind[0]; getitem = None
getitem_1 = unbind[1]; unbind = getitem_1 = None
body_graph_0 = self.body_graph_0 body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
getitem = map_impl[0]; map_impl = None getitem_2 = map_impl[0]; map_impl = None
return getitem""", return getitem_2""",
) )
self.assertExpectedInline( self.assertExpectedInline(
gm.body_graph_0.code.strip(), gm.body_graph_0.code.strip(),
@ -7082,7 +7253,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
return torch.cond( return torch.cond(
pred=torch.tensor([True]), pred=torch.tensor([True]),
true_fn=lambda x: x + 100, true_fn=lambda x: x + 100,
false_fn=lambda x: x, false_fn=lambda x: x.clone(),
operands=(x,), operands=(x,),
) )
@ -7096,7 +7267,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
return torch.cond( return torch.cond(
pred=x.sum() < y.sum(), pred=x.sum() < y.sum(),
true_fn=lambda x, y: x + 100, true_fn=lambda x, y: x + 100,
false_fn=lambda x, y: y, false_fn=lambda x, y: y.clone(),
operands=(x, y), operands=(x, y),
) )
@ -7155,7 +7326,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
return torch.cond( return torch.cond(
pred=torch.tensor([True]), pred=torch.tensor([True]),
true_fn=lambda x: (x + c, x - c), true_fn=lambda x: (x + c, x - c),
false_fn=lambda x: (x, x), false_fn=lambda x: (x.clone(), x.clone()),
operands=(x,), operands=(x,),
) )
@ -7165,7 +7336,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
return torch.cond( return torch.cond(
pred=torch.tensor([True]), pred=torch.tensor([True]),
true_fn=lambda x: (x + 1, x - 1), true_fn=lambda x: (x + 1, x - 1),
false_fn=lambda x: (x, x), false_fn=lambda x: (x.clone(), x.clone()),
operands=(x,), operands=(x,),
) )
@ -7361,8 +7532,6 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
functional_f(example_init, example_inputs), f(example_init, example_inputs) functional_f(example_init, example_inputs), f(example_init, example_inputs)
) )
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_scan_functionalized_elem_mutation(self): def test_scan_functionalized_elem_mutation(self):
def add1(x, y): def add1(x, y):
x.add_(4) x.add_(4)
@ -7375,8 +7544,15 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
example_init = torch.ones(5, 4) example_init = torch.ones(5, 4)
functional_f = torch.func.functionalize(f) functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex( with self.assertRaisesRegex(
UnsupportedAliasMutationException, # TODO: Fix this so that the HOPs show similar errors for functionalization
"Combine_fn might be modifying the input!", # This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
# RuntimeError,
# "torch.scan might be modifying the input!",
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
# torch._dynamo.exc.TorchDynamoException,
# "Unexpected exception when running generated GraphModule.*"
Exception,
".*",
): ):
functional_f(example_init, example_inputs) functional_f(example_init, example_inputs)
@ -7389,13 +7565,19 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
functional_f = torch.func.functionalize(f) functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex( with self.assertRaisesRegex(
UnsupportedAliasMutationException, # TODO: Fix this so that the HOPs show similar errors for functionalization
"Combine_fn might be modifying the input!", # Should be
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
# RuntimeError,
# "torch.scan might be modifying the input!",
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
# torch._dynamo.exc.TorchDynamoException,
# "Unexpected exception when running generated GraphModule.*"
Exception,
".*",
): ):
functional_f(example_init, example_inputs) functional_f(example_init, example_inputs)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_scan_functionalized_elem_alias(self): def test_scan_functionalized_elem_alias(self):
def add(x, y): def add(x, y):
return x, x return x, x
@ -7407,7 +7589,16 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
example_init = torch.ones(5, 4) example_init = torch.ones(5, 4)
functional_f = torch.func.functionalize(f) functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex( with self.assertRaisesRegex(
UnsupportedAliasMutationException, "Combine_fn might be aliasing the input!" # TODO: Fix this so that the HOPs show similar errors for functionalization
# Should be
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
# torch._dynamo.exc.UncapturedHigherOrderOpError,
# "scan must be captured completely with torch.compile.*",
Exception,
".*",
): ):
functional_f(example_init, example_inputs) functional_f(example_init, example_inputs)
@ -7917,7 +8108,11 @@ class GraphModule(torch.nn.Module):
x = torch.randn(2, 2) x = torch.randn(2, 2)
for f in ALIAS_FN: for f in ALIAS_FN:
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed, "might be aliasing the input" # Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
torch.compile(fn)(f, x) torch.compile(fn)(f, x)
@ -7933,7 +8128,11 @@ class GraphModule(torch.nn.Module):
# as a result of auto lifting. # as a result of auto lifting.
for view_f in ALIAS_FN[1:]: for view_f in ALIAS_FN[1:]:
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed, "might be aliasing the input" # Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
torch.compile(fn)(view_f, x) torch.compile(fn)(view_f, x)
@ -7950,12 +8149,20 @@ class GraphModule(torch.nn.Module):
x = torch.randn(2, 2) x = torch.randn(2, 2)
for f in ALIAS_FN: for f in ALIAS_FN:
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed, "might be modifying the input" # Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
torch.compile(fn)(f, x) torch.compile(fn)(f, x)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed, "might be modifying the input" # Should be
# torch._dynamo.exc.Unsupported,
# "Encountered aliasing during higher order op tracing for HOP.*"
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile.*",
): ):
with torch.inference_mode(inference_mode): with torch.inference_mode(inference_mode):
torch.compile(fn)(f, x) torch.compile(fn)(f, x)

View File

@ -570,7 +570,7 @@ class CondTests(TestCase):
return torch.cond(p, true_fn, false_fn, [a, b]) return torch.cond(p, true_fn, false_fn, [a, b])
# AssertionError: Output aliasing is currently not supported... # AssertionError: Output aliasing is currently not supported...
with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed): with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
torch.compile(Model())( torch.compile(Model())(
torch.tensor(True), torch.tensor(True),
torch.randn(10, 20), torch.randn(10, 20),

View File

@ -2831,11 +2831,13 @@ class SubgraphTracer(fx.Tracer):
return MutationInfo(False, "") return MutationInfo(False, "")
def has_aliasing(self): def has_aliasing(self):
from torch._higher_order_ops.utils import _collect_fake_inputs
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict() input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
for node in self.graph.nodes: for node in self.graph.nodes:
if node.op == "placeholder": if node.op == "placeholder":
example_value = node.meta["example_value"] example_value = _collect_fake_inputs([node])[0]
if isinstance(example_value, torch.Tensor): if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage()) storage = StorageWeakRef(example_value._typed_storage())
if storage in input_storages: if storage in input_storages:
@ -2848,9 +2850,9 @@ class SubgraphTracer(fx.Tracer):
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict() output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
out_nodes = self.graph.find_nodes(op="output")[0] out_nodes = self.graph.find_nodes(op="output")[0]
for out_node in out_nodes.args[0]: for out_node in pytree.tree_leaves(out_nodes.args[0]):
if out_node: if out_node:
example_value = out_node.meta["example_value"] example_value = _collect_fake_inputs([out_node])[0]
assert not isinstance(example_value, list) assert not isinstance(example_value, list)
if isinstance(example_value, torch.Tensor): if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage()) storage = StorageWeakRef(example_value._typed_storage())

View File

@ -961,6 +961,9 @@ class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable
class CondHigherOrderVariable(TorchHigherOrderOperatorVariable): class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
supports_input_mutation = False
supports_aliasing = False
@raise_hard_error_if_graph_break( @raise_hard_error_if_graph_break(
reason="Cond doesn't work unless it is captured completely with torch.compile." reason="Cond doesn't work unless it is captured completely with torch.compile."
) )
@ -1058,6 +1061,8 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
"cond", "cond",
source_target=self.value, source_target=self.value,
should_flatten_outputs=True, should_flatten_outputs=True,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
) )
if not only_consist_of(ret_val, (TensorVariable,)): if not only_consist_of(ret_val, (TensorVariable,)):
@ -1184,6 +1189,9 @@ def validate_subgraph_output_types(output: VariableTracker):
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
supports_input_mutation = False
supports_aliasing = False
@raise_hard_error_if_graph_break( @raise_hard_error_if_graph_break(
reason="while_loop doesn't work unless it is captured completely with torch.compile." reason="while_loop doesn't work unless it is captured completely with torch.compile."
) )
@ -1299,6 +1307,8 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
# So it's best we always enforce the ordering of carried_inputs the same as outputs # So it's best we always enforce the ordering of carried_inputs the same as outputs
# with "flatten_manual". # with "flatten_manual".
set_subgraph_inputs="flatten_manual", set_subgraph_inputs="flatten_manual",
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
) )
cond_nn_modules = dict(tx.output.nn_modules) cond_nn_modules = dict(tx.output.nn_modules)
validate_subgraph_output_types(cond_r) validate_subgraph_output_types(cond_r)
@ -1337,6 +1347,8 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
source_target=self.value, source_target=self.value,
set_subgraph_inputs="flatten_manual", set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True, should_flatten_outputs=True,
supports_input_mutation=False,
supports_aliasing=False,
) )
validate_subgraph_output_types(body_r) validate_subgraph_output_types(body_r)
@ -1420,6 +1432,9 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
supports_input_mutation = False
supports_aliasing = False
@raise_hard_error_if_graph_break( @raise_hard_error_if_graph_break(
reason="associative_scan must be captured completely with torch.compile." reason="associative_scan must be captured completely with torch.compile."
) )
@ -1513,6 +1528,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
description="associative_scan_combine_fn", description="associative_scan_combine_fn",
source_target=self.value, source_target=self.value,
set_subgraph_inputs="flatten_manual", set_subgraph_inputs="flatten_manual",
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
) )
# Ensure that the output of scan is a flattened list of elements, # Ensure that the output of scan is a flattened list of elements,
@ -1562,30 +1579,23 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
) )
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
combine_freevars_proxy = tuple(combine_lifted_freevars.keys())
from torch._higher_order_ops.utils import ( # Compute the proxies for the input check
_has_potential_branch_input_alias, proxy_vars_inputcheck = (
_has_potential_branch_input_mutation, tuple(sarg.as_proxy() for sarg in sub_args) + combine_freevars_proxy
_maybe_fake_tracing,
) )
from torch._higher_order_ops.utils import _maybe_fake_tracing
from torch._inductor.utils import is_pointwise_use from torch._inductor.utils import is_pointwise_use
from torch._subclasses.fake_tensor import FakeTensor
with tx.fake_mode: with tx.fake_mode:
xs_fake = [ sub_args_fake = [
first_slice_copy(leaf.proxy.node.meta["example_value"].clone())
for leaf in itertools.chain(xs_vars, xs_vars)
]
additional_fake = [
leaf.proxy.node.meta["example_value"].clone()
for leaf in additional_inputs_vars
] + [
leaf.node.meta["example_value"].clone() leaf.node.meta["example_value"].clone()
if isinstance(leaf.node.meta["example_value"], FakeTensor) if hasattr(leaf.node.meta["example_value"], "clone")
else leaf.node.meta["example_value"] else leaf.node.meta["example_value"]
for leaf in combine_lifted_freevars.keys() for leaf in pytree.tree_leaves(proxy_vars_inputcheck)
] ]
sub_args_fake = xs_fake + additional_fake
pre_dispatch = False pre_dispatch = False
fx = _maybe_fake_tracing( fx = _maybe_fake_tracing(
@ -1601,15 +1611,6 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
"For combine_mode='pointwise', the combine_fn needs to be pointwise" "For combine_mode='pointwise', the combine_fn needs to be pointwise"
) )
if _has_potential_branch_input_mutation(
combine_gm, sub_args_fake, pre_dispatch=pre_dispatch
):
raise RuntimeError("Combine_fn might be modifying the input!") # noqa: F541
if _has_potential_branch_input_alias(
combine_gm, sub_args_fake, pre_dispatch=pre_dispatch
):
raise RuntimeError("Combine_fn might be aliasing the input!") # noqa: F541
combine_fn_name = tx.output.install_subgraph( combine_fn_name = tx.output.install_subgraph(
"associative_scan_combine_fn", combine_gm "associative_scan_combine_fn", combine_gm
) )
@ -1641,6 +1642,9 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable): class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
supports_input_mutation = False
supports_aliasing = False
@raise_hard_error_if_graph_break( @raise_hard_error_if_graph_break(
reason="scan must be captured completely with torch.compile." reason="scan must be captured completely with torch.compile."
) )
@ -1752,7 +1756,10 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
description="scan_combine_fn", description="scan_combine_fn",
source_target=self.value, source_target=self.value,
set_subgraph_inputs="flatten_manual", set_subgraph_inputs="flatten_manual",
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
) )
# Ensure that the output of scan is a flattened list of elements, # Ensure that the output of scan is a flattened list of elements,
# because downstream operations assume that the output of HOPs # because downstream operations assume that the output of HOPs
# is flattened # is flattened
@ -1848,6 +1855,9 @@ def non_single_tensor_return_unsupported(api, ret):
class MapHigherOrderVariable(TorchHigherOrderOperatorVariable): class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
supports_input_mutation = False
supports_aliasing = False
@raise_hard_error_if_graph_break( @raise_hard_error_if_graph_break(
reason="map doesn't work unless it is captured completely with torch.compile." reason="map doesn't work unless it is captured completely with torch.compile."
) )
@ -1910,6 +1920,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
source_target=self.value, source_target=self.value,
set_subgraph_inputs="flatten_manual", set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True, should_flatten_outputs=True,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
) )
body_nn_modules = dict(tx.output.nn_modules) body_nn_modules = dict(tx.output.nn_modules)

View File

@ -434,12 +434,27 @@ def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, additional_inputs):
@associative_scan_op.py_functionalize_impl @associative_scan_op.py_functionalize_impl
def associative_scan_functionalize(ctx, combine_fn, xs, additional_inputs): def associative_scan_functionalize(ctx, combine_fn, xs, additional_inputs):
from torch._higher_order_ops.utils import _check_alias_and_mutation
unwrapped_xs = ctx.unwrap_tensors(xs) unwrapped_xs = ctx.unwrap_tensors(xs)
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
with ctx.redispatch_to_next(): with ctx.redispatch_to_next():
functional_combine_fn = ctx.functionalize( functional_combine_fn = ctx.functionalize(
_maybe_run_with_interpreter(combine_fn) _maybe_run_with_interpreter(combine_fn)
) )
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
sample_unwrapped_xs_sliced = [
first_slice_copy(inp) for inp in itertools.chain(unwrapped_xs, unwrapped_xs)
]
sample_inputs = list(
itertools.chain(
sample_unwrapped_xs_sliced,
unwrapped_additional_inputs,
)
)
_check_alias_and_mutation(
combine_fn, sample_inputs, "associative_scan", pre_dispatch
)
ret = associative_scan_op( ret = associative_scan_op(
functional_combine_fn, functional_combine_fn,
unwrapped_xs, unwrapped_xs,

View File

@ -136,10 +136,10 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
for ph in subgraph.graph.find_nodes(op="placeholder") for ph in subgraph.graph.find_nodes(op="placeholder")
] ]
( (
mutated_inp_idx,
inp_inp_alias, inp_inp_alias,
inp_out_alias, inp_out_alias,
out_out_alias, out_out_alias,
mutated_inp_idx,
output, output,
) = check_input_alias_and_mutation_return_ouputs(subgraph, fake_args) ) = check_input_alias_and_mutation_return_ouputs(subgraph, fake_args)

View File

@ -18,8 +18,6 @@ from torch._C._functorch import (
from torch._dispatch.python import suspend_functionalization from torch._dispatch.python import suspend_functionalization
from torch._functorch.utils import exposed_in from torch._functorch.utils import exposed_in
from torch._higher_order_ops.utils import ( from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_reenter_make_fx, _maybe_reenter_make_fx,
_maybe_run_with_interpreter, _maybe_run_with_interpreter,
_set_compilation_env, _set_compilation_env,
@ -27,7 +25,6 @@ from torch._higher_order_ops.utils import (
save_tensors_and_symints_for_backward, save_tensors_and_symints_for_backward,
saved_tensors_and_symints, saved_tensors_and_symints,
unique_graph_id, unique_graph_id,
UnsupportedAliasMutationException,
validate_subgraph_args_types, validate_subgraph_args_types,
) )
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
@ -668,29 +665,18 @@ def _merge_tensors(
@cond_op.py_functionalize_impl @cond_op.py_functionalize_impl
def cond_func(ctx, pred, true_fn, false_fn, inputs): def cond_func(ctx, pred, true_fn, false_fn, inputs):
from torch._higher_order_ops.utils import _check_alias_and_mutation
unwrapped_inputs = ctx.unwrap_tensors(inputs) unwrapped_inputs = ctx.unwrap_tensors(inputs)
unwrapped_pred = ctx.unwrap_tensors(pred) unwrapped_pred = ctx.unwrap_tensors(pred)
with ctx.redispatch_to_next(): with ctx.redispatch_to_next():
functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn)) functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn))
functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn)) functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn))
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
for branch in [true_fn, false_fn]: for branch, branch_name in [(true_fn, "cond_true"), (false_fn, "cond_false")]:
if _has_potential_branch_input_mutation( _check_alias_and_mutation(
branch, unwrapped_inputs, pre_dispatch=pre_dispatch branch, unwrapped_inputs, branch_name, pre_dispatch
): )
raise UnsupportedAliasMutationException(
"One of torch.cond branch might be modifying the input! "
"Consider cloning the input before modifying it. "
)
for branch in [true_fn, false_fn]:
if _has_potential_branch_input_alias(
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"One of torch.cond branch might be aliasing the input! "
"If you are returning a view of the input, please make sure "
"to clone it. "
)
cond_return = cond_op( cond_return = cond_op(
unwrapped_pred, functional_true, functional_false, unwrapped_inputs unwrapped_pred, functional_true, functional_false, unwrapped_inputs

View File

@ -420,6 +420,9 @@ def flex_attention_functionalize(
functional_score_mod = ctx.functionalize(score_mod) functional_score_mod = ctx.functionalize(score_mod)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
with TransformGetItemToIndex(): with TransformGetItemToIndex():
# TODO: So far only the input mutations are checked
# In the other HOPs, also aliases are checked which is
# omitted here
mutates = _has_potential_branch_input_mutation( mutates = _has_potential_branch_input_mutation(
score_mod, example_vals, pre_dispatch score_mod, example_vals, pre_dispatch
) )

View File

@ -3,12 +3,9 @@ import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._C import DispatchKey from torch._C import DispatchKey
from torch._higher_order_ops.utils import ( from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
autograd_not_implemented, autograd_not_implemented,
reenter_make_fx, reenter_make_fx,
unique_graph_id, unique_graph_id,
UnsupportedAliasMutationException,
) )
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
@ -91,24 +88,18 @@ def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints):
@hints_wrapper.py_functionalize_impl @hints_wrapper.py_functionalize_impl
def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints): def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
from torch._higher_order_ops.utils import _check_alias_and_mutation
unwrapped_args = ctx.unwrap_tensors(args) unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
unwrapped_hints = ctx.unwrap_tensors(hints) unwrapped_hints = ctx.unwrap_tensors(hints)
with ctx.redispatch_to_next(): with ctx.redispatch_to_next():
functional_body_fn = ctx.functionalize(body_fn) functional_body_fn = ctx.functionalize(body_fn)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
if _has_potential_branch_input_mutation( _check_alias_and_mutation(
body_fn, unwrapped_args, pre_dispatch=pre_dispatch body_fn, unwrapped_args, "hints_wrapper", pre_dispatch
): )
raise UnsupportedAliasMutationException(
"body_fn of hints_wrapper might be modifying the input!"
)
if _has_potential_branch_input_alias(
body_fn, unwrapped_args, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"body_fn of hints_wrapper might be aliasing the input!"
)
outputs = hints_wrapper( outputs = hints_wrapper(
functional_body_fn, functional_body_fn,
unwrapped_args, unwrapped_args,

View File

@ -7,13 +7,7 @@ import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._C import DispatchKey from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization from torch._dispatch.python import suspend_functionalization
from torch._higher_order_ops.utils import ( from torch._higher_order_ops.utils import _maybe_run_with_interpreter, reenter_make_fx
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter,
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode from torch._subclasses.functional_tensor import disable_functional_mode
@ -284,23 +278,15 @@ def map_fake_tensor_mode(mode, f, xs, args):
@map_impl.py_functionalize_impl @map_impl.py_functionalize_impl
def map_functionalize(ctx, f, xs, pos_args): def map_functionalize(ctx, f, xs, pos_args):
from torch._higher_order_ops.utils import _check_alias_and_mutation
unwrapped_xs = ctx.unwrap_tensors(xs) unwrapped_xs = ctx.unwrap_tensors(xs)
unwrapped_args = ctx.unwrap_tensors(pos_args) unwrapped_args = ctx.unwrap_tensors(pos_args)
wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f)) wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))
with ctx.redispatch_to_next(): with ctx.redispatch_to_next():
with disable_proxy_modes_tracing(): example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
if _has_potential_branch_input_mutation( _check_alias_and_mutation(f, example_inputs, "map", pre_dispatch)
f, example_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
if _has_potential_branch_input_alias(
f, example_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args) map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
return ctx.wrap_tensors(map_return) return ctx.wrap_tensors(map_return)

View File

@ -10,8 +10,6 @@ import torch.utils._pytree as pytree
from torch._C import DispatchKey from torch._C import DispatchKey
from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph
from torch._higher_order_ops.utils import ( from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_compile_and_run_fn, _maybe_compile_and_run_fn,
check_meta_consistency, check_meta_consistency,
first_slice_copy, first_slice_copy,
@ -19,7 +17,6 @@ from torch._higher_order_ops.utils import (
save_tensors_and_symints_for_backward, save_tensors_and_symints_for_backward,
saved_tensors_and_symints, saved_tensors_and_symints,
unique_graph_id, unique_graph_id,
UnsupportedAliasMutationException,
validate_subgraph_args_types, validate_subgraph_args_types,
) )
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
@ -861,12 +858,19 @@ def scan_fake_tensor_mode(mode, combine_fn, init, xs, additional_inputs):
@scan_op.py_functionalize_impl @scan_op.py_functionalize_impl
def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs): def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
from torch._higher_order_ops.utils import (
_check_alias_and_mutation,
_maybe_run_with_interpreter,
)
unwrapped_xs = ctx.unwrap_tensors(xs) unwrapped_xs = ctx.unwrap_tensors(xs)
unwrapped_init = ctx.unwrap_tensors(init) unwrapped_init = ctx.unwrap_tensors(init)
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
with ctx.redispatch_to_next(): with ctx.redispatch_to_next():
functional_combine_fn = ctx.functionalize(combine_fn) functional_combine_fn = ctx.functionalize(
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch _maybe_run_with_interpreter(combine_fn)
)
sample_unwrapped_xs_sliced = [first_slice_copy(inp) for inp in unwrapped_xs] sample_unwrapped_xs_sliced = [first_slice_copy(inp) for inp in unwrapped_xs]
sample_inputs = list( sample_inputs = list(
itertools.chain( itertools.chain(
@ -875,18 +879,8 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
unwrapped_additional_inputs, unwrapped_additional_inputs,
) )
) )
if _has_potential_branch_input_mutation( pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
combine_fn, sample_inputs, pre_dispatch=pre_dispatch _check_alias_and_mutation(combine_fn, sample_inputs, "scan", pre_dispatch)
):
raise UnsupportedAliasMutationException(
"Combine_fn might be modifying the input!"
)
if _has_potential_branch_input_alias(
combine_fn, sample_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"Combine_fn might be aliasing the input!"
)
ret = scan_op( ret = scan_op(
functional_combine_fn, functional_combine_fn,
unwrapped_init, unwrapped_init,

View File

@ -248,34 +248,6 @@ def _set_compilation_env():
torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs
def _detect_input_mutation(gm: torch.fx.GraphModule) -> bool:
example_inputs = [
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
]
inp_mutation, _, _, _ = check_input_alias_and_mutation(gm, example_inputs)
if len(inp_mutation) > 0:
return True
for _, module in gm.named_children():
if isinstance(module, torch.fx.GraphModule):
if _detect_input_mutation(module):
return True
return False
def _detect_input_alias(gm: torch.fx.GraphModule) -> bool:
example_inputs = [
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
]
_, inp_inp_alias_map, inp_out_alias_map, _ = check_input_alias_and_mutation(
gm, example_inputs
)
if len(inp_out_alias_map) > 0 or len(inp_inp_alias_map) > 0:
return True
return False
# The invariant here is that we always trace the branch with fake tensor # The invariant here is that we always trace the branch with fake tensor
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch): def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
fake_mode = detect_fake_mode(inputs) fake_mode = detect_fake_mode(inputs)
@ -301,7 +273,7 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
return gm return gm
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False): def potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
try: try:
gm = _maybe_fake_tracing(gm, inputs, pre_dispatch) gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
except UnsupportedAliasMutationException: except UnsupportedAliasMutationException:
@ -311,43 +283,113 @@ def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
except Exception as e: except Exception as e:
raise e raise e
return _detect_input_mutation(gm) or _detect_input_alias(gm) example_inputs = [
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
]
(
inp_inp_alias_map,
inp_out_alias_map,
out_out_alias_map,
inp_mutation,
) = check_input_alias_and_mutation(gm, example_inputs)
return (inp_inp_alias_map, inp_out_alias_map, out_out_alias_map), inp_mutation
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations):
""" if any(len(a) > 0 for a in aliases):
Dispatch-trace the branch with inputs and check if # TODO: Investigate here further which node is exactly aliasing
producing graph has mutable op on the input. This is raise RuntimeError(
bit restrictive as the branch must be traceable. f"{name} where aliases appear. "
""" + f"In particular, these inputs \
try: {set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map.keys())} " # noqa: C401
gm = _maybe_fake_tracing(branch, inputs, pre_dispatch) + "get aliased. Please ensure that this doesn't happen."
except UnsupportedAliasMutationException: )
# this can happen when nested cond_op is if len(input_mutations):
# functionalized # TODO: Investigate here further which node is exactly mutating the inputs
return True raise RuntimeError(
except Exception as e: f"{name} where the inputs are mutated. "
raise e + f"In particular, these nodes are mutating the inputs \
{set(el for el in input_mutations)}." # noqa: C401
return _detect_input_mutation(gm) + "Please ensure that this doesn't happen."
)
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False): def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False):
""" (
Dispatch-trace the branch with inputs and check if _,
producing graph has output aliasing the branch input. This is _,
bit restrictive as the branch must be traceable. _,
""" ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
try:
gm = _maybe_fake_tracing(branch, inputs, pre_dispatch)
except UnsupportedAliasMutationException:
# this can happen when nested cond_op is
# functionalized
return True
except Exception as e:
raise e
return _detect_input_alias(gm) return len(inp_mutation) > 0
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
(
inp_inp_alias_map,
inp_out_alias_map,
out_out_alias_map,
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
return (
any(
(
len(inp_inp_alias_map) > 0,
len(inp_out_alias_map) > 0,
len(out_out_alias_map) > 0,
)
),
len(inp_mutation) > 0,
)
def _collect_fake_inputs(inputs):
from torch._subclasses.fake_tensor import FakeTensor
# Get the example values of the inputs.
inputs_fake: list[Union[FakeTensor, torch.Tensor, int]] = []
for inp in inputs:
if isinstance(inp, (torch.fx.proxy.Proxy, torch.fx.node.Node)):
inp = inp.node if isinstance(inp, torch.fx.proxy.Proxy) else inp
if hasattr(inp, "meta"):
val = inp.meta["example_value"]
if isinstance(val, torch.Tensor):
if torch._C._functorch.is_batchedtensor(
val
) or torch._C._functorch.is_functionaltensor(val):
# This case is for batched or functional tensors
# Unwrap the tensors
while torch._C._functorch.is_batchedtensor(
val
) or torch._C._functorch.is_functionaltensor(val):
val = torch._C._functorch.get_unwrapped(val)
assert isinstance(val, FakeTensor)
inputs_fake.append(val)
else:
# This is the standard case of a TensorVariable
assert isinstance(val, FakeTensor)
inputs_fake.append(val)
else:
# This case is for SymInts and other non-Tensor elements
assert not isinstance(val, torch.Tensor)
inputs_fake.append(val)
else:
# This case is for ints
assert isinstance(inp, int)
inputs_fake.append(inp)
return inputs_fake
def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch):
aliases, inp_mutation = has_potential_input_alias_or_mutation(
graph_module, inputs_fake, pre_dispatch=pre_dispatch
)
if aliases:
raise RuntimeError(
f"{name} might be aliasing the input or the output!"
) # noqa: F541
if inp_mutation:
raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541
def unique_graph_id(proxy_mode, prefix): def unique_graph_id(proxy_mode, prefix):
@ -699,27 +741,29 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]])
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}" ), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
# TODO: Return a more detailed information as to which node
# causes a mutation or an alias. This may requires a per operator tensor version checking
def check_input_alias_and_mutation( def check_input_alias_and_mutation(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
fake_args: list[FakeTensor], fake_args: list[FakeTensor],
) -> tuple[list[int], dict[int, int], dict[int, int], dict[int, int]]: ) -> tuple[dict[int, int], dict[int, int], dict[int, int], list[int]]:
( (
mutated_inputs,
inp_inp_alias_map, inp_inp_alias_map,
inp_out_alias_map, inp_out_alias_map,
out_out_alias_map, out_out_alias_map,
mutated_inputs,
) = check_input_alias_and_mutation_return_ouputs(gm, fake_args)[:-1] ) = check_input_alias_and_mutation_return_ouputs(gm, fake_args)[:-1]
return mutated_inputs, inp_inp_alias_map, inp_out_alias_map, out_out_alias_map return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
def check_input_alias_and_mutation_return_ouputs( def check_input_alias_and_mutation_return_ouputs(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
fake_args: list[FakeTensor], fake_args: list[FakeTensor],
) -> tuple[ ) -> tuple[
dict[int, int],
dict[int, int],
dict[int, int],
list[int], list[int],
dict[int, int],
dict[int, int],
dict[int, int],
Union[tuple[Any, ...], list[Any]], Union[tuple[Any, ...], list[Any]],
]: ]:
# We want to disable active functional, proxy and fake modes if any. # We want to disable active functional, proxy and fake modes if any.
@ -825,10 +869,10 @@ def check_input_alias_and_mutation_return_ouputs(
if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map
} }
return ( return (
mutated_inputs,
inp_inp_alias_map, inp_inp_alias_map,
inp_out_alias_map, inp_out_alias_map,
out_out_alias_map, out_out_alias_map,
mutated_inputs,
outputs, outputs,
) )

View File

@ -6,14 +6,11 @@ import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._C import DispatchKey from torch._C import DispatchKey
from torch._higher_order_ops.utils import ( from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter, _maybe_run_with_interpreter,
_set_compilation_env, _set_compilation_env,
autograd_not_implemented, autograd_not_implemented,
check_meta_consistency, check_meta_consistency,
reenter_make_fx, reenter_make_fx,
UnsupportedAliasMutationException,
validate_subgraph_args_types, validate_subgraph_args_types,
) )
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
@ -400,6 +397,8 @@ def while_loop_fake_tensor_mode(
@while_loop_op.py_functionalize_impl @while_loop_op.py_functionalize_impl
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs): def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
from torch._higher_order_ops.utils import _check_alias_and_mutation
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs) unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
@ -411,19 +410,7 @@ def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
(cond_fn, "cond_fn"), (cond_fn, "cond_fn"),
(body_fn, "body_fn"), (body_fn, "body_fn"),
]: ]:
if _has_potential_branch_input_mutation( _check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch)
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
f"torch.while_loop's {fn_name} might be modifying the input!"
)
if _has_potential_branch_input_alias(
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
f"torch.while_loop's {fn_name} might be aliasing the input!"
)
ret = while_loop_op( ret = while_loop_op(
functional_cond_fn, functional_cond_fn,
functional_body_fn, functional_body_fn,

View File

@ -7898,7 +7898,7 @@ class WhileLoop(ExternKernel):
# Handling input mutations # Handling input mutations
mutated_idxs = check_input_alias_and_mutation( mutated_idxs = check_input_alias_and_mutation(
body_fn.graph.module, fake_all_inputs body_fn.graph.module, fake_all_inputs
)[0] )[3]
mutated_idx_set = OrderedSet(mutated_idxs) mutated_idx_set = OrderedSet(mutated_idxs)
mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set] mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set]
real_outputs = { real_outputs = {