mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[HOP] Mutation and alias rework (#146658)
This PR reworks the way the input mutations and various aliases are checked Pull Request resolved: https://github.com/pytorch/pytorch/pull/146658 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
0e805aad7f
commit
68034198e5
|
|
@ -1,5 +1,4 @@
|
||||||
from torch import cond # noqa: F401
|
from torch import cond # noqa: F401
|
||||||
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
|
|
||||||
from torch._higher_order_ops.map import ( # noqa: F401
|
from torch._higher_order_ops.map import ( # noqa: F401
|
||||||
_stack_pytree,
|
_stack_pytree,
|
||||||
_unstack_pytree,
|
_unstack_pytree,
|
||||||
|
|
|
||||||
|
|
@ -1873,7 +1873,7 @@ def forward(self, x, y):
|
||||||
return x + x
|
return x + x
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return x[:2]
|
return x[:2].clone()
|
||||||
|
|
||||||
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
|
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
|
||||||
|
|
||||||
|
|
@ -1883,7 +1883,7 @@ def forward(self, x, y):
|
||||||
return x + x
|
return x + x
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return x[:2]
|
return x[:2].clone()
|
||||||
|
|
||||||
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
|
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
|
||||||
|
|
||||||
|
|
@ -1924,7 +1924,8 @@ def forward(self, l_x_):
|
||||||
def forward(self, l_x_):
|
def forward(self, l_x_):
|
||||||
l_x__1 = l_x_
|
l_x__1 = l_x_
|
||||||
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
|
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
|
||||||
return (getitem,)""",
|
clone = getitem.clone(); getitem = None
|
||||||
|
return (clone,)""",
|
||||||
)
|
)
|
||||||
# We could successfully export branches that return different sizes
|
# We could successfully export branches that return different sizes
|
||||||
torch._dynamo.export(mod)(torch.randn(3, 2))
|
torch._dynamo.export(mod)(torch.randn(3, 2))
|
||||||
|
|
@ -3302,7 +3303,12 @@ def forward(self, x):
|
||||||
|
|
||||||
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
|
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
|
||||||
def f_branch_return_multiple_tensors(pred, x, y):
|
def f_branch_return_multiple_tensors(pred, x, y):
|
||||||
return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
|
return cond(
|
||||||
|
pred,
|
||||||
|
lambda x: (x.clone(), x.clone()),
|
||||||
|
lambda x: (x.clone(), x.clone()),
|
||||||
|
[y],
|
||||||
|
)
|
||||||
|
|
||||||
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
|
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
|
||||||
gm, _ = torch._dynamo.export(
|
gm, _ = torch._dynamo.export(
|
||||||
|
|
@ -3324,10 +3330,10 @@ def forward(self, x):
|
||||||
|
|
||||||
def test_cond_raise_user_error_on_mismatch_return_length(self):
|
def test_cond_raise_user_error_on_mismatch_return_length(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x
|
return x.clone()
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return (x, x)
|
return (x.clone(), x.clone())
|
||||||
|
|
||||||
def f_mismatch_return_length(x):
|
def f_mismatch_return_length(x):
|
||||||
return cond(torch.tensor(100), true_fn, false_fn, [x])
|
return cond(torch.tensor(100), true_fn, false_fn, [x])
|
||||||
|
|
|
||||||
|
|
@ -1791,7 +1791,13 @@ def forward(self, child : torch.Tensor):
|
||||||
|
|
||||||
def test_map_pytree_return(self):
|
def test_map_pytree_return(self):
|
||||||
def _construct_pytree(a):
|
def _construct_pytree(a):
|
||||||
return (a, [[[a]]], a, (a, (a,), a), {"a": a})
|
return (
|
||||||
|
a.clone(),
|
||||||
|
[[[a.clone()]]],
|
||||||
|
a.clone(),
|
||||||
|
(a.clone(), (a.clone(),), a.clone()),
|
||||||
|
{"a": a.clone()},
|
||||||
|
)
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
def inner_f(xs):
|
def inner_f(xs):
|
||||||
|
|
@ -1823,7 +1829,14 @@ def forward(self, L_x_ : torch.Tensor):
|
||||||
body_graph,
|
body_graph,
|
||||||
"""\
|
"""\
|
||||||
def forward(self, child : torch.Tensor):
|
def forward(self, child : torch.Tensor):
|
||||||
return (child, child, child, child, child, child, child)""",
|
child_1 = child.clone()
|
||||||
|
child_2 = child.clone()
|
||||||
|
child_3 = child.clone()
|
||||||
|
child_4 = child.clone()
|
||||||
|
child_5 = child.clone()
|
||||||
|
child_6 = child.clone()
|
||||||
|
child_7 = child.clone(); child = None
|
||||||
|
return (child_1, child_2, child_3, child_4, child_5, child_6, child_7)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_kwargs(self):
|
def test_map_kwargs(self):
|
||||||
|
|
@ -6902,7 +6915,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
def test(pred, x):
|
def test(pred, x):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x
|
return x.clone()
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return -x
|
return -x
|
||||||
|
|
@ -6926,7 +6939,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
def test(pred, mode, x):
|
def test(pred, mode, x):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x
|
return x.clone()
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return -x
|
return -x
|
||||||
|
|
|
||||||
|
|
@ -5931,7 +5931,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
from functorch.experimental.control_flow import cond
|
from functorch.experimental.control_flow import cond
|
||||||
|
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x
|
return x.clone()
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
|
||||||
|
|
@ -7604,20 +7604,25 @@ def forward(self, b_a_buffer, x):
|
||||||
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
|
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@testing.expectedFailureCppRuntime
|
|
||||||
def test_export_associative_scan_lifted_buffers(self):
|
def test_export_associative_scan_lifted_buffers(self):
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
combine_mode = "pointwise"
|
combine_mode = "pointwise"
|
||||||
|
|
||||||
|
class A(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.buffer = torch.nn.Buffer(torch.ones(3, 2, device=device))
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.buffer.cos()
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_buffer(
|
self.a = A()
|
||||||
"buf", torch.ones(3, 2, device=device), persistent=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def combine_fn(self, x, y):
|
def combine_fn(self, x, y):
|
||||||
return x + y * self.buf
|
return (x + y) * self.a()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return associative_scan(
|
return associative_scan(
|
||||||
|
|
|
||||||
|
|
@ -4572,17 +4572,17 @@ class <lambda>(torch.nn.Module):
|
||||||
|
|
||||||
body_graph_0 = self.body_graph_0
|
body_graph_0 = self.body_graph_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
|
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
|
||||||
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
|
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||||
|
|
||||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
|
||||||
|
|
||||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
|
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
|
||||||
|
|
||||||
body_graph_1 = self.body_graph_1
|
body_graph_1 = self.body_graph_1
|
||||||
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
|
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
|
||||||
getitem_1: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
|
getitem_5: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
|
||||||
|
|
||||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_5); getitem_5 = None
|
||||||
|
|
||||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
|
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
|
||||||
return (add_1,)
|
return (add_1,)
|
||||||
|
|
@ -4635,9 +4635,9 @@ class <lambda>(torch.nn.Module):
|
||||||
|
|
||||||
body_graph_0 = self.body_graph_0
|
body_graph_0 = self.body_graph_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
|
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
|
||||||
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
|
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||||
|
|
||||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
|
||||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||||
return (add,)
|
return (add,)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from functorch.experimental import control_flow
|
from functorch.experimental import control_flow
|
||||||
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
|
from functorch.experimental.control_flow import cond
|
||||||
from torch._dynamo.testing import normalize_gm
|
from torch._dynamo.testing import normalize_gm
|
||||||
from torch._higher_order_ops.associative_scan import (
|
from torch._higher_order_ops.associative_scan import (
|
||||||
_fake_associative_scan,
|
_fake_associative_scan,
|
||||||
|
|
@ -36,7 +36,6 @@ from torch.testing._internal.common_utils import (
|
||||||
TEST_WITH_CROSSREF,
|
TEST_WITH_CROSSREF,
|
||||||
TEST_WITH_TORCHDYNAMO,
|
TEST_WITH_TORCHDYNAMO,
|
||||||
TestCase,
|
TestCase,
|
||||||
xfailIfTorchDynamo,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -159,7 +158,7 @@ def get_scan_combine_fn(name, associative=True, parameters=None):
|
||||||
def RNN(x: torch.Tensor, y: torch.Tensor):
|
def RNN(x: torch.Tensor, y: torch.Tensor):
|
||||||
c_new = y @ parameters[0] + parameters[1]
|
c_new = y @ parameters[0] + parameters[1]
|
||||||
h_new = torch.tanh(c_new + x @ parameters[2] + parameters[3])
|
h_new = torch.tanh(c_new + x @ parameters[2] + parameters[3])
|
||||||
return h_new, h_new
|
return h_new, h_new.clone()
|
||||||
|
|
||||||
def fct_c1_no_grad(x: torch.Tensor, y: torch.Tensor):
|
def fct_c1_no_grad(x: torch.Tensor, y: torch.Tensor):
|
||||||
h_new = torch.tanh(x[0] + x[1] + y)
|
h_new = torch.tanh(x[0] + x[1] + y)
|
||||||
|
|
@ -888,8 +887,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cond_autograd_pytree_input(self):
|
def test_cond_autograd_pytree_input(self):
|
||||||
|
# TODO: This is an unexpected behavior for cond
|
||||||
|
# Without this additional multiplication,
|
||||||
|
# the output of the backward graph would alias the
|
||||||
|
# inputs, as the gradients are just 1s and thus get optimized
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
|
return (x["t"][0] * 2.0) + x["t"][1]["b"] * x["t"][2][0]
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
||||||
|
|
@ -966,10 +969,10 @@ def forward(self, pred_1):
|
||||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||||
def test_cond_autograd_same_pytree_output(self):
|
def test_cond_autograd_same_pytree_output(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return {"res": [x["t"][0], (x["t"][2][0],)]}
|
return {"res": [x["t"][0].clone(), (x["t"][2][0].clone(),)]}
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return {"res": [x["t"][1]["b"], (x["t"][2][0],)]}
|
return {"res": [x["t"][1]["b"].clone(), (x["t"][2][0].clone(),)]}
|
||||||
|
|
||||||
a = torch.randn(4, requires_grad=True)
|
a = torch.randn(4, requires_grad=True)
|
||||||
b = torch.randn(4, requires_grad=True)
|
b = torch.randn(4, requires_grad=True)
|
||||||
|
|
@ -1007,9 +1010,7 @@ def forward(self, pred_1):
|
||||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
||||||
getitem = cond[0]
|
getitem = cond[0]
|
||||||
getitem_1 = cond[1]; cond = None
|
getitem_1 = cond[1]; cond = None
|
||||||
view = torch.ops.aten.view.default(getitem, [4]); getitem = None
|
return {'res': [getitem, (getitem_1,)]}""", # noqa: B950
|
||||||
view_1 = torch.ops.aten.view.default(getitem_1, [4]); getitem_1 = None
|
|
||||||
return {'res': [view, (view_1,)]}""", # noqa: B950
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||||
|
|
@ -1901,12 +1902,15 @@ def forward(self, pred_1, x_1):
|
||||||
{
|
{
|
||||||
"i": x["i"] * y["j"][0][0],
|
"i": x["i"] * y["j"][0][0],
|
||||||
"k": torch.tensor(0.0),
|
"k": torch.tensor(0.0),
|
||||||
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]),
|
"j": (
|
||||||
|
[x["j"][1][0]["o"].clone()],
|
||||||
|
[{"o": torch.sin(x["i"])}],
|
||||||
|
),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"i": x["i"] * y["j"][0][0],
|
"i": x["i"] * y["j"][0][0],
|
||||||
"k": torch.tensor(0.0),
|
"k": torch.tensor(0.0),
|
||||||
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]),
|
"j": ([x["j"][1][0]["o"].clone()], [{"o": torch.sin(x["i"])}]),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2395,8 +2399,11 @@ def forward(self, pred_1, x_1):
|
||||||
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
"Expected init and carry to have same metadata.*",
|
"scan must be captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
scan_fct(wrong_carry_shape, init, x, dim=dim)
|
scan_fct(wrong_carry_shape, init, x, dim=dim)
|
||||||
|
|
||||||
|
|
@ -3311,6 +3318,114 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
|
||||||
return (carry, out_1)""", # noqa: B950
|
return (carry, out_1)""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_scan_input_mutation(self):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
def fct_input_mutation(x, y):
|
||||||
|
x.add_(1)
|
||||||
|
return x + y, x + y + 2
|
||||||
|
|
||||||
|
x = torch.randn(3, 2, 2, device=device)
|
||||||
|
init = torch.randn(2, 2, device=device)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"scan must be captured completely with torch.compile.*",
|
||||||
|
):
|
||||||
|
scan(fct_input_mutation, init, x, dim=0)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_scan_input_carry_alias(self):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
def fct_input_output_alias(x, y):
|
||||||
|
return (x[0], x[1] + y[1]), (x[1] + y[1] + 1, x[1] + y[1] + 2)
|
||||||
|
|
||||||
|
x = torch.randn(3, 2, 2, device=device)
|
||||||
|
y = torch.randn(3, 2, 2, device=device)
|
||||||
|
inp = (x, y)
|
||||||
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"scan must be captured completely with torch.compile.*",
|
||||||
|
):
|
||||||
|
scan(fct_input_output_alias, init, inp, dim=0)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
def test_scan_input_output_alias(self):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
def fct_input_output_alias(x, y):
|
||||||
|
return (x[0] + 1, x[1] + y[1]), (x[1], x[1] + y[1] + 2)
|
||||||
|
|
||||||
|
x = torch.randn(3, 2, 2, device=device)
|
||||||
|
y = torch.randn(3, 2, 2, device=device)
|
||||||
|
inp = (x, y)
|
||||||
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"scan must be captured completely with torch.compile.*",
|
||||||
|
):
|
||||||
|
scan(fct_input_output_alias, init, inp, dim=0)
|
||||||
|
|
||||||
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
|
@requires_cuda
|
||||||
|
def test_scan_carry_carry_alias(self):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
def fct_carry_carry_alias(x, y):
|
||||||
|
c = x[0] + y[1]
|
||||||
|
return (c, c), (x[0] + y[1], x[0] + y[1] + 1)
|
||||||
|
|
||||||
|
x = torch.randn(3, 2, 2, device=device)
|
||||||
|
y = torch.randn(3, 2, 2, device=device)
|
||||||
|
inp = (x, y)
|
||||||
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"scan must be captured completely with torch.compile.*",
|
||||||
|
):
|
||||||
|
scan(fct_carry_carry_alias, init, inp, dim=0)
|
||||||
|
|
||||||
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
|
@requires_cuda
|
||||||
|
def test_scan_carry_output_alias(self):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
def fct_carry_output_alias(x, y):
|
||||||
|
c = x[0] + y[1]
|
||||||
|
return (x[0] + y[1], c), (c, x[0] + y[1] + 1)
|
||||||
|
|
||||||
|
x = torch.randn(3, 2, 2, device=device)
|
||||||
|
y = torch.randn(3, 2, 2, device=device)
|
||||||
|
inp = (x, y)
|
||||||
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"scan must be captured completely with torch.compile.*",
|
||||||
|
):
|
||||||
|
scan(fct_carry_output_alias, init, inp, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class AssociativeScanModels:
|
class AssociativeScanModels:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -4158,7 +4273,7 @@ class GraphModule(torch.nn.Module):
|
||||||
)
|
)
|
||||||
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
|
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
|
||||||
def combine_fn(x, y):
|
def combine_fn(x, y):
|
||||||
val = cond(torch.sum(y) > 0.0, lambda y: y + 0.0, lambda y: 1.0 - y, (y,))
|
val = cond(torch.sum(y) > 0.0, lambda y: y.clone(), lambda y: 1.0 - y, (y,))
|
||||||
return x * val
|
return x * val
|
||||||
|
|
||||||
inp = torch.randn(3, 10, 1, device=device)
|
inp = torch.randn(3, 10, 1, device=device)
|
||||||
|
|
@ -4775,8 +4890,10 @@ class GraphModule(torch.nn.Module):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
# Should be
|
# Should be
|
||||||
RuntimeError,
|
# torch._dynamo.exc.Unsupported,
|
||||||
"Combine_fn might be modifying the input!",
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"associative_scan must be captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
associative_scan(fct_input_mutation, x, 0)
|
associative_scan(fct_input_mutation, x, 0)
|
||||||
|
|
||||||
|
|
@ -4793,11 +4910,35 @@ class GraphModule(torch.nn.Module):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
# Should be
|
# Should be
|
||||||
RuntimeError,
|
# torch._dynamo.exc.Unsupported,
|
||||||
"Combine_fn might be aliasing the input!",
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"associative_scan must be captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
associative_scan(fct_input_output_alias, inp, 0)
|
associative_scan(fct_input_output_alias, inp, 0)
|
||||||
|
|
||||||
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
|
@requires_cuda
|
||||||
|
def test_associative_scan_output_output_alias(self):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
def fct_output_output_alias(x, y):
|
||||||
|
c = x[0] + y[1]
|
||||||
|
return c, c
|
||||||
|
|
||||||
|
x = torch.randn(3, 2, 2, device=device)
|
||||||
|
y = torch.randn(3, 2, 2, device=device)
|
||||||
|
inp = (x, y)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"associative_scan must be captured completely with torch.compile.*",
|
||||||
|
):
|
||||||
|
associative_scan(fct_output_output_alias, inp, 0)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
||||||
@skipIfNoDynamoSupport
|
@skipIfNoDynamoSupport
|
||||||
|
|
@ -5597,8 +5738,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
||||||
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
||||||
def test_cond_functionalized_input_mutation_on_true_brancte(self):
|
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
view_x = x.view(x.shape)
|
view_x = x.view(x.shape)
|
||||||
view_x.add_(1)
|
view_x.add_(1)
|
||||||
|
|
@ -5632,13 +5772,13 @@ def forward(self, x_1):
|
||||||
# torch.cond triggers the check of the branches because the predicate
|
# torch.cond triggers the check of the branches because the predicate
|
||||||
# is a SymBool.
|
# is a SymBool.
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
torch._dynamo.exc.TorchRuntimeError,
|
||||||
|
"cond_true might be modifying the input!",
|
||||||
):
|
):
|
||||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||||
*example_inputs
|
*example_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
|
||||||
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x.sin().sum()
|
return x.sin().sum()
|
||||||
|
|
@ -5673,16 +5813,16 @@ def forward(self, x_1):
|
||||||
# torch.cond triggers the check of the branches because the predicate
|
# torch.cond triggers the check of the branches because the predicate
|
||||||
# is a SymBool.
|
# is a SymBool.
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
torch._dynamo.exc.TorchRuntimeError,
|
||||||
|
"cond_false might be modifying the input!",
|
||||||
):
|
):
|
||||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||||
*example_inputs
|
*example_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
|
||||||
def test_cond_functionalized_output_alias_input(self):
|
def test_cond_functionalized_output_alias_input(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x
|
return x.clone()
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
view_x = x.view(x.shape)
|
view_x = x.view(x.shape)
|
||||||
|
|
@ -5707,13 +5847,16 @@ def forward(self, x_1):
|
||||||
# torch.cond triggers the check of the branches because the predicate
|
# torch.cond triggers the check of the branches because the predicate
|
||||||
# is a SymBool.
|
# is a SymBool.
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||||
*example_inputs
|
*example_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
|
||||||
def test_cond_functionalized_nested_input_mutation(self):
|
def test_cond_functionalized_nested_input_mutation(self):
|
||||||
def true_true_fn(x):
|
def true_true_fn(x):
|
||||||
x.add_(4)
|
x.add_(4)
|
||||||
|
|
@ -5735,13 +5878,13 @@ def forward(self, x_1):
|
||||||
|
|
||||||
example_inputs = (torch.ones(4, 5),)
|
example_inputs = (torch.ones(4, 5),)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
torch._dynamo.exc.TorchRuntimeError,
|
||||||
|
"cond_true might be modifying the input!",
|
||||||
):
|
):
|
||||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||||
*example_inputs
|
*example_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
|
||||||
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
||||||
def true_true_fn(x):
|
def true_true_fn(x):
|
||||||
x.add_(4)
|
x.add_(4)
|
||||||
|
|
@ -5768,7 +5911,11 @@ def forward(self, x_1):
|
||||||
f(example_input_func)
|
f(example_input_func)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
make_fx(f, tracing_mode="symbolic")(example_input_func)
|
make_fx(f, tracing_mode="symbolic")(example_input_func)
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -5786,7 +5933,11 @@ def forward(self, x_1):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "One of torch.cond branch"
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
|
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
|
||||||
|
|
||||||
|
|
@ -5807,8 +5958,11 @@ def forward(self, x_1):
|
||||||
example_input_func = to_fun_old(example_input)
|
example_input_func = to_fun_old(example_input)
|
||||||
torch._enable_functionalization(reapply_views=False)
|
torch._enable_functionalization(reapply_views=False)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError,
|
# Should be
|
||||||
"One of torch.cond branch might be aliasing",
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
f(example_input_func)
|
f(example_input_func)
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -5838,8 +5992,11 @@ def forward(self, x_1):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError,
|
# Should be
|
||||||
"One of torch.cond branch might be aliasing",
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
||||||
|
|
||||||
|
|
@ -5974,8 +6131,11 @@ def forward(self, arg0_1):
|
||||||
|
|
||||||
x = torch.randn(4)
|
x = torch.randn(4)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError,
|
# Should be
|
||||||
"Unmatched output spec from torch.cond branches",
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
make_fx(f)(x, torch.tensor(False))
|
make_fx(f)(x, torch.tensor(False))
|
||||||
|
|
||||||
|
|
@ -6147,8 +6307,11 @@ def forward(self, arg0_1):
|
||||||
|
|
||||||
x = torch.randn(4)
|
x = torch.randn(4)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError,
|
# Should be
|
||||||
"Unmatched output spec from torch.cond branches",
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||||
|
|
||||||
|
|
@ -6431,7 +6594,8 @@ def forward(self, arg0_1):
|
||||||
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
||||||
functional_f = torch.func.functionalize(f)
|
functional_f = torch.func.functionalize(f)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "torch.map is mutating the input!"
|
torch._dynamo.exc.TorchRuntimeError,
|
||||||
|
"map might be modifying the input!",
|
||||||
):
|
):
|
||||||
functional_f(*example_inputs)
|
functional_f(*example_inputs)
|
||||||
|
|
||||||
|
|
@ -6446,7 +6610,7 @@ def forward(self, arg0_1):
|
||||||
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
||||||
functional_f = torch.func.functionalize(f)
|
functional_f = torch.func.functionalize(f)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "torch.map is mutating the input!"
|
torch._dynamo.exc.TorchRuntimeError, "map might be modifying the input!"
|
||||||
):
|
):
|
||||||
functional_f(*example_inputs)
|
functional_f(*example_inputs)
|
||||||
|
|
||||||
|
|
@ -6484,7 +6648,11 @@ def forward(self, arg0_1):
|
||||||
example_inputs = (torch.ones(3, 2, 4),)
|
example_inputs = (torch.ones(3, 2, 4),)
|
||||||
functional_f = torch.func.functionalize(f)
|
functional_f = torch.func.functionalize(f)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.TorchRuntimeError, "torch.map is aliasing the input!"
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"map doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
functional_f(*example_inputs)
|
functional_f(*example_inputs)
|
||||||
|
|
||||||
|
|
@ -6843,10 +7011,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||||
gm.code.strip(),
|
gm.code.strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, pred_1, x_1):
|
def forward(self, pred_1, x_1):
|
||||||
|
unbind = torch.ops.aten.unbind.int(x_1)
|
||||||
|
getitem = unbind[0]; getitem = None
|
||||||
|
getitem_1 = unbind[1]; unbind = getitem_1 = None
|
||||||
body_graph_0 = self.body_graph_0
|
body_graph_0 = self.body_graph_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
|
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
|
||||||
getitem = map_impl[0]; map_impl = None
|
getitem_2 = map_impl[0]; map_impl = None
|
||||||
return getitem""",
|
return getitem_2""",
|
||||||
)
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
gm.body_graph_0.code.strip(),
|
gm.body_graph_0.code.strip(),
|
||||||
|
|
@ -7082,7 +7253,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
return torch.cond(
|
return torch.cond(
|
||||||
pred=torch.tensor([True]),
|
pred=torch.tensor([True]),
|
||||||
true_fn=lambda x: x + 100,
|
true_fn=lambda x: x + 100,
|
||||||
false_fn=lambda x: x,
|
false_fn=lambda x: x.clone(),
|
||||||
operands=(x,),
|
operands=(x,),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -7096,7 +7267,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
return torch.cond(
|
return torch.cond(
|
||||||
pred=x.sum() < y.sum(),
|
pred=x.sum() < y.sum(),
|
||||||
true_fn=lambda x, y: x + 100,
|
true_fn=lambda x, y: x + 100,
|
||||||
false_fn=lambda x, y: y,
|
false_fn=lambda x, y: y.clone(),
|
||||||
operands=(x, y),
|
operands=(x, y),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -7155,7 +7326,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
return torch.cond(
|
return torch.cond(
|
||||||
pred=torch.tensor([True]),
|
pred=torch.tensor([True]),
|
||||||
true_fn=lambda x: (x + c, x - c),
|
true_fn=lambda x: (x + c, x - c),
|
||||||
false_fn=lambda x: (x, x),
|
false_fn=lambda x: (x.clone(), x.clone()),
|
||||||
operands=(x,),
|
operands=(x,),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -7165,7 +7336,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
return torch.cond(
|
return torch.cond(
|
||||||
pred=torch.tensor([True]),
|
pred=torch.tensor([True]),
|
||||||
true_fn=lambda x: (x + 1, x - 1),
|
true_fn=lambda x: (x + 1, x - 1),
|
||||||
false_fn=lambda x: (x, x),
|
false_fn=lambda x: (x.clone(), x.clone()),
|
||||||
operands=(x,),
|
operands=(x,),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -7361,8 +7532,6 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
|
||||||
functional_f(example_init, example_inputs), f(example_init, example_inputs)
|
functional_f(example_init, example_inputs), f(example_init, example_inputs)
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
|
||||||
@xfailIfTorchDynamo
|
|
||||||
def test_scan_functionalized_elem_mutation(self):
|
def test_scan_functionalized_elem_mutation(self):
|
||||||
def add1(x, y):
|
def add1(x, y):
|
||||||
x.add_(4)
|
x.add_(4)
|
||||||
|
|
@ -7375,8 +7544,15 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
|
||||||
example_init = torch.ones(5, 4)
|
example_init = torch.ones(5, 4)
|
||||||
functional_f = torch.func.functionalize(f)
|
functional_f = torch.func.functionalize(f)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException,
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
||||||
"Combine_fn might be modifying the input!",
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
||||||
|
# RuntimeError,
|
||||||
|
# "torch.scan might be modifying the input!",
|
||||||
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
||||||
|
# torch._dynamo.exc.TorchDynamoException,
|
||||||
|
# "Unexpected exception when running generated GraphModule.*"
|
||||||
|
Exception,
|
||||||
|
".*",
|
||||||
):
|
):
|
||||||
functional_f(example_init, example_inputs)
|
functional_f(example_init, example_inputs)
|
||||||
|
|
||||||
|
|
@ -7389,13 +7565,19 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
|
||||||
|
|
||||||
functional_f = torch.func.functionalize(f)
|
functional_f = torch.func.functionalize(f)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException,
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
||||||
"Combine_fn might be modifying the input!",
|
# Should be
|
||||||
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
||||||
|
# RuntimeError,
|
||||||
|
# "torch.scan might be modifying the input!",
|
||||||
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
||||||
|
# torch._dynamo.exc.TorchDynamoException,
|
||||||
|
# "Unexpected exception when running generated GraphModule.*"
|
||||||
|
Exception,
|
||||||
|
".*",
|
||||||
):
|
):
|
||||||
functional_f(example_init, example_inputs)
|
functional_f(example_init, example_inputs)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
|
||||||
@xfailIfTorchDynamo
|
|
||||||
def test_scan_functionalized_elem_alias(self):
|
def test_scan_functionalized_elem_alias(self):
|
||||||
def add(x, y):
|
def add(x, y):
|
||||||
return x, x
|
return x, x
|
||||||
|
|
@ -7407,7 +7589,16 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor,
|
||||||
example_init = torch.ones(5, 4)
|
example_init = torch.ones(5, 4)
|
||||||
functional_f = torch.func.functionalize(f)
|
functional_f = torch.func.functionalize(f)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException, "Combine_fn might be aliasing the input!"
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
||||||
|
# Should be
|
||||||
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
||||||
|
# torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
# "scan must be captured completely with torch.compile.*",
|
||||||
|
Exception,
|
||||||
|
".*",
|
||||||
):
|
):
|
||||||
functional_f(example_init, example_inputs)
|
functional_f(example_init, example_inputs)
|
||||||
|
|
||||||
|
|
@ -7917,7 +8108,11 @@ class GraphModule(torch.nn.Module):
|
||||||
x = torch.randn(2, 2)
|
x = torch.randn(2, 2)
|
||||||
for f in ALIAS_FN:
|
for f in ALIAS_FN:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.BackendCompilerFailed, "might be aliasing the input"
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
torch.compile(fn)(f, x)
|
torch.compile(fn)(f, x)
|
||||||
|
|
||||||
|
|
@ -7933,7 +8128,11 @@ class GraphModule(torch.nn.Module):
|
||||||
# as a result of auto lifting.
|
# as a result of auto lifting.
|
||||||
for view_f in ALIAS_FN[1:]:
|
for view_f in ALIAS_FN[1:]:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.BackendCompilerFailed, "might be aliasing the input"
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
torch.compile(fn)(view_f, x)
|
torch.compile(fn)(view_f, x)
|
||||||
|
|
||||||
|
|
@ -7950,12 +8149,20 @@ class GraphModule(torch.nn.Module):
|
||||||
x = torch.randn(2, 2)
|
x = torch.randn(2, 2)
|
||||||
for f in ALIAS_FN:
|
for f in ALIAS_FN:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.BackendCompilerFailed, "might be modifying the input"
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
torch.compile(fn)(f, x)
|
torch.compile(fn)(f, x)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.BackendCompilerFailed, "might be modifying the input"
|
# Should be
|
||||||
|
# torch._dynamo.exc.Unsupported,
|
||||||
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
||||||
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
||||||
):
|
):
|
||||||
with torch.inference_mode(inference_mode):
|
with torch.inference_mode(inference_mode):
|
||||||
torch.compile(fn)(f, x)
|
torch.compile(fn)(f, x)
|
||||||
|
|
|
||||||
|
|
@ -570,7 +570,7 @@ class CondTests(TestCase):
|
||||||
return torch.cond(p, true_fn, false_fn, [a, b])
|
return torch.cond(p, true_fn, false_fn, [a, b])
|
||||||
|
|
||||||
# AssertionError: Output aliasing is currently not supported...
|
# AssertionError: Output aliasing is currently not supported...
|
||||||
with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed):
|
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||||
torch.compile(Model())(
|
torch.compile(Model())(
|
||||||
torch.tensor(True),
|
torch.tensor(True),
|
||||||
torch.randn(10, 20),
|
torch.randn(10, 20),
|
||||||
|
|
|
||||||
|
|
@ -2831,11 +2831,13 @@ class SubgraphTracer(fx.Tracer):
|
||||||
return MutationInfo(False, "")
|
return MutationInfo(False, "")
|
||||||
|
|
||||||
def has_aliasing(self):
|
def has_aliasing(self):
|
||||||
|
from torch._higher_order_ops.utils import _collect_fake_inputs
|
||||||
|
|
||||||
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||||
|
|
||||||
for node in self.graph.nodes:
|
for node in self.graph.nodes:
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
example_value = node.meta["example_value"]
|
example_value = _collect_fake_inputs([node])[0]
|
||||||
if isinstance(example_value, torch.Tensor):
|
if isinstance(example_value, torch.Tensor):
|
||||||
storage = StorageWeakRef(example_value._typed_storage())
|
storage = StorageWeakRef(example_value._typed_storage())
|
||||||
if storage in input_storages:
|
if storage in input_storages:
|
||||||
|
|
@ -2848,9 +2850,9 @@ class SubgraphTracer(fx.Tracer):
|
||||||
|
|
||||||
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||||
out_nodes = self.graph.find_nodes(op="output")[0]
|
out_nodes = self.graph.find_nodes(op="output")[0]
|
||||||
for out_node in out_nodes.args[0]:
|
for out_node in pytree.tree_leaves(out_nodes.args[0]):
|
||||||
if out_node:
|
if out_node:
|
||||||
example_value = out_node.meta["example_value"]
|
example_value = _collect_fake_inputs([out_node])[0]
|
||||||
assert not isinstance(example_value, list)
|
assert not isinstance(example_value, list)
|
||||||
if isinstance(example_value, torch.Tensor):
|
if isinstance(example_value, torch.Tensor):
|
||||||
storage = StorageWeakRef(example_value._typed_storage())
|
storage = StorageWeakRef(example_value._typed_storage())
|
||||||
|
|
|
||||||
|
|
@ -961,6 +961,9 @@ class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable
|
||||||
|
|
||||||
|
|
||||||
class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
supports_input_mutation = False
|
||||||
|
supports_aliasing = False
|
||||||
|
|
||||||
@raise_hard_error_if_graph_break(
|
@raise_hard_error_if_graph_break(
|
||||||
reason="Cond doesn't work unless it is captured completely with torch.compile."
|
reason="Cond doesn't work unless it is captured completely with torch.compile."
|
||||||
)
|
)
|
||||||
|
|
@ -1058,6 +1061,8 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
"cond",
|
"cond",
|
||||||
source_target=self.value,
|
source_target=self.value,
|
||||||
should_flatten_outputs=True,
|
should_flatten_outputs=True,
|
||||||
|
supports_input_mutation=self.supports_input_mutation,
|
||||||
|
supports_aliasing=self.supports_aliasing,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not only_consist_of(ret_val, (TensorVariable,)):
|
if not only_consist_of(ret_val, (TensorVariable,)):
|
||||||
|
|
@ -1184,6 +1189,9 @@ def validate_subgraph_output_types(output: VariableTracker):
|
||||||
|
|
||||||
|
|
||||||
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
supports_input_mutation = False
|
||||||
|
supports_aliasing = False
|
||||||
|
|
||||||
@raise_hard_error_if_graph_break(
|
@raise_hard_error_if_graph_break(
|
||||||
reason="while_loop doesn't work unless it is captured completely with torch.compile."
|
reason="while_loop doesn't work unless it is captured completely with torch.compile."
|
||||||
)
|
)
|
||||||
|
|
@ -1299,6 +1307,8 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
# So it's best we always enforce the ordering of carried_inputs the same as outputs
|
# So it's best we always enforce the ordering of carried_inputs the same as outputs
|
||||||
# with "flatten_manual".
|
# with "flatten_manual".
|
||||||
set_subgraph_inputs="flatten_manual",
|
set_subgraph_inputs="flatten_manual",
|
||||||
|
supports_input_mutation=self.supports_input_mutation,
|
||||||
|
supports_aliasing=self.supports_aliasing,
|
||||||
)
|
)
|
||||||
cond_nn_modules = dict(tx.output.nn_modules)
|
cond_nn_modules = dict(tx.output.nn_modules)
|
||||||
validate_subgraph_output_types(cond_r)
|
validate_subgraph_output_types(cond_r)
|
||||||
|
|
@ -1337,6 +1347,8 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
source_target=self.value,
|
source_target=self.value,
|
||||||
set_subgraph_inputs="flatten_manual",
|
set_subgraph_inputs="flatten_manual",
|
||||||
should_flatten_outputs=True,
|
should_flatten_outputs=True,
|
||||||
|
supports_input_mutation=False,
|
||||||
|
supports_aliasing=False,
|
||||||
)
|
)
|
||||||
validate_subgraph_output_types(body_r)
|
validate_subgraph_output_types(body_r)
|
||||||
|
|
||||||
|
|
@ -1420,6 +1432,9 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
|
||||||
|
|
||||||
class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
supports_input_mutation = False
|
||||||
|
supports_aliasing = False
|
||||||
|
|
||||||
@raise_hard_error_if_graph_break(
|
@raise_hard_error_if_graph_break(
|
||||||
reason="associative_scan must be captured completely with torch.compile."
|
reason="associative_scan must be captured completely with torch.compile."
|
||||||
)
|
)
|
||||||
|
|
@ -1513,6 +1528,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
description="associative_scan_combine_fn",
|
description="associative_scan_combine_fn",
|
||||||
source_target=self.value,
|
source_target=self.value,
|
||||||
set_subgraph_inputs="flatten_manual",
|
set_subgraph_inputs="flatten_manual",
|
||||||
|
supports_input_mutation=self.supports_input_mutation,
|
||||||
|
supports_aliasing=self.supports_aliasing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure that the output of scan is a flattened list of elements,
|
# Ensure that the output of scan is a flattened list of elements,
|
||||||
|
|
@ -1562,30 +1579,23 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
)
|
)
|
||||||
|
|
||||||
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
|
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
|
||||||
|
combine_freevars_proxy = tuple(combine_lifted_freevars.keys())
|
||||||
|
|
||||||
from torch._higher_order_ops.utils import (
|
# Compute the proxies for the input check
|
||||||
_has_potential_branch_input_alias,
|
proxy_vars_inputcheck = (
|
||||||
_has_potential_branch_input_mutation,
|
tuple(sarg.as_proxy() for sarg in sub_args) + combine_freevars_proxy
|
||||||
_maybe_fake_tracing,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from torch._higher_order_ops.utils import _maybe_fake_tracing
|
||||||
from torch._inductor.utils import is_pointwise_use
|
from torch._inductor.utils import is_pointwise_use
|
||||||
from torch._subclasses.fake_tensor import FakeTensor
|
|
||||||
|
|
||||||
with tx.fake_mode:
|
with tx.fake_mode:
|
||||||
xs_fake = [
|
sub_args_fake = [
|
||||||
first_slice_copy(leaf.proxy.node.meta["example_value"].clone())
|
|
||||||
for leaf in itertools.chain(xs_vars, xs_vars)
|
|
||||||
]
|
|
||||||
additional_fake = [
|
|
||||||
leaf.proxy.node.meta["example_value"].clone()
|
|
||||||
for leaf in additional_inputs_vars
|
|
||||||
] + [
|
|
||||||
leaf.node.meta["example_value"].clone()
|
leaf.node.meta["example_value"].clone()
|
||||||
if isinstance(leaf.node.meta["example_value"], FakeTensor)
|
if hasattr(leaf.node.meta["example_value"], "clone")
|
||||||
else leaf.node.meta["example_value"]
|
else leaf.node.meta["example_value"]
|
||||||
for leaf in combine_lifted_freevars.keys()
|
for leaf in pytree.tree_leaves(proxy_vars_inputcheck)
|
||||||
]
|
]
|
||||||
sub_args_fake = xs_fake + additional_fake
|
|
||||||
pre_dispatch = False
|
pre_dispatch = False
|
||||||
|
|
||||||
fx = _maybe_fake_tracing(
|
fx = _maybe_fake_tracing(
|
||||||
|
|
@ -1601,15 +1611,6 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
"For combine_mode='pointwise', the combine_fn needs to be pointwise"
|
"For combine_mode='pointwise', the combine_fn needs to be pointwise"
|
||||||
)
|
)
|
||||||
|
|
||||||
if _has_potential_branch_input_mutation(
|
|
||||||
combine_gm, sub_args_fake, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise RuntimeError("Combine_fn might be modifying the input!") # noqa: F541
|
|
||||||
if _has_potential_branch_input_alias(
|
|
||||||
combine_gm, sub_args_fake, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise RuntimeError("Combine_fn might be aliasing the input!") # noqa: F541
|
|
||||||
|
|
||||||
combine_fn_name = tx.output.install_subgraph(
|
combine_fn_name = tx.output.install_subgraph(
|
||||||
"associative_scan_combine_fn", combine_gm
|
"associative_scan_combine_fn", combine_gm
|
||||||
)
|
)
|
||||||
|
|
@ -1641,6 +1642,9 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
|
||||||
|
|
||||||
class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
supports_input_mutation = False
|
||||||
|
supports_aliasing = False
|
||||||
|
|
||||||
@raise_hard_error_if_graph_break(
|
@raise_hard_error_if_graph_break(
|
||||||
reason="scan must be captured completely with torch.compile."
|
reason="scan must be captured completely with torch.compile."
|
||||||
)
|
)
|
||||||
|
|
@ -1752,7 +1756,10 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
description="scan_combine_fn",
|
description="scan_combine_fn",
|
||||||
source_target=self.value,
|
source_target=self.value,
|
||||||
set_subgraph_inputs="flatten_manual",
|
set_subgraph_inputs="flatten_manual",
|
||||||
|
supports_input_mutation=self.supports_input_mutation,
|
||||||
|
supports_aliasing=self.supports_aliasing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure that the output of scan is a flattened list of elements,
|
# Ensure that the output of scan is a flattened list of elements,
|
||||||
# because downstream operations assume that the output of HOPs
|
# because downstream operations assume that the output of HOPs
|
||||||
# is flattened
|
# is flattened
|
||||||
|
|
@ -1848,6 +1855,9 @@ def non_single_tensor_return_unsupported(api, ret):
|
||||||
|
|
||||||
|
|
||||||
class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
supports_input_mutation = False
|
||||||
|
supports_aliasing = False
|
||||||
|
|
||||||
@raise_hard_error_if_graph_break(
|
@raise_hard_error_if_graph_break(
|
||||||
reason="map doesn't work unless it is captured completely with torch.compile."
|
reason="map doesn't work unless it is captured completely with torch.compile."
|
||||||
)
|
)
|
||||||
|
|
@ -1910,6 +1920,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
source_target=self.value,
|
source_target=self.value,
|
||||||
set_subgraph_inputs="flatten_manual",
|
set_subgraph_inputs="flatten_manual",
|
||||||
should_flatten_outputs=True,
|
should_flatten_outputs=True,
|
||||||
|
supports_input_mutation=self.supports_input_mutation,
|
||||||
|
supports_aliasing=self.supports_aliasing,
|
||||||
)
|
)
|
||||||
|
|
||||||
body_nn_modules = dict(tx.output.nn_modules)
|
body_nn_modules = dict(tx.output.nn_modules)
|
||||||
|
|
|
||||||
|
|
@ -434,12 +434,27 @@ def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, additional_inputs):
|
||||||
|
|
||||||
@associative_scan_op.py_functionalize_impl
|
@associative_scan_op.py_functionalize_impl
|
||||||
def associative_scan_functionalize(ctx, combine_fn, xs, additional_inputs):
|
def associative_scan_functionalize(ctx, combine_fn, xs, additional_inputs):
|
||||||
|
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||||
|
|
||||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||||
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
||||||
with ctx.redispatch_to_next():
|
with ctx.redispatch_to_next():
|
||||||
functional_combine_fn = ctx.functionalize(
|
functional_combine_fn = ctx.functionalize(
|
||||||
_maybe_run_with_interpreter(combine_fn)
|
_maybe_run_with_interpreter(combine_fn)
|
||||||
)
|
)
|
||||||
|
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||||
|
sample_unwrapped_xs_sliced = [
|
||||||
|
first_slice_copy(inp) for inp in itertools.chain(unwrapped_xs, unwrapped_xs)
|
||||||
|
]
|
||||||
|
sample_inputs = list(
|
||||||
|
itertools.chain(
|
||||||
|
sample_unwrapped_xs_sliced,
|
||||||
|
unwrapped_additional_inputs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_check_alias_and_mutation(
|
||||||
|
combine_fn, sample_inputs, "associative_scan", pre_dispatch
|
||||||
|
)
|
||||||
ret = associative_scan_op(
|
ret = associative_scan_op(
|
||||||
functional_combine_fn,
|
functional_combine_fn,
|
||||||
unwrapped_xs,
|
unwrapped_xs,
|
||||||
|
|
|
||||||
|
|
@ -136,10 +136,10 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
||||||
for ph in subgraph.graph.find_nodes(op="placeholder")
|
for ph in subgraph.graph.find_nodes(op="placeholder")
|
||||||
]
|
]
|
||||||
(
|
(
|
||||||
mutated_inp_idx,
|
|
||||||
inp_inp_alias,
|
inp_inp_alias,
|
||||||
inp_out_alias,
|
inp_out_alias,
|
||||||
out_out_alias,
|
out_out_alias,
|
||||||
|
mutated_inp_idx,
|
||||||
output,
|
output,
|
||||||
) = check_input_alias_and_mutation_return_ouputs(subgraph, fake_args)
|
) = check_input_alias_and_mutation_return_ouputs(subgraph, fake_args)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,6 @@ from torch._C._functorch import (
|
||||||
from torch._dispatch.python import suspend_functionalization
|
from torch._dispatch.python import suspend_functionalization
|
||||||
from torch._functorch.utils import exposed_in
|
from torch._functorch.utils import exposed_in
|
||||||
from torch._higher_order_ops.utils import (
|
from torch._higher_order_ops.utils import (
|
||||||
_has_potential_branch_input_alias,
|
|
||||||
_has_potential_branch_input_mutation,
|
|
||||||
_maybe_reenter_make_fx,
|
_maybe_reenter_make_fx,
|
||||||
_maybe_run_with_interpreter,
|
_maybe_run_with_interpreter,
|
||||||
_set_compilation_env,
|
_set_compilation_env,
|
||||||
|
|
@ -27,7 +25,6 @@ from torch._higher_order_ops.utils import (
|
||||||
save_tensors_and_symints_for_backward,
|
save_tensors_and_symints_for_backward,
|
||||||
saved_tensors_and_symints,
|
saved_tensors_and_symints,
|
||||||
unique_graph_id,
|
unique_graph_id,
|
||||||
UnsupportedAliasMutationException,
|
|
||||||
validate_subgraph_args_types,
|
validate_subgraph_args_types,
|
||||||
)
|
)
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
|
|
@ -668,29 +665,18 @@ def _merge_tensors(
|
||||||
|
|
||||||
@cond_op.py_functionalize_impl
|
@cond_op.py_functionalize_impl
|
||||||
def cond_func(ctx, pred, true_fn, false_fn, inputs):
|
def cond_func(ctx, pred, true_fn, false_fn, inputs):
|
||||||
|
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||||
|
|
||||||
unwrapped_inputs = ctx.unwrap_tensors(inputs)
|
unwrapped_inputs = ctx.unwrap_tensors(inputs)
|
||||||
unwrapped_pred = ctx.unwrap_tensors(pred)
|
unwrapped_pred = ctx.unwrap_tensors(pred)
|
||||||
with ctx.redispatch_to_next():
|
with ctx.redispatch_to_next():
|
||||||
functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn))
|
functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn))
|
||||||
functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn))
|
functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn))
|
||||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||||
for branch in [true_fn, false_fn]:
|
for branch, branch_name in [(true_fn, "cond_true"), (false_fn, "cond_false")]:
|
||||||
if _has_potential_branch_input_mutation(
|
_check_alias_and_mutation(
|
||||||
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
|
branch, unwrapped_inputs, branch_name, pre_dispatch
|
||||||
):
|
)
|
||||||
raise UnsupportedAliasMutationException(
|
|
||||||
"One of torch.cond branch might be modifying the input! "
|
|
||||||
"Consider cloning the input before modifying it. "
|
|
||||||
)
|
|
||||||
for branch in [true_fn, false_fn]:
|
|
||||||
if _has_potential_branch_input_alias(
|
|
||||||
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise UnsupportedAliasMutationException(
|
|
||||||
"One of torch.cond branch might be aliasing the input! "
|
|
||||||
"If you are returning a view of the input, please make sure "
|
|
||||||
"to clone it. "
|
|
||||||
)
|
|
||||||
|
|
||||||
cond_return = cond_op(
|
cond_return = cond_op(
|
||||||
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
|
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
|
||||||
|
|
|
||||||
|
|
@ -420,6 +420,9 @@ def flex_attention_functionalize(
|
||||||
functional_score_mod = ctx.functionalize(score_mod)
|
functional_score_mod = ctx.functionalize(score_mod)
|
||||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||||
with TransformGetItemToIndex():
|
with TransformGetItemToIndex():
|
||||||
|
# TODO: So far only the input mutations are checked
|
||||||
|
# In the other HOPs, also aliases are checked which is
|
||||||
|
# omitted here
|
||||||
mutates = _has_potential_branch_input_mutation(
|
mutates = _has_potential_branch_input_mutation(
|
||||||
score_mod, example_vals, pre_dispatch
|
score_mod, example_vals, pre_dispatch
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,9 @@ import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch._C import DispatchKey
|
from torch._C import DispatchKey
|
||||||
from torch._higher_order_ops.utils import (
|
from torch._higher_order_ops.utils import (
|
||||||
_has_potential_branch_input_alias,
|
|
||||||
_has_potential_branch_input_mutation,
|
|
||||||
autograd_not_implemented,
|
autograd_not_implemented,
|
||||||
reenter_make_fx,
|
reenter_make_fx,
|
||||||
unique_graph_id,
|
unique_graph_id,
|
||||||
UnsupportedAliasMutationException,
|
|
||||||
)
|
)
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
|
|
@ -91,24 +88,18 @@ def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints):
|
||||||
|
|
||||||
@hints_wrapper.py_functionalize_impl
|
@hints_wrapper.py_functionalize_impl
|
||||||
def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
|
def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
|
||||||
|
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||||
|
|
||||||
unwrapped_args = ctx.unwrap_tensors(args)
|
unwrapped_args = ctx.unwrap_tensors(args)
|
||||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||||
unwrapped_hints = ctx.unwrap_tensors(hints)
|
unwrapped_hints = ctx.unwrap_tensors(hints)
|
||||||
with ctx.redispatch_to_next():
|
with ctx.redispatch_to_next():
|
||||||
functional_body_fn = ctx.functionalize(body_fn)
|
functional_body_fn = ctx.functionalize(body_fn)
|
||||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||||
if _has_potential_branch_input_mutation(
|
_check_alias_and_mutation(
|
||||||
body_fn, unwrapped_args, pre_dispatch=pre_dispatch
|
body_fn, unwrapped_args, "hints_wrapper", pre_dispatch
|
||||||
):
|
)
|
||||||
raise UnsupportedAliasMutationException(
|
|
||||||
"body_fn of hints_wrapper might be modifying the input!"
|
|
||||||
)
|
|
||||||
if _has_potential_branch_input_alias(
|
|
||||||
body_fn, unwrapped_args, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise UnsupportedAliasMutationException(
|
|
||||||
"body_fn of hints_wrapper might be aliasing the input!"
|
|
||||||
)
|
|
||||||
outputs = hints_wrapper(
|
outputs = hints_wrapper(
|
||||||
functional_body_fn,
|
functional_body_fn,
|
||||||
unwrapped_args,
|
unwrapped_args,
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,7 @@ import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch._C import DispatchKey
|
from torch._C import DispatchKey
|
||||||
from torch._dispatch.python import suspend_functionalization
|
from torch._dispatch.python import suspend_functionalization
|
||||||
from torch._higher_order_ops.utils import (
|
from torch._higher_order_ops.utils import _maybe_run_with_interpreter, reenter_make_fx
|
||||||
_has_potential_branch_input_alias,
|
|
||||||
_has_potential_branch_input_mutation,
|
|
||||||
_maybe_run_with_interpreter,
|
|
||||||
reenter_make_fx,
|
|
||||||
UnsupportedAliasMutationException,
|
|
||||||
)
|
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||||
|
|
@ -284,23 +278,15 @@ def map_fake_tensor_mode(mode, f, xs, args):
|
||||||
|
|
||||||
@map_impl.py_functionalize_impl
|
@map_impl.py_functionalize_impl
|
||||||
def map_functionalize(ctx, f, xs, pos_args):
|
def map_functionalize(ctx, f, xs, pos_args):
|
||||||
|
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||||
|
|
||||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||||
unwrapped_args = ctx.unwrap_tensors(pos_args)
|
unwrapped_args = ctx.unwrap_tensors(pos_args)
|
||||||
wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))
|
wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))
|
||||||
|
|
||||||
with ctx.redispatch_to_next():
|
with ctx.redispatch_to_next():
|
||||||
with disable_proxy_modes_tracing():
|
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
||||||
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
|
||||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||||
if _has_potential_branch_input_mutation(
|
_check_alias_and_mutation(f, example_inputs, "map", pre_dispatch)
|
||||||
f, example_inputs, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
|
|
||||||
|
|
||||||
if _has_potential_branch_input_alias(
|
|
||||||
f, example_inputs, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
|
||||||
|
|
||||||
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
|
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
|
||||||
return ctx.wrap_tensors(map_return)
|
return ctx.wrap_tensors(map_return)
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,6 @@ import torch.utils._pytree as pytree
|
||||||
from torch._C import DispatchKey
|
from torch._C import DispatchKey
|
||||||
from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph
|
from torch._higher_order_ops.cond import create_bw_fn, materialize_as_graph
|
||||||
from torch._higher_order_ops.utils import (
|
from torch._higher_order_ops.utils import (
|
||||||
_has_potential_branch_input_alias,
|
|
||||||
_has_potential_branch_input_mutation,
|
|
||||||
_maybe_compile_and_run_fn,
|
_maybe_compile_and_run_fn,
|
||||||
check_meta_consistency,
|
check_meta_consistency,
|
||||||
first_slice_copy,
|
first_slice_copy,
|
||||||
|
|
@ -19,7 +17,6 @@ from torch._higher_order_ops.utils import (
|
||||||
save_tensors_and_symints_for_backward,
|
save_tensors_and_symints_for_backward,
|
||||||
saved_tensors_and_symints,
|
saved_tensors_and_symints,
|
||||||
unique_graph_id,
|
unique_graph_id,
|
||||||
UnsupportedAliasMutationException,
|
|
||||||
validate_subgraph_args_types,
|
validate_subgraph_args_types,
|
||||||
)
|
)
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
|
|
@ -861,12 +858,19 @@ def scan_fake_tensor_mode(mode, combine_fn, init, xs, additional_inputs):
|
||||||
|
|
||||||
@scan_op.py_functionalize_impl
|
@scan_op.py_functionalize_impl
|
||||||
def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
|
def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
|
||||||
|
from torch._higher_order_ops.utils import (
|
||||||
|
_check_alias_and_mutation,
|
||||||
|
_maybe_run_with_interpreter,
|
||||||
|
)
|
||||||
|
|
||||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||||
unwrapped_init = ctx.unwrap_tensors(init)
|
unwrapped_init = ctx.unwrap_tensors(init)
|
||||||
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
||||||
|
|
||||||
with ctx.redispatch_to_next():
|
with ctx.redispatch_to_next():
|
||||||
functional_combine_fn = ctx.functionalize(combine_fn)
|
functional_combine_fn = ctx.functionalize(
|
||||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
_maybe_run_with_interpreter(combine_fn)
|
||||||
|
)
|
||||||
sample_unwrapped_xs_sliced = [first_slice_copy(inp) for inp in unwrapped_xs]
|
sample_unwrapped_xs_sliced = [first_slice_copy(inp) for inp in unwrapped_xs]
|
||||||
sample_inputs = list(
|
sample_inputs = list(
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
|
|
@ -875,18 +879,8 @@ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
|
||||||
unwrapped_additional_inputs,
|
unwrapped_additional_inputs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if _has_potential_branch_input_mutation(
|
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||||
combine_fn, sample_inputs, pre_dispatch=pre_dispatch
|
_check_alias_and_mutation(combine_fn, sample_inputs, "scan", pre_dispatch)
|
||||||
):
|
|
||||||
raise UnsupportedAliasMutationException(
|
|
||||||
"Combine_fn might be modifying the input!"
|
|
||||||
)
|
|
||||||
if _has_potential_branch_input_alias(
|
|
||||||
combine_fn, sample_inputs, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise UnsupportedAliasMutationException(
|
|
||||||
"Combine_fn might be aliasing the input!"
|
|
||||||
)
|
|
||||||
ret = scan_op(
|
ret = scan_op(
|
||||||
functional_combine_fn,
|
functional_combine_fn,
|
||||||
unwrapped_init,
|
unwrapped_init,
|
||||||
|
|
|
||||||
|
|
@ -248,34 +248,6 @@ def _set_compilation_env():
|
||||||
torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs
|
torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs
|
||||||
|
|
||||||
|
|
||||||
def _detect_input_mutation(gm: torch.fx.GraphModule) -> bool:
|
|
||||||
example_inputs = [
|
|
||||||
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
|
|
||||||
]
|
|
||||||
inp_mutation, _, _, _ = check_input_alias_and_mutation(gm, example_inputs)
|
|
||||||
if len(inp_mutation) > 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
for _, module in gm.named_children():
|
|
||||||
if isinstance(module, torch.fx.GraphModule):
|
|
||||||
if _detect_input_mutation(module):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_input_alias(gm: torch.fx.GraphModule) -> bool:
|
|
||||||
example_inputs = [
|
|
||||||
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
|
|
||||||
]
|
|
||||||
_, inp_inp_alias_map, inp_out_alias_map, _ = check_input_alias_and_mutation(
|
|
||||||
gm, example_inputs
|
|
||||||
)
|
|
||||||
if len(inp_out_alias_map) > 0 or len(inp_inp_alias_map) > 0:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# The invariant here is that we always trace the branch with fake tensor
|
# The invariant here is that we always trace the branch with fake tensor
|
||||||
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
||||||
fake_mode = detect_fake_mode(inputs)
|
fake_mode = detect_fake_mode(inputs)
|
||||||
|
|
@ -301,7 +273,7 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
def potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
||||||
try:
|
try:
|
||||||
gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
|
gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
|
||||||
except UnsupportedAliasMutationException:
|
except UnsupportedAliasMutationException:
|
||||||
|
|
@ -311,43 +283,113 @@ def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return _detect_input_mutation(gm) or _detect_input_alias(gm)
|
example_inputs = [
|
||||||
|
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
|
||||||
|
]
|
||||||
|
(
|
||||||
|
inp_inp_alias_map,
|
||||||
|
inp_out_alias_map,
|
||||||
|
out_out_alias_map,
|
||||||
|
inp_mutation,
|
||||||
|
) = check_input_alias_and_mutation(gm, example_inputs)
|
||||||
|
return (inp_inp_alias_map, inp_out_alias_map, out_out_alias_map), inp_mutation
|
||||||
|
|
||||||
|
|
||||||
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
|
def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations):
|
||||||
"""
|
if any(len(a) > 0 for a in aliases):
|
||||||
Dispatch-trace the branch with inputs and check if
|
# TODO: Investigate here further which node is exactly aliasing
|
||||||
producing graph has mutable op on the input. This is
|
raise RuntimeError(
|
||||||
bit restrictive as the branch must be traceable.
|
f"{name} where aliases appear. "
|
||||||
"""
|
+ f"In particular, these inputs \
|
||||||
try:
|
{set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map.keys())} " # noqa: C401
|
||||||
gm = _maybe_fake_tracing(branch, inputs, pre_dispatch)
|
+ "get aliased. Please ensure that this doesn't happen."
|
||||||
except UnsupportedAliasMutationException:
|
)
|
||||||
# this can happen when nested cond_op is
|
if len(input_mutations):
|
||||||
# functionalized
|
# TODO: Investigate here further which node is exactly mutating the inputs
|
||||||
return True
|
raise RuntimeError(
|
||||||
except Exception as e:
|
f"{name} where the inputs are mutated. "
|
||||||
raise e
|
+ f"In particular, these nodes are mutating the inputs \
|
||||||
|
{set(el for el in input_mutations)}." # noqa: C401
|
||||||
return _detect_input_mutation(gm)
|
+ "Please ensure that this doesn't happen."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
|
def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False):
|
||||||
"""
|
(
|
||||||
Dispatch-trace the branch with inputs and check if
|
_,
|
||||||
producing graph has output aliasing the branch input. This is
|
_,
|
||||||
bit restrictive as the branch must be traceable.
|
_,
|
||||||
"""
|
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
||||||
try:
|
|
||||||
gm = _maybe_fake_tracing(branch, inputs, pre_dispatch)
|
|
||||||
except UnsupportedAliasMutationException:
|
|
||||||
# this can happen when nested cond_op is
|
|
||||||
# functionalized
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return _detect_input_alias(gm)
|
return len(inp_mutation) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
||||||
|
(
|
||||||
|
inp_inp_alias_map,
|
||||||
|
inp_out_alias_map,
|
||||||
|
out_out_alias_map,
|
||||||
|
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
||||||
|
return (
|
||||||
|
any(
|
||||||
|
(
|
||||||
|
len(inp_inp_alias_map) > 0,
|
||||||
|
len(inp_out_alias_map) > 0,
|
||||||
|
len(out_out_alias_map) > 0,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
len(inp_mutation) > 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_fake_inputs(inputs):
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensor
|
||||||
|
|
||||||
|
# Get the example values of the inputs.
|
||||||
|
inputs_fake: list[Union[FakeTensor, torch.Tensor, int]] = []
|
||||||
|
for inp in inputs:
|
||||||
|
if isinstance(inp, (torch.fx.proxy.Proxy, torch.fx.node.Node)):
|
||||||
|
inp = inp.node if isinstance(inp, torch.fx.proxy.Proxy) else inp
|
||||||
|
if hasattr(inp, "meta"):
|
||||||
|
val = inp.meta["example_value"]
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
if torch._C._functorch.is_batchedtensor(
|
||||||
|
val
|
||||||
|
) or torch._C._functorch.is_functionaltensor(val):
|
||||||
|
# This case is for batched or functional tensors
|
||||||
|
# Unwrap the tensors
|
||||||
|
while torch._C._functorch.is_batchedtensor(
|
||||||
|
val
|
||||||
|
) or torch._C._functorch.is_functionaltensor(val):
|
||||||
|
val = torch._C._functorch.get_unwrapped(val)
|
||||||
|
assert isinstance(val, FakeTensor)
|
||||||
|
inputs_fake.append(val)
|
||||||
|
else:
|
||||||
|
# This is the standard case of a TensorVariable
|
||||||
|
assert isinstance(val, FakeTensor)
|
||||||
|
inputs_fake.append(val)
|
||||||
|
else:
|
||||||
|
# This case is for SymInts and other non-Tensor elements
|
||||||
|
assert not isinstance(val, torch.Tensor)
|
||||||
|
inputs_fake.append(val)
|
||||||
|
else:
|
||||||
|
# This case is for ints
|
||||||
|
assert isinstance(inp, int)
|
||||||
|
inputs_fake.append(inp)
|
||||||
|
|
||||||
|
return inputs_fake
|
||||||
|
|
||||||
|
|
||||||
|
def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch):
|
||||||
|
aliases, inp_mutation = has_potential_input_alias_or_mutation(
|
||||||
|
graph_module, inputs_fake, pre_dispatch=pre_dispatch
|
||||||
|
)
|
||||||
|
if aliases:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{name} might be aliasing the input or the output!"
|
||||||
|
) # noqa: F541
|
||||||
|
if inp_mutation:
|
||||||
|
raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541
|
||||||
|
|
||||||
|
|
||||||
def unique_graph_id(proxy_mode, prefix):
|
def unique_graph_id(proxy_mode, prefix):
|
||||||
|
|
@ -699,27 +741,29 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]])
|
||||||
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Return a more detailed information as to which node
|
||||||
|
# causes a mutation or an alias. This may requires a per operator tensor version checking
|
||||||
def check_input_alias_and_mutation(
|
def check_input_alias_and_mutation(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
fake_args: list[FakeTensor],
|
fake_args: list[FakeTensor],
|
||||||
) -> tuple[list[int], dict[int, int], dict[int, int], dict[int, int]]:
|
) -> tuple[dict[int, int], dict[int, int], dict[int, int], list[int]]:
|
||||||
(
|
(
|
||||||
mutated_inputs,
|
|
||||||
inp_inp_alias_map,
|
inp_inp_alias_map,
|
||||||
inp_out_alias_map,
|
inp_out_alias_map,
|
||||||
out_out_alias_map,
|
out_out_alias_map,
|
||||||
|
mutated_inputs,
|
||||||
) = check_input_alias_and_mutation_return_ouputs(gm, fake_args)[:-1]
|
) = check_input_alias_and_mutation_return_ouputs(gm, fake_args)[:-1]
|
||||||
return mutated_inputs, inp_inp_alias_map, inp_out_alias_map, out_out_alias_map
|
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
|
||||||
|
|
||||||
|
|
||||||
def check_input_alias_and_mutation_return_ouputs(
|
def check_input_alias_and_mutation_return_ouputs(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
fake_args: list[FakeTensor],
|
fake_args: list[FakeTensor],
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
|
dict[int, int],
|
||||||
|
dict[int, int],
|
||||||
|
dict[int, int],
|
||||||
list[int],
|
list[int],
|
||||||
dict[int, int],
|
|
||||||
dict[int, int],
|
|
||||||
dict[int, int],
|
|
||||||
Union[tuple[Any, ...], list[Any]],
|
Union[tuple[Any, ...], list[Any]],
|
||||||
]:
|
]:
|
||||||
# We want to disable active functional, proxy and fake modes if any.
|
# We want to disable active functional, proxy and fake modes if any.
|
||||||
|
|
@ -825,10 +869,10 @@ def check_input_alias_and_mutation_return_ouputs(
|
||||||
if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map
|
if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
mutated_inputs,
|
|
||||||
inp_inp_alias_map,
|
inp_inp_alias_map,
|
||||||
inp_out_alias_map,
|
inp_out_alias_map,
|
||||||
out_out_alias_map,
|
out_out_alias_map,
|
||||||
|
mutated_inputs,
|
||||||
outputs,
|
outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,14 +6,11 @@ import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch._C import DispatchKey
|
from torch._C import DispatchKey
|
||||||
from torch._higher_order_ops.utils import (
|
from torch._higher_order_ops.utils import (
|
||||||
_has_potential_branch_input_alias,
|
|
||||||
_has_potential_branch_input_mutation,
|
|
||||||
_maybe_run_with_interpreter,
|
_maybe_run_with_interpreter,
|
||||||
_set_compilation_env,
|
_set_compilation_env,
|
||||||
autograd_not_implemented,
|
autograd_not_implemented,
|
||||||
check_meta_consistency,
|
check_meta_consistency,
|
||||||
reenter_make_fx,
|
reenter_make_fx,
|
||||||
UnsupportedAliasMutationException,
|
|
||||||
validate_subgraph_args_types,
|
validate_subgraph_args_types,
|
||||||
)
|
)
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
|
|
@ -400,6 +397,8 @@ def while_loop_fake_tensor_mode(
|
||||||
|
|
||||||
@while_loop_op.py_functionalize_impl
|
@while_loop_op.py_functionalize_impl
|
||||||
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
|
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||||
|
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
||||||
|
|
||||||
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
|
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
|
||||||
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
||||||
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
|
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
|
||||||
|
|
@ -411,19 +410,7 @@ def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||||
(cond_fn, "cond_fn"),
|
(cond_fn, "cond_fn"),
|
||||||
(body_fn, "body_fn"),
|
(body_fn, "body_fn"),
|
||||||
]:
|
]:
|
||||||
if _has_potential_branch_input_mutation(
|
_check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch)
|
||||||
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise UnsupportedAliasMutationException(
|
|
||||||
f"torch.while_loop's {fn_name} might be modifying the input!"
|
|
||||||
)
|
|
||||||
|
|
||||||
if _has_potential_branch_input_alias(
|
|
||||||
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
|
|
||||||
):
|
|
||||||
raise UnsupportedAliasMutationException(
|
|
||||||
f"torch.while_loop's {fn_name} might be aliasing the input!"
|
|
||||||
)
|
|
||||||
ret = while_loop_op(
|
ret = while_loop_op(
|
||||||
functional_cond_fn,
|
functional_cond_fn,
|
||||||
functional_body_fn,
|
functional_body_fn,
|
||||||
|
|
|
||||||
|
|
@ -7898,7 +7898,7 @@ class WhileLoop(ExternKernel):
|
||||||
# Handling input mutations
|
# Handling input mutations
|
||||||
mutated_idxs = check_input_alias_and_mutation(
|
mutated_idxs = check_input_alias_and_mutation(
|
||||||
body_fn.graph.module, fake_all_inputs
|
body_fn.graph.module, fake_all_inputs
|
||||||
)[0]
|
)[3]
|
||||||
mutated_idx_set = OrderedSet(mutated_idxs)
|
mutated_idx_set = OrderedSet(mutated_idxs)
|
||||||
mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set]
|
mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set]
|
||||||
real_outputs = {
|
real_outputs = {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user