mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Don't attempt to compute hints for unbacked expressions (#132060)
This breaks the inference we made that if you cat an N-D tensor with a 1-D tensor of size (u0,), the u0 must be zero, but no one really wanted that anyway... Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132060 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
8fff976355
commit
fc32732596
|
|
@ -8755,7 +8755,9 @@ def ___make_guard_fn():
|
||||||
z = y.item()
|
z = y.item()
|
||||||
return torch.cat([x, torch.ones(z)])
|
return torch.cat([x, torch.ones(z)])
|
||||||
|
|
||||||
fn(torch.randn(2, 3), torch.tensor([0]))
|
self.assertRaises(
|
||||||
|
RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([0]))
|
||||||
|
)
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([1]))
|
RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([1]))
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5411,13 +5411,17 @@ class CommonTemplate:
|
||||||
z = y.item()
|
z = y.item()
|
||||||
return torch.cat([x, x.new_ones(z)])
|
return torch.cat([x, x.new_ones(z)])
|
||||||
|
|
||||||
self.common(
|
with self.assertRaisesRegex(
|
||||||
fn,
|
RuntimeError,
|
||||||
(
|
"Expected 2-D tensors, but got 1-D for tensor number 1 in the list",
|
||||||
torch.randn([2, 3]),
|
):
|
||||||
torch.tensor([0]),
|
self.common(
|
||||||
),
|
fn,
|
||||||
)
|
(
|
||||||
|
torch.randn([2, 3]),
|
||||||
|
torch.tensor([0]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_cat_unbacked_empty_1d(self):
|
def test_cat_unbacked_empty_1d(self):
|
||||||
|
|
|
||||||
|
|
@ -2780,7 +2780,17 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
|
||||||
assert tensor.ndim == 1 # we've already checked this above
|
assert tensor.ndim == 1 # we've already checked this above
|
||||||
# Don't suggest the legacy behavior in the error message
|
# Don't suggest the legacy behavior in the error message
|
||||||
torch._check(
|
torch._check(
|
||||||
tensor.shape[0] == 0,
|
# NB: it is not enough to simply assert that tensor.shape[0] == 0;
|
||||||
|
# this MUST be true even under guard size oblivious.
|
||||||
|
# Effectively, we must actually know that the shape is zero,
|
||||||
|
# passing an unbacked SymInt which we will defer a runtime
|
||||||
|
# assert on won't cut it. This is a policy decision (size
|
||||||
|
# oblivious semantics say that u0 tensors never are inferred
|
||||||
|
# to be zero size, even if they must be that for the cat to go
|
||||||
|
# through), and is load bearing for our Inductor lowerings
|
||||||
|
# (which assume that size oblivious tests are OK to determine
|
||||||
|
# if a shape is permissibly zero.)
|
||||||
|
guard_size_oblivious(tensor.shape[0] == 0),
|
||||||
lambda: f"Number of dimensions of tensors must match. "
|
lambda: f"Number of dimensions of tensors must match. "
|
||||||
f"Expected {example.ndim}-D tensors, but got 1-D for "
|
f"Expected {example.ndim}-D tensors, but got 1-D for "
|
||||||
f"tensor number {tensor_idx} in the list",
|
f"tensor number {tensor_idx} in the list",
|
||||||
|
|
|
||||||
|
|
@ -110,9 +110,15 @@ class SymNode:
|
||||||
# in sync, so we've deleted it for now.)
|
# in sync, so we've deleted it for now.)
|
||||||
|
|
||||||
def compute_hint():
|
def compute_hint():
|
||||||
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||||
|
|
||||||
# This occasionally gets exercised by, e.g.,
|
# This occasionally gets exercised by, e.g.,
|
||||||
# convert_shape_to_symint. It's just a nicety so you don't HAVE
|
# convert_shape_to_symint. It's just a nicety so you don't HAVE
|
||||||
# to have a correct hint on hand when making a SymNode.
|
# to have a correct hint on hand when making a SymNode.
|
||||||
|
# Don't attempt to compute for unbacked, this can be quite
|
||||||
|
# expensive.
|
||||||
|
if free_unbacked_symbols(self.expr):
|
||||||
|
return None
|
||||||
hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
|
hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
|
||||||
if hint is not None:
|
if hint is not None:
|
||||||
hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
|
hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user