mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[HOP] Mutation and alias rework (#146658)
This PR reworks the way the input mutations and various aliases are checked Pull Request resolved: https://github.com/pytorch/pytorch/pull/146658 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
0e805aad7f
commit
68034198e5
|
|
@ -1,5 +1,4 @@
|
|||
from torch import cond # noqa: F401
|
||||
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
|
||||
from torch._higher_order_ops.map import ( # noqa: F401
|
||||
_stack_pytree,
|
||||
_unstack_pytree,
|
||||
|
|
|
|||
|
|
@ -1873,7 +1873,7 @@ def forward(self, x, y):
|
|||
return x + x
|
||||
|
||||
def false_fn(x):
|
||||
return x[:2]
|
||||
return x[:2].clone()
|
||||
|
||||
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
|
||||
|
||||
|
|
@ -1883,7 +1883,7 @@ def forward(self, x, y):
|
|||
return x + x
|
||||
|
||||
def false_fn(x):
|
||||
return x[:2]
|
||||
return x[:2].clone()
|
||||
|
||||
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
|
||||
|
||||
|
|
@ -1924,7 +1924,8 @@ def forward(self, l_x_):
|
|||
def forward(self, l_x_):
|
||||
l_x__1 = l_x_
|
||||
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
|
||||
return (getitem,)""",
|
||||
clone = getitem.clone(); getitem = None
|
||||
return (clone,)""",
|
||||
)
|
||||
# We could successfully export branches that return different sizes
|
||||
torch._dynamo.export(mod)(torch.randn(3, 2))
|
||||
|
|
@ -3302,7 +3303,12 @@ def forward(self, x):
|
|||
|
||||
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
|
||||
def f_branch_return_multiple_tensors(pred, x, y):
|
||||
return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
|
||||
return cond(
|
||||
pred,
|
||||
lambda x: (x.clone(), x.clone()),
|
||||
lambda x: (x.clone(), x.clone()),
|
||||
[y],
|
||||
)
|
||||
|
||||
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
|
||||
gm, _ = torch._dynamo.export(
|
||||
|
|
@ -3324,10 +3330,10 @@ def forward(self, x):
|
|||
|
||||
def test_cond_raise_user_error_on_mismatch_return_length(self):
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
return (x, x)
|
||||
return (x.clone(), x.clone())
|
||||
|
||||
def f_mismatch_return_length(x):
|
||||
return cond(torch.tensor(100), true_fn, false_fn, [x])
|
||||
|
|
|
|||
|
|
@ -1791,7 +1791,13 @@ def forward(self, child : torch.Tensor):
|
|||
|
||||
def test_map_pytree_return(self):
|
||||
def _construct_pytree(a):
|
||||
return (a, [[[a]]], a, (a, (a,), a), {"a": a})
|
||||
return (
|
||||
a.clone(),
|
||||
[[[a.clone()]]],
|
||||
a.clone(),
|
||||
(a.clone(), (a.clone(),), a.clone()),
|
||||
{"a": a.clone()},
|
||||
)
|
||||
|
||||
def f(x):
|
||||
def inner_f(xs):
|
||||
|
|
@ -1823,7 +1829,14 @@ def forward(self, L_x_ : torch.Tensor):
|
|||
body_graph,
|
||||
"""\
|
||||
def forward(self, child : torch.Tensor):
|
||||
return (child, child, child, child, child, child, child)""",
|
||||
child_1 = child.clone()
|
||||
child_2 = child.clone()
|
||||
child_3 = child.clone()
|
||||
child_4 = child.clone()
|
||||
child_5 = child.clone()
|
||||
child_6 = child.clone()
|
||||
child_7 = child.clone(); child = None
|
||||
return (child_1, child_2, child_3, child_4, child_5, child_6, child_7)""",
|
||||
)
|
||||
|
||||
def test_map_kwargs(self):
|
||||
|
|
@ -6902,7 +6915,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def test(pred, x):
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
return -x
|
||||
|
|
@ -6926,7 +6939,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def test(pred, mode, x):
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
return -x
|
||||
|
|
|
|||
|
|
@ -5931,7 +5931,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
from functorch.experimental.control_flow import cond
|
||||
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
return x.sin()
|
||||
|
|
|
|||
|
|
@ -7604,20 +7604,25 @@ def forward(self, b_a_buffer, x):
|
|||
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
|
||||
|
||||
@requires_cuda
|
||||
@testing.expectedFailureCppRuntime
|
||||
def test_export_associative_scan_lifted_buffers(self):
|
||||
device = torch.device("cuda")
|
||||
combine_mode = "pointwise"
|
||||
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(3, 2, device=device))
|
||||
|
||||
def forward(self):
|
||||
return self.buffer.cos()
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"buf", torch.ones(3, 2, device=device), persistent=False
|
||||
)
|
||||
self.a = A()
|
||||
|
||||
def combine_fn(self, x, y):
|
||||
return x + y * self.buf
|
||||
return (x + y) * self.a()
|
||||
|
||||
def forward(self, x):
|
||||
return associative_scan(
|
||||
|
|
|
|||
|
|
@ -4572,17 +4572,17 @@ class <lambda>(torch.nn.Module):
|
|||
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
|
||||
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
|
||||
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
|
||||
|
||||
body_graph_1 = self.body_graph_1
|
||||
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
|
||||
getitem_1: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
|
||||
getitem_5: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
|
||||
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_5); getitem_5 = None
|
||||
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
|
||||
return (add_1,)
|
||||
|
|
@ -4635,9 +4635,9 @@ class <lambda>(torch.nn.Module):
|
|||
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
|
||||
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||
return (add,)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import unittest
|
|||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from functorch.experimental import control_flow
|
||||
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
|
||||
from functorch.experimental.control_flow import cond
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._higher_order_ops.associative_scan import (
|
||||
_fake_associative_scan,
|
||||
|
|
@ -36,7 +36,6 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_CROSSREF,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -159,7 +158,7 @@ def get_scan_combine_fn(name, associative=True, parameters=None):
|
|||
def RNN(x: torch.Tensor, y: torch.Tensor):
|
||||
c_new = y @ parameters[0] + parameters[1]
|
||||
h_new = torch.tanh(c_new + x @ parameters[2] + parameters[3])
|
||||
return h_new, h_new
|
||||
return h_new, h_new.clone()
|
||||
|
||||
def fct_c1_no_grad(x: torch.Tensor, y: torch.Tensor):
|
||||
h_new = torch.tanh(x[0] + x[1] + y)
|
||||
|
|
@ -888,8 +887,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
|
|||
)
|
||||
|
||||
def test_cond_autograd_pytree_input(self):
|
||||
# TODO: This is an unexpected behavior for cond
|
||||
# Without this additional multiplication,
|
||||
# the output of the backward graph would alias the
|
||||
# inputs, as the gradients are just 1s and thus get optimized
|
||||
def true_fn(x):
|
||||
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
|
||||
return (x["t"][0] * 2.0) + x["t"][1]["b"] * x["t"][2][0]
|
||||
|
||||
def false_fn(x):
|
||||
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
||||
|
|
@ -966,10 +969,10 @@ def forward(self, pred_1):
|
|||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||
def test_cond_autograd_same_pytree_output(self):
|
||||
def true_fn(x):
|
||||
return {"res": [x["t"][0], (x["t"][2][0],)]}
|
||||
return {"res": [x["t"][0].clone(), (x["t"][2][0].clone(),)]}
|
||||
|
||||
def false_fn(x):
|
||||
return {"res": [x["t"][1]["b"], (x["t"][2][0],)]}
|
||||
return {"res": [x["t"][1]["b"].clone(), (x["t"][2][0].clone(),)]}
|
||||
|
||||
a = torch.randn(4, requires_grad=True)
|
||||
b = torch.randn(4, requires_grad=True)
|
||||
|
|
@ -1007,9 +1010,7 @@ def forward(self, pred_1):
|
|||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
||||
getitem = cond[0]
|
||||
getitem_1 = cond[1]; cond = None
|
||||
view = torch.ops.aten.view.default(getitem, [4]); getitem = None
|
||||
view_1 = torch.ops.aten.view.default(getitem_1, [4]); getitem_1 = None
|
||||
return {'res': [view, (view_1,)]}""", # noqa: B950
|
||||
return {'res': [getitem, (getitem_1,)]}""", # noqa: B950
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||
|
|
@ -1901,12 +1902,15 @@ def forward(self, pred_1, x_1):
|
|||
{
|
||||
"i": x["i"] * y["j"][0][0],
|
||||
"k": torch.tensor(0.0),
|
||||
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]),
|
||||
"j": (
|
||||
[x["j"][1][0]["o"].clone()],
|
||||
[{"o": torch.sin(x["i"])}],
|
||||
),
|
||||
},
|
||||
{
|
||||
"i": x["i"] * y["j"][0][0],
|
||||
"k": torch.tensor(0.0),
|
||||
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]),
|
||||
"j": ([x["j"][1][0]["o"].clone()], [{"o": torch.sin(x["i"])}]),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -2395,8 +2399,11 @@ def forward(self, pred_1, x_1):
|
|||
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Expected init and carry to have same metadata.*",
|
||||
"scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
scan_fct(wrong_carry_shape, init, x, dim=dim)
|
||||
|
||||
|
|
@ -3311,6 +3318,114 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
|
|||
return (carry, out_1)""", # noqa: B950
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_scan_input_mutation(self):
|
||||
device = torch.device("cuda")
|
||||
|
||||
def fct_input_mutation(x, y):
|
||||
x.add_(1)
|
||||
return x + y, x + y + 2
|
||||
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
init = torch.randn(2, 2, device=device)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
scan(fct_input_mutation, init, x, dim=0)
|
||||
|
||||
@requires_cuda
|
||||
def test_scan_input_carry_alias(self):
|
||||
device = torch.device("cuda")
|
||||
|
||||
def fct_input_output_alias(x, y):
|
||||
return (x[0], x[1] + y[1]), (x[1] + y[1] + 1, x[1] + y[1] + 2)
|
||||
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
y = torch.randn(3, 2, 2, device=device)
|
||||
inp = (x, y)
|
||||
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
scan(fct_input_output_alias, init, inp, dim=0)
|
||||
|
||||
@requires_cuda
|
||||
def test_scan_input_output_alias(self):
|
||||
device = torch.device("cuda")
|
||||
|
||||
def fct_input_output_alias(x, y):
|
||||
return (x[0] + 1, x[1] + y[1]), (x[1], x[1] + y[1] + 2)
|
||||
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
y = torch.randn(3, 2, 2, device=device)
|
||||
inp = (x, y)
|
||||
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
scan(fct_input_output_alias, init, inp, dim=0)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@requires_cuda
|
||||
def test_scan_carry_carry_alias(self):
|
||||
device = torch.device("cuda")
|
||||
|
||||
def fct_carry_carry_alias(x, y):
|
||||
c = x[0] + y[1]
|
||||
return (c, c), (x[0] + y[1], x[0] + y[1] + 1)
|
||||
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
y = torch.randn(3, 2, 2, device=device)
|
||||
inp = (x, y)
|
||||
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
scan(fct_carry_carry_alias, init, inp, dim=0)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@requires_cuda
|
||||
def test_scan_carry_output_alias(self):
|
||||
device = torch.device("cuda")
|
||||
|
||||
def fct_carry_output_alias(x, y):
|
||||
c = x[0] + y[1]
|
||||
return (x[0] + y[1], c), (c, x[0] + y[1] + 1)
|
||||
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
y = torch.randn(3, 2, 2, device=device)
|
||||
inp = (x, y)
|
||||
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
scan(fct_carry_output_alias, init, inp, dim=0)
|
||||
|
||||
|
||||
class AssociativeScanModels:
|
||||
@staticmethod
|
||||
|
|
@ -4158,7 +4273,7 @@ class GraphModule(torch.nn.Module):
|
|||
)
|
||||
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
|
||||
def combine_fn(x, y):
|
||||
val = cond(torch.sum(y) > 0.0, lambda y: y + 0.0, lambda y: 1.0 - y, (y,))
|
||||
val = cond(torch.sum(y) > 0.0, lambda y: y.clone(), lambda y: 1.0 - y, (y,))
|
||||
return x * val
|
||||
|
||||
inp = torch.randn(3, 10, 1, device=device)
|
||||
|
|
@ -4775,8 +4890,10 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
RuntimeError,
|
||||
"Combine_fn might be modifying the input!",
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"associative_scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
associative_scan(fct_input_mutation, x, 0)
|
||||
|
||||
|
|
@ -4793,11 +4910,35 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
RuntimeError,
|
||||
"Combine_fn might be aliasing the input!",
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"associative_scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
associative_scan(fct_input_output_alias, inp, 0)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@requires_cuda
|
||||
def test_associative_scan_output_output_alias(self):
|
||||
device = torch.device("cuda")
|
||||
|
||||
def fct_output_output_alias(x, y):
|
||||
c = x[0] + y[1]
|
||||
return c, c
|
||||
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
y = torch.randn(3, 2, 2, device=device)
|
||||
inp = (x, y)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"associative_scan must be captured completely with torch.compile.*",
|
||||
):
|
||||
associative_scan(fct_output_output_alias, inp, 0)
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
||||
@skipIfNoDynamoSupport
|
||||
|
|
@ -5597,8 +5738,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
|||
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
def test_cond_functionalized_input_mutation_on_true_brancte(self):
|
||||
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
||||
def true_fn(x):
|
||||
view_x = x.view(x.shape)
|
||||
view_x.add_(1)
|
||||
|
|
@ -5632,13 +5772,13 @@ def forward(self, x_1):
|
|||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"cond_true might be modifying the input!",
|
||||
):
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
||||
def true_fn(x):
|
||||
return x.sin().sum()
|
||||
|
|
@ -5673,16 +5813,16 @@ def forward(self, x_1):
|
|||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"cond_false might be modifying the input!",
|
||||
):
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
def test_cond_functionalized_output_alias_input(self):
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
view_x = x.view(x.shape)
|
||||
|
|
@ -5707,13 +5847,16 @@ def forward(self, x_1):
|
|||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
def test_cond_functionalized_nested_input_mutation(self):
|
||||
def true_true_fn(x):
|
||||
x.add_(4)
|
||||
|
|
@ -5735,13 +5878,13 @@ def forward(self, x_1):
|
|||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"cond_true might be modifying the input!",
|
||||
):
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
||||
def true_true_fn(x):
|
||||
x.add_(4)
|
||||
|
|
@ -5768,7 +5911,11 @@ def forward(self, x_1):
|
|||
f(example_input_func)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
make_fx(f, tracing_mode="symbolic")(example_input_func)
|
||||
finally:
|
||||
|
|
@ -5786,7 +5933,11 @@ def forward(self, x_1):
|
|||
return wrapper
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
|
||||
|
||||
|
|
@ -5807,8 +5958,11 @@ def forward(self, x_1):
|
|||
example_input_func = to_fun_old(example_input)
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"One of torch.cond branch might be aliasing",
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
f(example_input_func)
|
||||
finally:
|
||||
|
|
@ -5838,8 +5992,11 @@ def forward(self, x_1):
|
|||
return wrapper
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"One of torch.cond branch might be aliasing",
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
||||
|
||||
|
|
@ -5974,8 +6131,11 @@ def forward(self, arg0_1):
|
|||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"Unmatched output spec from torch.cond branches",
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
make_fx(f)(x, torch.tensor(False))
|
||||
|
||||
|
|
@ -6147,8 +6307,11 @@ def forward(self, arg0_1):
|
|||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"Unmatched output spec from torch.cond branches",
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||
|
||||
|
|
@ -6431,7 +6594,8 @@ def forward(self, arg0_1):
|
|||
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "torch.map is mutating the input!"
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"map might be modifying the input!",
|
||||
):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
|
|
@ -6446,7 +6610,7 @@ def forward(self, arg0_1):
|
|||
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "torch.map is mutating the input!"
|
||||
torch._dynamo.exc.TorchRuntimeError, "map might be modifying the input!"
|
||||
):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
|
|
@ -6484,7 +6648,11 @@ def forward(self, arg0_1):
|
|||
example_inputs = (torch.ones(3, 2, 4),)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.TorchRuntimeError, "torch.map is aliasing the input!"
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"map doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
|
|
@ -6843,10 +7011,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
unbind = torch.ops.aten.unbind.int(x_1)
|
||||
getitem = unbind[0]; getitem = None
|
||||
getitem_1 = unbind[1]; unbind = getitem_1 = None
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
|
||||
getitem = map_impl[0]; map_impl = None
|
||||
return getitem""",
|
||||
getitem_2 = map_impl[0]; map_impl = None
|
||||
return getitem_2""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.body_graph_0.code.strip(),
|
||||
|
|
@ -7082,7 +7253,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
return torch.cond(
|
||||
pred=torch.tensor([True]),
|
||||
true_fn=lambda x: x + 100,
|
||||
false_fn=lambda x: x,
|
||||
false_fn=lambda x: x.clone(),
|
||||
operands=(x,),
|
||||
)
|
||||
|
||||
|
|
@ -7096,7 +7267,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
return torch.cond(
|
||||
pred=x.sum() < y.sum(),
|
||||
true_fn=lambda x, y: x + 100,
|
||||
false_fn=lambda x, y: y,
|
||||
false_fn=lambda x, y: y.clone(),
|
||||
operands=(x, y),
|
||||
)
|
||||
|
||||
|
|
@ -7155,7 +7326,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
return torch.cond(
|
||||
pred=torch.tensor([True]),
|
||||
true_fn=lambda x: (x + c, x - c),
|
||||
false_fn=lambda x: (x, x),
|
||||
false_fn=lambda x: (x.clone(), x.clone()),
|
||||
operands=(x,),
|
||||
)
|
||||
|
||||
|
|
@ -7165,7 +7336,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
return torch.cond(
|
||||
pred=torch.tensor([True]),
|
||||
true_fn=lambda x: (x + 1, x - 1),
|
||||
false_fn=lambda x: (x, x),
|
||||
false_fn=lambda x: (x.clone(), x.clone()),
|
||||
operands=(x,),
|
||||
)
|
||||
|
||||
|
|
@ -7361,8 +7532,6 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
|
|||
functional_f(example_init, example_inputs), f(example_init, example_inputs)
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_scan_functionalized_elem_mutation(self):
|
||||
def add1(x, y):
|
||||
x.add_(4)
|
||||
|
|
@ -7375,8 +7544,15 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
|
|||
example_init = torch.ones(5, 4)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException,
|
||||
"Combine_fn might be modifying the input!",
|
||||
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
||||
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
||||
# RuntimeError,
|
||||
# "torch.scan might be modifying the input!",
|
||||
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
||||
# torch._dynamo.exc.TorchDynamoException,
|
||||
# "Unexpected exception when running generated GraphModule.*"
|
||||
Exception,
|
||||
".*",
|
||||
):
|
||||
functional_f(example_init, example_inputs)
|
||||
|
||||
|
|
@ -7389,13 +7565,19 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
|
|||
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException,
|
||||
"Combine_fn might be modifying the input!",
|
||||
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
||||
# Should be
|
||||
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
||||
# RuntimeError,
|
||||
# "torch.scan might be modifying the input!",
|
||||
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
||||
# torch._dynamo.exc.TorchDynamoException,
|
||||
# "Unexpected exception when running generated GraphModule.*"
|
||||
Exception,
|
||||
".*",
|
||||
):
|
||||
functional_f(example_init, example_inputs)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_scan_functionalized_elem_alias(self):
|
||||
def add(x, y):
|
||||
return x, x
|
||||
|
|
@ -7407,7 +7589,16 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
|
|||
example_init = torch.ones(5, 4)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "Combine_fn might be aliasing the input!"
|
||||
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
||||
# Should be
|
||||
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
||||
# torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
# "scan must be captured completely with torch.compile.*",
|
||||
Exception,
|
||||
".*",
|
||||
):
|
||||
functional_f(example_init, example_inputs)
|
||||
|
||||
|
|
@ -7917,7 +8108,11 @@ class GraphModule(torch.nn.Module):
|
|||
x = torch.randn(2, 2)
|
||||
for f in ALIAS_FN:
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.BackendCompilerFailed, "might be aliasing the input"
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
torch.compile(fn)(f, x)
|
||||
|
||||
|
|
@ -7933,7 +8128,11 @@ class GraphModule(torch.nn.Module):
|
|||
# as a result of auto lifting.
|
||||
for view_f in ALIAS_FN[1:]:
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.BackendCompilerFailed, "might be aliasing the input"
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
torch.compile(fn)(view_f, x)
|
||||
|
||||
|
|
@ -7950,12 +8149,20 @@ class GraphModule(torch.nn.Module):
|
|||
x = torch.randn(2, 2)
|
||||
for f in ALIAS_FN:
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.BackendCompilerFailed, "might be modifying the input"
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
torch.compile(fn)(f, x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.BackendCompilerFailed, "might be modifying the input"
|
||||
# Should be
|
||||
# torch._dynamo.exc.Unsupported,
|
||||
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||
):
|
||||
with torch.inference_mode(inference_mode):
|
||||
torch.compile(fn)(f, x)
|
||||
|
|
|
|||
|
|
@ -570,7 +570,7 @@ class CondTests(TestCase):
|
|||
return torch.cond(p, true_fn, false_fn, [a, b])
|
||||
|
||||
# AssertionError: Output aliasing is currently not supported...
|
||||
with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed):
|
||||
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||
torch.compile(Model())(
|
||||
torch.tensor(True),
|
||||
torch.randn(10, 20),
|
||||
|
|
|
|||
|
|
@ -2831,11 +2831,13 @@ class SubgraphTracer(fx.Tracer):
|
|||
return MutationInfo(False, "")
|
||||
|
||||
def has_aliasing(self):
|
||||
from torch._higher_order_ops.utils import _collect_fake_inputs
|
||||
|
||||
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
example_value = node.meta["example_value"]
|
||||
example_value = _collect_fake_inputs([node])[0]
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
if storage in input_storages:
|
||||
|
|
@ -2848,9 +2850,9 @@ class SubgraphTracer(fx.Tracer):
|
|||
|
||||
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||
out_nodes = self.graph.find_nodes(op="output")[0]
|
||||
for out_node in out_nodes.args[0]:
|
||||
for out_node in pytree.tree_leaves(out_nodes.args[0]):
|
||||
if out_node:
|
||||
example_value = out_node.meta["example_value"]
|
||||
example_value = _collect_fake_inputs([out_node])[0]
|
||||
assert not isinstance(example_value, list)
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
|
|
|
|||
|
|
@ -961,6 +961,9 @@ class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable
|
|||
|
||||
|
||||
class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
supports_input_mutation = False
|
||||
supports_aliasing = False
|
||||
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="Cond doesn't work unless it is captured completely with torch.compile."
|
||||
)
|
||||
|
|
@ -1058,6 +1061,8 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
"cond",
|
||||
source_target=self.value,
|
||||
should_flatten_outputs=True,
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
|
||||
if not only_consist_of(ret_val, (TensorVariable,)):
|
||||
|
|
@ -1184,6 +1189,9 @@ def validate_subgraph_output_types(output: VariableTracker):
|
|||
|
||||
|
||||
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
supports_input_mutation = False
|
||||
supports_aliasing = False
|
||||
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="while_loop doesn't work unless it is captured completely with torch.compile."
|
||||
)
|
||||
|
|
@ -1299,6 +1307,8 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
# So it's best we always enforce the ordering of carried_inputs the same as outputs
|
||||
# with "flatten_manual".
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
cond_nn_modules = dict(tx.output.nn_modules)
|
||||
validate_subgraph_output_types(cond_r)
|
||||
|
|
@ -1337,6 +1347,8 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
should_flatten_outputs=True,
|
||||
supports_input_mutation=False,
|
||||
supports_aliasing=False,
|
||||
)
|
||||
validate_subgraph_output_types(body_r)
|
||||
|
||||
|
|
@ -1420,6 +1432,9 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
|
||||
|
||||
class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
supports_input_mutation = False
|
||||
supports_aliasing = False
|
||||
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="associative_scan must be captured completely with torch.compile."
|
||||
)
|
||||
|
|
@ -1513,6 +1528,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
description="associative_scan_combine_fn",
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
|
||||
# Ensure that the output of scan is a flattened list of elements,
|
||||
|
|
@ -1562,30 +1579,23 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
)
|
||||
|
||||
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
|
||||
combine_freevars_proxy = tuple(combine_lifted_freevars.keys())
|
||||
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_fake_tracing,
|
||||
# Compute the proxies for the input check
|
||||
proxy_vars_inputcheck = (
|
||||
tuple(sarg.as_proxy() for sarg in sub_args) + combine_freevars_proxy
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.utils import _maybe_fake_tracing
|
||||
from torch._inductor.utils import is_pointwise_use
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
with tx.fake_mode:
|
||||
xs_fake = [
|
||||
first_slice_copy(leaf.proxy.node.meta["example_value"].clone())
|
||||
for leaf in itertools.chain(xs_vars, xs_vars)
|
||||
]
|
||||
additional_fake = [
|
||||
leaf.proxy.node.meta["example_value"].clone()
|
||||
for leaf in additional_inputs_vars
|
||||
] + [
|
||||
sub_args_fake = [
|
||||
leaf.node.meta["example_value"].clone()
|
||||
if isinstance(leaf.node.meta["example_value"], FakeTensor)
|
||||
if hasattr(leaf.node.meta["example_value"], "clone")
|
||||
else leaf.node.meta["example_value"]
|
||||
for leaf in combine_lifted_freevars.keys()
|
||||
for leaf in pytree.tree_leaves(proxy_vars_inputcheck)
|
||||
]
|
||||
sub_args_fake = xs_fake + additional_fake
|
||||
pre_dispatch = False
|
||||
|
||||
fx = _maybe_fake_tracing(
|
||||
|
|
@ -1601,15 +1611,6 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
"For combine_mode='pointwise', the combine_fn needs to be pointwise"
|
||||
)
|
||||
|
||||
if _has_potential_branch_input_mutation(
|
||||
combine_gm, sub_args_fake, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise RuntimeError("Combine_fn might be modifying the input!") # noqa: F541
|
||||
if _has_potential_branch_input_alias(
|
||||
combine_gm, sub_args_fake, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise RuntimeError("Combine_fn might be aliasing the input!") # noqa: F541
|
||||
|
||||
combine_fn_name = tx.output.install_subgraph(
|
||||
"associative_scan_combine_fn", combine_gm
|
||||
)
|
||||
|
|
@ -1641,6 +1642,9 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
|
||||
|
||||
class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
supports_input_mutation = False
|
||||
supports_aliasing = False
|
||||
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="scan must be captured completely with torch.compile."
|
||||
)
|
||||
|
|
@ -1752,7 +1756,10 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
description="scan_combine_fn",
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
|
||||
# Ensure that the output of scan is a flattened list of elements,
|
||||
# because downstream operations assume that the output of HOPs
|
||||
# is flattened
|
||||
|
|
@ -1848,6 +1855,9 @@ def non_single_tensor_return_unsupported(api, ret):
|
|||
|
||||
|
||||
class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
supports_input_mutation = False
|
||||
supports_aliasing = False
|
||||
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="map doesn't work unless it is captured completely with torch.compile."
|
||||
)
|
||||
|
|
@ -1910,6 +1920,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
should_flatten_outputs=True,
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
|
||||
body_nn_modules = dict(tx.output.nn_modules)
|
||||
|
|
|
|||
|
|
@ -434,12 +434,27 @@ def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, additional_inputs):
|
|||
|
||||
@associative_scan_op.py_functionalize_impl
|
||||
def associative_scan_functionalize(ctx, combine_fn, xs, additional_inputs):
|
||||
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||
|
||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
||||
with ctx.redispatch_to_next():
|
||||
functional_combine_fn = ctx.functionalize(
|
||||
_maybe_run_with_interpreter(combine_fn)
|
||||
)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
sample_unwrapped_xs_sliced = [
|
||||
first_slice_copy(inp) for inp in itertools.chain(unwrapped_xs, unwrapped_xs)
|
||||
]
|
||||
sample_inputs = list(
|
||||
itertools.chain(
|
||||
sample_unwrapped_xs_sliced,
|
||||
unwrapped_additional_inputs,
|
||||
)
|
||||
)
|
||||
_check_alias_and_mutation(
|
||||
combine_fn, sample_inputs, "associative_scan", pre_dispatch
|
||||
)
|
||||
ret = associative_scan_op(
|
||||
functional_combine_fn,
|
||||
unwrapped_xs,
|
||||
|
|
|
|||
|
|
@ -136,10 +136,10 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
|||
for ph in subgraph.graph.find_nodes(op="placeholder")
|
||||
]
|
||||
(
|
||||
mutated_inp_idx,
|
||||
inp_inp_alias,
|
||||
inp_out_alias,
|
||||
out_out_alias,
|
||||
mutated_inp_idx,
|
||||
output,
|
||||
) = check_input_alias_and_mutation_return_ouputs(subgraph, fake_args)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ from torch._C._functorch import (
|
|||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.utils import exposed_in
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_reenter_make_fx,
|
||||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
|
|
@ -27,7 +25,6 @@ from torch._higher_order_ops.utils import (
|
|||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
|
|
@ -668,29 +665,18 @@ def _merge_tensors(
|
|||
|
||||
@cond_op.py_functionalize_impl
|
||||
def cond_func(ctx, pred, true_fn, false_fn, inputs):
|
||||
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||
|
||||
unwrapped_inputs = ctx.unwrap_tensors(inputs)
|
||||
unwrapped_pred = ctx.unwrap_tensors(pred)
|
||||
with ctx.redispatch_to_next():
|
||||
functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn))
|
||||
functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn))
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
for branch in [true_fn, false_fn]:
|
||||
if _has_potential_branch_input_mutation(
|
||||
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"One of torch.cond branch might be modifying the input! "
|
||||
"Consider cloning the input before modifying it. "
|
||||
)
|
||||
for branch in [true_fn, false_fn]:
|
||||
if _has_potential_branch_input_alias(
|
||||
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"One of torch.cond branch might be aliasing the input! "
|
||||
"If you are returning a view of the input, please make sure "
|
||||
"to clone it. "
|
||||
)
|
||||
for branch, branch_name in [(true_fn, "cond_true"), (false_fn, "cond_false")]:
|
||||
_check_alias_and_mutation(
|
||||
branch, unwrapped_inputs, branch_name, pre_dispatch
|
||||
)
|
||||
|
||||
cond_return = cond_op(
|
||||
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
|
||||
|
|
|
|||
|
|
@ -420,6 +420,9 @@ def flex_attention_functionalize(
|
|||
functional_score_mod = ctx.functionalize(score_mod)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
with TransformGetItemToIndex():
|
||||
# TODO: So far only the input mutations are checked
|
||||
# In the other HOPs, also aliases are checked which is
|
||||
# omitted here
|
||||
mutates = _has_potential_branch_input_mutation(
|
||||
score_mod, example_vals, pre_dispatch
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,9 @@ import torch
|
|||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
|
@ -91,24 +88,18 @@ def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints):
|
|||
|
||||
@hints_wrapper.py_functionalize_impl
|
||||
def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
|
||||
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||
|
||||
unwrapped_args = ctx.unwrap_tensors(args)
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
unwrapped_hints = ctx.unwrap_tensors(hints)
|
||||
with ctx.redispatch_to_next():
|
||||
functional_body_fn = ctx.functionalize(body_fn)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
if _has_potential_branch_input_mutation(
|
||||
body_fn, unwrapped_args, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"body_fn of hints_wrapper might be modifying the input!"
|
||||
)
|
||||
if _has_potential_branch_input_alias(
|
||||
body_fn, unwrapped_args, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"body_fn of hints_wrapper might be aliasing the input!"
|
||||
)
|
||||
_check_alias_and_mutation(
|
||||
body_fn, unwrapped_args, "hints_wrapper", pre_dispatch
|
||||
)
|
||||
|
||||
outputs = hints_wrapper(
|
||||
functional_body_fn,
|
||||
unwrapped_args,
|
||||
|
|
|
|||
|
|
@ -7,13 +7,7 @@ import torch
|
|||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_run_with_interpreter,
|
||||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._higher_order_ops.utils import _maybe_run_with_interpreter, reenter_make_fx
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
|
|
@ -284,23 +278,15 @@ def map_fake_tensor_mode(mode, f, xs, args):
|
|||
|
||||
@map_impl.py_functionalize_impl
|
||||
def map_functionalize(ctx, f, xs, pos_args):
|
||||
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||
|
||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||
unwrapped_args = ctx.unwrap_tensors(pos_args)
|
||||
wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))
|
||||
|
||||
with ctx.redispatch_to_next():
|
||||
with disable_proxy_modes_tracing():
|
||||
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
||||
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
if _has_potential_branch_input_mutation(
|
||||
f, example_inputs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
|
||||
|
||||
if _has_potential_branch_input_alias(
|
||||
f, example_inputs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
||||
|
||||
_check_alias_and_mutation(f, example_inputs, "map", pre_dispatch)
|
||||
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
|
||||
return ctx.wrap_tensors(map_return)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ import torch.utils._pytree as pytree
|
|||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_compile_and_run_fn,
|
||||
check_meta_consistency,
|
||||
first_slice_copy,
|
||||
|
|
@ -19,7 +17,6 @@ from torch._higher_order_ops.utils import (
|
|||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
|
|
@ -861,12 +858,19 @@ def scan_fake_tensor_mode(mode, combine_fn, init, xs, additional_inputs):
|
|||
|
||||
@scan_op.py_functionalize_impl
|
||||
def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
|
||||
from torch._higher_order_ops.utils import (
|
||||
_check_alias_and_mutation,
|
||||
_maybe_run_with_interpreter,
|
||||
)
|
||||
|
||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||
unwrapped_init = ctx.unwrap_tensors(init)
|
||||
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
||||
|
||||
with ctx.redispatch_to_next():
|
||||
functional_combine_fn = ctx.functionalize(combine_fn)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
functional_combine_fn = ctx.functionalize(
|
||||
_maybe_run_with_interpreter(combine_fn)
|
||||
)
|
||||
sample_unwrapped_xs_sliced = [first_slice_copy(inp) for inp in unwrapped_xs]
|
||||
sample_inputs = list(
|
||||
itertools.chain(
|
||||
|
|
@ -875,18 +879,8 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
|
|||
unwrapped_additional_inputs,
|
||||
)
|
||||
)
|
||||
if _has_potential_branch_input_mutation(
|
||||
combine_fn, sample_inputs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"Combine_fn might be modifying the input!"
|
||||
)
|
||||
if _has_potential_branch_input_alias(
|
||||
combine_fn, sample_inputs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"Combine_fn might be aliasing the input!"
|
||||
)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
_check_alias_and_mutation(combine_fn, sample_inputs, "scan", pre_dispatch)
|
||||
ret = scan_op(
|
||||
functional_combine_fn,
|
||||
unwrapped_init,
|
||||
|
|
|
|||
|
|
@ -248,34 +248,6 @@ def _set_compilation_env():
|
|||
torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs
|
||||
|
||||
|
||||
def _detect_input_mutation(gm: torch.fx.GraphModule) -> bool:
|
||||
example_inputs = [
|
||||
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
|
||||
]
|
||||
inp_mutation, _, _, _ = check_input_alias_and_mutation(gm, example_inputs)
|
||||
if len(inp_mutation) > 0:
|
||||
return True
|
||||
|
||||
for _, module in gm.named_children():
|
||||
if isinstance(module, torch.fx.GraphModule):
|
||||
if _detect_input_mutation(module):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _detect_input_alias(gm: torch.fx.GraphModule) -> bool:
|
||||
example_inputs = [
|
||||
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
|
||||
]
|
||||
_, inp_inp_alias_map, inp_out_alias_map, _ = check_input_alias_and_mutation(
|
||||
gm, example_inputs
|
||||
)
|
||||
if len(inp_out_alias_map) > 0 or len(inp_inp_alias_map) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# The invariant here is that we always trace the branch with fake tensor
|
||||
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
||||
fake_mode = detect_fake_mode(inputs)
|
||||
|
|
@ -301,7 +273,7 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
|||
return gm
|
||||
|
||||
|
||||
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
||||
def potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
||||
try:
|
||||
gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
|
||||
except UnsupportedAliasMutationException:
|
||||
|
|
@ -311,43 +283,113 @@ def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return _detect_input_mutation(gm) or _detect_input_alias(gm)
|
||||
example_inputs = [
|
||||
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
|
||||
]
|
||||
(
|
||||
inp_inp_alias_map,
|
||||
inp_out_alias_map,
|
||||
out_out_alias_map,
|
||||
inp_mutation,
|
||||
) = check_input_alias_and_mutation(gm, example_inputs)
|
||||
return (inp_inp_alias_map, inp_out_alias_map, out_out_alias_map), inp_mutation
|
||||
|
||||
|
||||
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
|
||||
"""
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
producing graph has mutable op on the input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = _maybe_fake_tracing(branch, inputs, pre_dispatch)
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond_op is
|
||||
# functionalized
|
||||
return True
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return _detect_input_mutation(gm)
|
||||
def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations):
|
||||
if any(len(a) > 0 for a in aliases):
|
||||
# TODO: Investigate here further which node is exactly aliasing
|
||||
raise RuntimeError(
|
||||
f"{name} where aliases appear. "
|
||||
+ f"In particular, these inputs \
|
||||
{set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map.keys())} " # noqa: C401
|
||||
+ "get aliased. Please ensure that this doesn't happen."
|
||||
)
|
||||
if len(input_mutations):
|
||||
# TODO: Investigate here further which node is exactly mutating the inputs
|
||||
raise RuntimeError(
|
||||
f"{name} where the inputs are mutated. "
|
||||
+ f"In particular, these nodes are mutating the inputs \
|
||||
{set(el for el in input_mutations)}." # noqa: C401
|
||||
+ "Please ensure that this doesn't happen."
|
||||
)
|
||||
|
||||
|
||||
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
|
||||
"""
|
||||
Dispatch-trace the branch with inputs and check if
|
||||
producing graph has output aliasing the branch input. This is
|
||||
bit restrictive as the branch must be traceable.
|
||||
"""
|
||||
try:
|
||||
gm = _maybe_fake_tracing(branch, inputs, pre_dispatch)
|
||||
except UnsupportedAliasMutationException:
|
||||
# this can happen when nested cond_op is
|
||||
# functionalized
|
||||
return True
|
||||
except Exception as e:
|
||||
raise e
|
||||
def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False):
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
||||
|
||||
return _detect_input_alias(gm)
|
||||
return len(inp_mutation) > 0
|
||||
|
||||
|
||||
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
||||
(
|
||||
inp_inp_alias_map,
|
||||
inp_out_alias_map,
|
||||
out_out_alias_map,
|
||||
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
||||
return (
|
||||
any(
|
||||
(
|
||||
len(inp_inp_alias_map) > 0,
|
||||
len(inp_out_alias_map) > 0,
|
||||
len(out_out_alias_map) > 0,
|
||||
)
|
||||
),
|
||||
len(inp_mutation) > 0,
|
||||
)
|
||||
|
||||
|
||||
def _collect_fake_inputs(inputs):
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
# Get the example values of the inputs.
|
||||
inputs_fake: list[Union[FakeTensor, torch.Tensor, int]] = []
|
||||
for inp in inputs:
|
||||
if isinstance(inp, (torch.fx.proxy.Proxy, torch.fx.node.Node)):
|
||||
inp = inp.node if isinstance(inp, torch.fx.proxy.Proxy) else inp
|
||||
if hasattr(inp, "meta"):
|
||||
val = inp.meta["example_value"]
|
||||
if isinstance(val, torch.Tensor):
|
||||
if torch._C._functorch.is_batchedtensor(
|
||||
val
|
||||
) or torch._C._functorch.is_functionaltensor(val):
|
||||
# This case is for batched or functional tensors
|
||||
# Unwrap the tensors
|
||||
while torch._C._functorch.is_batchedtensor(
|
||||
val
|
||||
) or torch._C._functorch.is_functionaltensor(val):
|
||||
val = torch._C._functorch.get_unwrapped(val)
|
||||
assert isinstance(val, FakeTensor)
|
||||
inputs_fake.append(val)
|
||||
else:
|
||||
# This is the standard case of a TensorVariable
|
||||
assert isinstance(val, FakeTensor)
|
||||
inputs_fake.append(val)
|
||||
else:
|
||||
# This case is for SymInts and other non-Tensor elements
|
||||
assert not isinstance(val, torch.Tensor)
|
||||
inputs_fake.append(val)
|
||||
else:
|
||||
# This case is for ints
|
||||
assert isinstance(inp, int)
|
||||
inputs_fake.append(inp)
|
||||
|
||||
return inputs_fake
|
||||
|
||||
|
||||
def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch):
|
||||
aliases, inp_mutation = has_potential_input_alias_or_mutation(
|
||||
graph_module, inputs_fake, pre_dispatch=pre_dispatch
|
||||
)
|
||||
if aliases:
|
||||
raise RuntimeError(
|
||||
f"{name} might be aliasing the input or the output!"
|
||||
) # noqa: F541
|
||||
if inp_mutation:
|
||||
raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541
|
||||
|
||||
|
||||
def unique_graph_id(proxy_mode, prefix):
|
||||
|
|
@ -699,27 +741,29 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]])
|
|||
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|
||||
|
||||
|
||||
# TODO: Return a more detailed information as to which node
|
||||
# causes a mutation or an alias. This may requires a per operator tensor version checking
|
||||
def check_input_alias_and_mutation(
|
||||
gm: torch.fx.GraphModule,
|
||||
fake_args: list[FakeTensor],
|
||||
) -> tuple[list[int], dict[int, int], dict[int, int], dict[int, int]]:
|
||||
) -> tuple[dict[int, int], dict[int, int], dict[int, int], list[int]]:
|
||||
(
|
||||
mutated_inputs,
|
||||
inp_inp_alias_map,
|
||||
inp_out_alias_map,
|
||||
out_out_alias_map,
|
||||
mutated_inputs,
|
||||
) = check_input_alias_and_mutation_return_ouputs(gm, fake_args)[:-1]
|
||||
return mutated_inputs, inp_inp_alias_map, inp_out_alias_map, out_out_alias_map
|
||||
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
|
||||
|
||||
|
||||
def check_input_alias_and_mutation_return_ouputs(
|
||||
gm: torch.fx.GraphModule,
|
||||
fake_args: list[FakeTensor],
|
||||
) -> tuple[
|
||||
dict[int, int],
|
||||
dict[int, int],
|
||||
dict[int, int],
|
||||
list[int],
|
||||
dict[int, int],
|
||||
dict[int, int],
|
||||
dict[int, int],
|
||||
Union[tuple[Any, ...], list[Any]],
|
||||
]:
|
||||
# We want to disable active functional, proxy and fake modes if any.
|
||||
|
|
@ -825,10 +869,10 @@ def check_input_alias_and_mutation_return_ouputs(
|
|||
if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map
|
||||
}
|
||||
return (
|
||||
mutated_inputs,
|
||||
inp_inp_alias_map,
|
||||
inp_out_alias_map,
|
||||
out_out_alias_map,
|
||||
mutated_inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,11 @@ import torch
|
|||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
autograd_not_implemented,
|
||||
check_meta_consistency,
|
||||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
|
|
@ -400,6 +397,8 @@ def while_loop_fake_tensor_mode(
|
|||
|
||||
@while_loop_op.py_functionalize_impl
|
||||
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||
|
||||
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
|
||||
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
||||
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
|
||||
|
|
@ -411,19 +410,7 @@ def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
|
|||
(cond_fn, "cond_fn"),
|
||||
(body_fn, "body_fn"),
|
||||
]:
|
||||
if _has_potential_branch_input_mutation(
|
||||
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
f"torch.while_loop's {fn_name} might be modifying the input!"
|
||||
)
|
||||
|
||||
if _has_potential_branch_input_alias(
|
||||
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
f"torch.while_loop's {fn_name} might be aliasing the input!"
|
||||
)
|
||||
_check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch)
|
||||
ret = while_loop_op(
|
||||
functional_cond_fn,
|
||||
functional_body_fn,
|
||||
|
|
|
|||
|
|
@ -7898,7 +7898,7 @@ class WhileLoop(ExternKernel):
|
|||
# Handling input mutations
|
||||
mutated_idxs = check_input_alias_and_mutation(
|
||||
body_fn.graph.module, fake_all_inputs
|
||||
)[0]
|
||||
)[3]
|
||||
mutated_idx_set = OrderedSet(mutated_idxs)
|
||||
mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set]
|
||||
real_outputs = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user