mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127)
For reshape/view: removes fast paths for 0 elements, checking dimensions to skip. Modifies the loop accumulating input elements, to raise a UserError if we run out of dimensions, graph breaking for compile and erroring out for export. For infer_size: assumes if user passes us an unbacked, it's probably not -1 Will think about changes in https://docs.google.com/document/d/1WYx6EZwVDXtBnWyrzoecgGWdiK0V3XZKftfpWwQ5i3E/edit?tab=t.0#heading=h.22k54zym11qp in a later PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/150127 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
b37fa20771
commit
54f736155b
|
|
@ -323,7 +323,7 @@ class TestDraftExport(TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
|
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
|
||||||
)
|
)
|
||||||
self.assertEqual(report.failures[0].data["expr"], "Eq(2*u1, 10)")
|
self.assertEqual(report.failures[0].data["expr"], "Eq(9380*u1, 0)")
|
||||||
|
|
||||||
def test_dedup_data_dependent_failure(self):
|
def test_dedup_data_dependent_failure(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -480,6 +480,7 @@ class TestDraftExport(TestCase):
|
||||||
return torch.nn.functional.linear(masked, weight, bias)
|
return torch.nn.functional.linear(masked, weight, bias)
|
||||||
|
|
||||||
x = torch.zeros(10)
|
x = torch.zeros(10)
|
||||||
|
x[0] += 1
|
||||||
inp = (torch.randn(10, 8, 7), x, torch.randn(25, 7), torch.randn(25))
|
inp = (torch.randn(10, 8, 7), x, torch.randn(25, 7), torch.randn(25))
|
||||||
draft_ep = draft_export(M(), inp)
|
draft_ep = draft_export(M(), inp)
|
||||||
ep = export(M(), inp)
|
ep = export(M(), inp)
|
||||||
|
|
|
||||||
|
|
@ -4301,7 +4301,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
def forward(self, t):
|
def forward(self, t):
|
||||||
items = [t[i].item() for i in range(t.numel())]
|
items = [t[i].item() for i in range(t.numel())]
|
||||||
r = torch.randn([items[0], items[1]])
|
r = torch.randn([items[0], items[1]])
|
||||||
# Could not guard on data-dependent expression Eq(u2, -1)
|
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0)
|
||||||
return r.view(items[0], items[2])
|
return r.view(items[0], items[2])
|
||||||
|
|
||||||
M = M_v0
|
M = M_v0
|
||||||
|
|
@ -4310,9 +4310,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
"The following call raised this error(.*\n)+"
|
"The following call raised this error(.*\n)+"
|
||||||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
|
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
|
||||||
"To fix the error, insert one of the following checks before this call.*:\n"
|
"To fix the error, insert one of the following checks before this call.*:\n"
|
||||||
f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
|
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
|
||||||
f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
|
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
|
||||||
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
|
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
|
||||||
|
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
|
||||||
):
|
):
|
||||||
export(N(), (t,), strict=strict)
|
export(N(), (t,), strict=strict)
|
||||||
|
|
||||||
|
|
@ -4320,59 +4321,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
def forward(self, t):
|
def forward(self, t):
|
||||||
items = [t[i].item() for i in range(t.numel())]
|
items = [t[i].item() for i in range(t.numel())]
|
||||||
r = torch.randn([items[0], items[1]])
|
r = torch.randn([items[0], items[1]])
|
||||||
# Could not guard on data-dependent expression Eq(u2, -1)
|
# TODO(pianpwk): this isn't the suggested fixes.
|
||||||
torch._check(items[2] != -1)
|
# fix issue with % being interpreted as PythonMod instead of Mod
|
||||||
# Could not guard on data-dependent expression u2 >= 0
|
torch._check(items[1] == items[2])
|
||||||
return r.view(items[0], items[2])
|
return r.view(items[0], items[2])
|
||||||
|
|
||||||
M = M_v1
|
M = M_v1
|
||||||
with self.assertRaisesRegex(
|
|
||||||
error_type,
|
|
||||||
"The following call raised this error(.*\n)+"
|
|
||||||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
|
|
||||||
"To fix the error, insert one of the following checks before this call.*:\n"
|
|
||||||
f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
|
|
||||||
f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
|
|
||||||
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
|
|
||||||
):
|
|
||||||
export(N(), (t,), strict=strict)
|
|
||||||
|
|
||||||
class M_v2(torch.nn.Module):
|
|
||||||
def forward(self, t):
|
|
||||||
items = [t[i].item() for i in range(t.numel())]
|
|
||||||
r = torch.randn([items[0], items[1]])
|
|
||||||
# Could not guard on data-dependent expression Eq(u2, -1)
|
|
||||||
torch._check(items[2] != -1)
|
|
||||||
# Could not guard on data-dependent expression u2 >= 0
|
|
||||||
torch._check(items[2] >= 0)
|
|
||||||
# Could not guard on data-dependent expression Eq(u1, u2)
|
|
||||||
return r.view(items[0], items[2])
|
|
||||||
|
|
||||||
M = M_v2
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
error_type,
|
|
||||||
"The following call raised this error(.*\n)+"
|
|
||||||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
|
|
||||||
"To fix the error, insert one of the following checks before this call.*:\n"
|
|
||||||
f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
|
|
||||||
f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
|
|
||||||
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
|
|
||||||
):
|
|
||||||
export(N(), (t,), strict=strict)
|
|
||||||
|
|
||||||
class M_v3(torch.nn.Module):
|
|
||||||
def forward(self, t):
|
|
||||||
items = [t[i].item() for i in range(t.numel())]
|
|
||||||
r = torch.randn([items[0], items[1]])
|
|
||||||
# Could not guard on data-dependent expression Eq(u2, -1)
|
|
||||||
torch._check(items[2] != -1)
|
|
||||||
# Could not guard on data-dependent expression u2 >= 0
|
|
||||||
torch._check(items[2] >= 0)
|
|
||||||
# Could not guard on data-dependent expression Eq(u1, u2)
|
|
||||||
torch._check(items[2] == r.shape[1])
|
|
||||||
return r.view(items[0], items[2])
|
|
||||||
|
|
||||||
M = M_v3
|
|
||||||
export(N(), (t,), strict=strict)
|
export(N(), (t,), strict=strict)
|
||||||
|
|
||||||
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
|
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
|
||||||
|
|
@ -4484,6 +4438,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
fixes=[], # nothing to fix!
|
fixes=[], # nothing to fix!
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_simple_unbacked_view(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
u0 = x.item()
|
||||||
|
y = torch.empty(5, u0)
|
||||||
|
return y.view(u0, 5) # [5, u0] -> [u0, 5]
|
||||||
|
|
||||||
|
ep = export(Foo(), (torch.tensor([9]),))
|
||||||
|
self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
|
||||||
|
self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)
|
||||||
|
|
||||||
|
class Foov2(torch.nn.Module):
|
||||||
|
def forward(self, xs):
|
||||||
|
xsl = xs.tolist()
|
||||||
|
a, b = xsl
|
||||||
|
x = torch.zeros(a)
|
||||||
|
return x.reshape(b)
|
||||||
|
|
||||||
|
xs = torch.tensor([4, 4])
|
||||||
|
ep = export(Foov2(), (xs,))
|
||||||
|
self.assertEqual(ep.module()(xs).size(0), 4)
|
||||||
|
self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)
|
||||||
|
|
||||||
def test_no_suggested_fixes_for_data_dependent_errors(self):
|
def test_no_suggested_fixes_for_data_dependent_errors(self):
|
||||||
# suggested fixes for data-dependent errors only work in non-strict mode
|
# suggested fixes for data-dependent errors only work in non-strict mode
|
||||||
strict = False
|
strict = False
|
||||||
|
|
@ -7549,22 +7526,19 @@ def forward(self, b_a_buffer, x):
|
||||||
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
|
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_check_is_size_error(self):
|
def test_no_check_is_size_error(self):
|
||||||
class Module(torch.nn.Module):
|
class Module(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = x.item()
|
a = x.item()
|
||||||
# We cannot automatically infer a is a size here because view
|
|
||||||
# accepts -1
|
|
||||||
return torch.randn(24).view(a, 4)
|
return torch.randn(24).view(a, 4)
|
||||||
|
|
||||||
f = Module()
|
f = Module()
|
||||||
if is_non_strict_test(self._testMethodName):
|
ep = export(f, (torch.tensor(6),))
|
||||||
error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
|
ep.module()(torch.tensor(6))
|
||||||
else:
|
with self.assertRaisesRegex(
|
||||||
error = torch._dynamo.exc.UserError
|
RuntimeError, r"Runtime assertion failed for .* u.* 6"
|
||||||
error_msg = r"Could not guard on data-dependent expression"
|
):
|
||||||
with self.assertRaisesRegex(error, error_msg):
|
ep.module()(torch.tensor(5))
|
||||||
_ = export(f, (torch.tensor(6),))
|
|
||||||
|
|
||||||
def test_is_non_negative_check_function(self):
|
def test_is_non_negative_check_function(self):
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
|
|
@ -13487,7 +13461,7 @@ def forward(self, x):
|
||||||
node.target == torch.ops.aten._assert_scalar.default
|
node.target == torch.ops.aten._assert_scalar.default
|
||||||
for node in ep.graph.nodes
|
for node in ep.graph.nodes
|
||||||
].count(True)
|
].count(True)
|
||||||
self.assertEqual(num_asserts, 1)
|
self.assertEqual(num_asserts, 2)
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
ep.module()(torch.randn(4, 2))
|
ep.module()(torch.randn(4, 2))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -924,24 +924,29 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
|
||||||
Infers the size of a dim with size -1, if it exists.
|
Infers the size of a dim with size -1, if it exists.
|
||||||
Also checks that new shape is compatible with the number of elements.
|
Also checks that new shape is compatible with the number of elements.
|
||||||
"""
|
"""
|
||||||
|
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
|
||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
newsize = 1
|
newsize = 1
|
||||||
for i, d in enumerate(shape):
|
for i, d in enumerate(shape):
|
||||||
if d == -1:
|
if guard_or_false(d == -1):
|
||||||
torch._check(dim is None, lambda: "only one dimension can be inferred")
|
torch._check(dim is None, lambda: "only one dimension can be inferred")
|
||||||
dim = i
|
dim = i
|
||||||
elif d >= 0:
|
|
||||||
newsize *= d
|
|
||||||
else:
|
else:
|
||||||
torch._check(False, lambda: f"invalid shape dimension {d}")
|
torch._check(
|
||||||
|
d >= 0,
|
||||||
|
lambda: (
|
||||||
|
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
|
||||||
|
"If this was meant to be inferred, please explicitly pass in -1."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
newsize *= d
|
||||||
if dim is None:
|
if dim is None:
|
||||||
torch._check(
|
torch._check(
|
||||||
numel == newsize,
|
numel == newsize,
|
||||||
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
|
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from torch.fx.experimental.symbolic_shapes import definitely_true
|
|
||||||
|
|
||||||
torch._check(
|
torch._check(
|
||||||
newsize != 0,
|
newsize != 0,
|
||||||
lambda: (
|
lambda: (
|
||||||
|
|
|
||||||
|
|
@ -3717,7 +3717,8 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
|
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
|
||||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
|
from torch._dynamo.exc import UserError, UserErrorType
|
||||||
|
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
|
||||||
|
|
||||||
# Creates a valid shape
|
# Creates a valid shape
|
||||||
shape = utils.extract_shape_from_varargs(shape, validate=False)
|
shape = utils.extract_shape_from_varargs(shape, validate=False)
|
||||||
|
|
@ -3726,7 +3727,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
||||||
shape = utils.infer_size(shape, a.numel())
|
shape = utils.infer_size(shape, a.numel())
|
||||||
|
|
||||||
# Special-cases tensors with no elements
|
# Special-cases tensors with no elements
|
||||||
if guard_size_oblivious(a.numel() == 0):
|
if guard_or_false(a.numel() == 0):
|
||||||
return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
|
return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
|
||||||
|
|
||||||
# Special-cases reshaping zero dim tensors
|
# Special-cases reshaping zero dim tensors
|
||||||
|
|
@ -3762,6 +3763,12 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
||||||
return torch.as_strided(a, [dim0, dim1], [dim1, 1])
|
return torch.as_strided(a, [dim0, dim1], [dim1, 1])
|
||||||
|
|
||||||
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
|
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
|
||||||
|
shape_numel = reduce(operator.mul, shape, 1)
|
||||||
|
torch._check(
|
||||||
|
a.numel() == shape_numel,
|
||||||
|
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
|
||||||
|
)
|
||||||
|
deferred: list[Callable[[], bool]] = []
|
||||||
|
|
||||||
# NOTE [Reshape Algorithm]
|
# NOTE [Reshape Algorithm]
|
||||||
# This algorithm works by attempting to greedily construct the desired dimensions in
|
# This algorithm works by attempting to greedily construct the desired dimensions in
|
||||||
|
|
@ -3794,16 +3801,30 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skips dimensions that are already the correct length
|
# Skips dimensions that are already the correct length
|
||||||
if guard_size_oblivious(length == a_.shape[idx]):
|
if guard_or_false(length == a_.shape[idx]):
|
||||||
idx = idx + 1
|
idx = idx + 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Gathers enough original dimensions such that this new dimension can be created
|
# Gathers enough original dimensions such that this new dimension can be created
|
||||||
# Note that this accumulation will terminate because we've verified a and the shape
|
# Note that this accumulation will terminate because we've verified a and the shape
|
||||||
# specify the same number of elements above
|
# specify the same number of elements above
|
||||||
|
def maybe_throw_dde():
|
||||||
|
# NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input
|
||||||
|
# tensor dimensions to match the target shape (length), we've hit data-dependent errors testing
|
||||||
|
# divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd
|
||||||
|
# figure out a valid reshape later in the loop.
|
||||||
|
# But we failed, either by running out of dimensions, or we couldn't figure out the strides,
|
||||||
|
# and we've decided to re-raise to either graph break out, or provide the exact guard so the user
|
||||||
|
# can torch._check() to avoid this.
|
||||||
|
for f in deferred:
|
||||||
|
f()
|
||||||
|
|
||||||
accum = a_.shape[idx]
|
accum = a_.shape[idx]
|
||||||
end = idx
|
end = idx
|
||||||
while guard_size_oblivious(accum % length != 0):
|
while guard_or_true(accum % length != 0):
|
||||||
|
deferred.append(lambda: bool(accum % length != 0))
|
||||||
|
if end == a_.ndim - 1:
|
||||||
|
maybe_throw_dde()
|
||||||
end = end + 1
|
end = end + 1
|
||||||
accum = accum * a_.shape[end]
|
accum = accum * a_.shape[end]
|
||||||
if end != idx:
|
if end != idx:
|
||||||
|
|
@ -3817,13 +3838,15 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
||||||
if allow_copy:
|
if allow_copy:
|
||||||
return prims.reshape(a, shape)
|
return prims.reshape(a, shape)
|
||||||
|
|
||||||
|
maybe_throw_dde()
|
||||||
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
|
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
a_ = flatten(a_, idx, end)
|
a_ = flatten(a_, idx, end)
|
||||||
|
|
||||||
# Splits the (possibly flattened) dimension to create the desired dim length
|
# Splits the (possibly flattened) dimension to create the desired dim length.
|
||||||
if guard_size_oblivious(accum != length):
|
# guard_or_true is safe due to the tail unsqueeze routine.
|
||||||
|
if guard_or_true(accum != length):
|
||||||
a_ = prims.split_dim(a_, idx, length)
|
a_ = prims.split_dim(a_, idx, length)
|
||||||
|
|
||||||
idx = idx + 1
|
idx = idx + 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user