mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Restrict support for out= variants of torch operators (#140202)
There has been a series of attempts to provide support for resizing in torch operators like `torch.sigmoid(x, out=y)`, i.e., `y` would have a different shape before and after this expression. Prior to this patch, we have some checks to graph break if the shape changed. This patch extends 1. extends the existing check and graph break for any shape change, not just for `TensorVariable` with source field. 2. removes an old code path which was introduced to address the shape change, but became obselete in that regard because we added extra checks to graph break upon shape change. Moreover, this old code path is unsound, it tries to replace references to the old `TensorVariable` the new one returned by `wrap_fx_proxy`, but it only does the replacement in `symbolic_locals`, which breaks when cells are involved. In general the old `TensorVariable` could be _anywhere_, think the `replace_all` we had for immutable VTs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140202 Approved by: https://github.com/jansel ghstack dependencies: #140035, #140036, #140149, #140150, #140151, #140201
This commit is contained in:
parent
65615915ed
commit
39d1c91c33
|
|
@ -2225,6 +2225,75 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||||
res = opt_m(x)
|
res = opt_m(x)
|
||||||
self.assertTrue(same(ref, res))
|
self.assertTrue(same(ref, res))
|
||||||
|
|
||||||
|
def test_out_root_cell_shape_change(self):
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def fn():
|
||||||
|
out = torch.empty(0)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
x = torch.zeros(3, 5)
|
||||||
|
torch.sigmoid(x, out=out)
|
||||||
|
return out.size()
|
||||||
|
|
||||||
|
return run()
|
||||||
|
|
||||||
|
res = fn()
|
||||||
|
self.assertEqual((3, 5), res)
|
||||||
|
|
||||||
|
def test_out_nested_cell_shape_change(self):
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def fn():
|
||||||
|
def run():
|
||||||
|
x = torch.zeros(3, 5)
|
||||||
|
out = torch.empty(0)
|
||||||
|
|
||||||
|
def capture():
|
||||||
|
return out # Force `out` to be a nested cell
|
||||||
|
|
||||||
|
torch.sigmoid(x, out=out)
|
||||||
|
return out.size()
|
||||||
|
|
||||||
|
return run()
|
||||||
|
|
||||||
|
res = fn()
|
||||||
|
self.assertEqual((3, 5), res)
|
||||||
|
|
||||||
|
def test_out_root_cell_tuple_shape_change(self):
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def fn():
|
||||||
|
out1 = torch.empty(0)
|
||||||
|
out2 = torch.empty(0, dtype=torch.long)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
x = torch.zeros(3, 5)
|
||||||
|
torch.sort(x, out=(out1, out2))
|
||||||
|
return out1.size(), out2.size()
|
||||||
|
|
||||||
|
return run()
|
||||||
|
|
||||||
|
res = fn()
|
||||||
|
self.assertEqual(((3, 5), (3, 5)), res)
|
||||||
|
|
||||||
|
def test_out_nested_cell_tuple_shape_change(self):
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def fn():
|
||||||
|
def run():
|
||||||
|
x = torch.zeros(3, 5)
|
||||||
|
out1 = torch.empty(0)
|
||||||
|
out2 = torch.empty(0, dtype=torch.long)
|
||||||
|
|
||||||
|
def capture():
|
||||||
|
# Force `out1` and `out2` to be nested cells
|
||||||
|
return out1, out2
|
||||||
|
|
||||||
|
torch.sort(x, out=(out1, out2))
|
||||||
|
return out1.size(), out2.size()
|
||||||
|
|
||||||
|
return run()
|
||||||
|
|
||||||
|
res = fn()
|
||||||
|
self.assertEqual(((3, 5), (3, 5)), res)
|
||||||
|
|
||||||
def test_slice_into_list_mutable(self):
|
def test_slice_into_list_mutable(self):
|
||||||
class Mod(torch.nn.Module):
|
class Mod(torch.nn.Module):
|
||||||
def forward(self, listy):
|
def forward(self, listy):
|
||||||
|
|
|
||||||
|
|
@ -2642,12 +2642,6 @@ class InstructionTranslatorBase(
|
||||||
def fake_mode(self):
|
def fake_mode(self):
|
||||||
return self.output.tracing_context.fake_mode
|
return self.output.tracing_context.fake_mode
|
||||||
|
|
||||||
def find_symbolic_locals_name(self, tensor_variable):
|
|
||||||
for key, value in self.symbolic_locals.items():
|
|
||||||
if value is tensor_variable:
|
|
||||||
return key
|
|
||||||
return None
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]):
|
def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -973,24 +973,24 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||||
isinstance(kwargs["out"], variables.ConstantVariable)
|
isinstance(kwargs["out"], variables.ConstantVariable)
|
||||||
and kwargs["out"].as_python_constant() is None
|
and kwargs["out"].as_python_constant() is None
|
||||||
):
|
):
|
||||||
# out variants of torch operators like torch.sort and
|
# out variants of torch operators like torch.sort and torch.sigmoid
|
||||||
# torch.sigmoid mutate the tensors in the out field. Track such
|
# mutate the tensors in the out field.
|
||||||
# tensors and rewrite the symbolic locals.
|
#
|
||||||
|
# However, it's non-trivial to update all references of the old
|
||||||
|
# `TensorVariable` to the new one returned (`result_var`), so we
|
||||||
|
# take the conservative approach to graph break on size changes, and
|
||||||
|
# assume other cases can fall through soundly.
|
||||||
|
#
|
||||||
|
# Note that although these tensor variablels would hold different
|
||||||
|
# proxies, the in-place mutation semantics is preserved in the FX
|
||||||
|
# graph, so we won't have correctness issues.
|
||||||
if isinstance(tensor_variable, TupleVariable):
|
if isinstance(tensor_variable, TupleVariable):
|
||||||
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
|
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
|
||||||
output_tensor_names = [
|
|
||||||
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
|
|
||||||
]
|
|
||||||
for idx, name in enumerate(output_tensor_names):
|
|
||||||
if name in tx.symbolic_locals:
|
|
||||||
tx.symbolic_locals[name] = tensor_variable.items[idx]
|
|
||||||
for out_tensor, result_tensor in zip(
|
for out_tensor, result_tensor in zip(
|
||||||
kwargs["out"].items, tensor_variable.items
|
kwargs["out"].items, tensor_variable.items
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
out_tensor.source
|
isinstance(out_tensor, variables.TensorVariable)
|
||||||
and out_tensor in tx.output.graphargs
|
|
||||||
and isinstance(out_tensor, variables.TensorVariable)
|
|
||||||
and isinstance(result_tensor, variables.TensorVariable)
|
and isinstance(result_tensor, variables.TensorVariable)
|
||||||
and out_tensor._size
|
and out_tensor._size
|
||||||
!= result_tensor._size # we actually want to compare None values here
|
!= result_tensor._size # we actually want to compare None values here
|
||||||
|
|
@ -1003,11 +1003,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||||
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
|
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
|
||||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||||
if (
|
if fake_out_shape != fake_tensor.shape:
|
||||||
kwargs["out"].source
|
|
||||||
and kwargs["out"] in tx.output.graphargs
|
|
||||||
and fake_out_shape != fake_tensor.shape
|
|
||||||
):
|
|
||||||
# It's hard to get out variants with resizing on graph inputs work
|
# It's hard to get out variants with resizing on graph inputs work
|
||||||
# properly across dynamo/aot/inductor, just fall back.
|
# properly across dynamo/aot/inductor, just fall back.
|
||||||
unimplemented("out variants with resizing on graph inputs")
|
unimplemented("out variants with resizing on graph inputs")
|
||||||
|
|
@ -1017,9 +1013,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||||
unimplemented(
|
unimplemented(
|
||||||
"out= op was called where output tensor was non-contiguous"
|
"out= op was called where output tensor was non-contiguous"
|
||||||
)
|
)
|
||||||
name = tx.find_symbolic_locals_name(kwargs["out"])
|
|
||||||
if name in tx.symbolic_locals:
|
|
||||||
tx.symbolic_locals[name] = tensor_variable
|
|
||||||
elif (
|
elif (
|
||||||
isinstance(tensor_variable, ConstantVariable)
|
isinstance(tensor_variable, ConstantVariable)
|
||||||
and tensor_variable.value is None
|
and tensor_variable.value is None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user