Bail on checking internal overlap when dealing with unbacked symints (#145385)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145385
Approved by: https://github.com/ezyang
This commit is contained in:
bobrenjc93 2025-01-23 18:07:12 +00:00 committed by PyTorch MergeBot
parent e1407f5aeb
commit 6f07847efe
5 changed files with 43 additions and 7 deletions

View File

@ -12,12 +12,22 @@ MemOverlap has_internal_overlap(const TensorBase& tensor) {
MemOverlap has_internal_overlap(TensorImpl* t) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided);
auto sizes = t->sym_sizes();
auto strides = t->sym_strides();
// When we have unbacked symint strides, is_non_overlapping_and_dense
// often results in guard on data dependent errors. For now
// let us bail early if there are unbacked symint strides.
for (const auto i : c10::irange(strides.size())) {
if (!strides[i].has_hint()) {
return MemOverlap::TooHard;
}
}
if (t->is_non_overlapping_and_dense()) {
return MemOverlap::No;
}
auto strides = t->sym_strides();
auto sizes = t->sym_sizes();
for (const auto i : c10::irange(strides.size())) {
// NB: The size oblivious test is written very carefully here. When
// unbacked SymInts are involved, we should try to conservatively report

View File

@ -7614,6 +7614,26 @@ utils_device.CURRENT_DEVICE == None""".split(
opt = torch.compile(fn, fullgraph=True)
opt(*inputs)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@torch._dynamo.config.patch(assume_static_by_default=True)
def test_symint_copy_into_unbacked_slice(self):
@torch.compile()
def fn(a, x):
u0 = torch.tensor(x[0].to(torch.int64).item()).item()
B, H, T, D = a.shape
a_padding = torch.zeros((B, H, u0, D), dtype=torch.float64)
b = torch.cat([a, a_padding], dim=2)
c = torch.randn(B, H, 152, D)
b[:, :, :152, :] = c
return b
x = torch.tensor([0])
torch._dynamo.decorators.mark_unbacked(x, 0)
a = torch.zeros((1, 16, 152, 96))
# Previously would crash with guard on data dependent error
fn(a, x)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_symint_fold_nontrivial_product_modulo(self):
@torch.compile(fullgraph=True)

View File

@ -1066,7 +1066,7 @@ def forward(self, x_1):
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2)
Max = torch.sym_max
self.assertEqual(
cf(
@ -1076,7 +1076,7 @@ def forward(self, x_1):
device="meta",
)
),
0,
2,
)
# Wobbling these to zero is OK too

View File

@ -1009,7 +1009,7 @@ def squeeze(x, dim=None):
for d, s in enumerate(x.get_size()):
if not (
d in dims
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1, size_oblivious=True))
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
):
new_shape.append(s)

View File

@ -455,9 +455,15 @@ class SizeVarAllocator:
# as this will ensure that you actually have a sympy'ified expression,
# and will prevent you from incorrectly writing evaluate_expr(a == b)
# which does the wrong thing if a or b is a sympy expression
def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:
def evaluate_expr(
self,
left: Union[Expr, sympy.logic.boolalg.Boolean],
size_oblivious: bool = False,
) -> bool:
assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
return self.shape_env.evaluate_expr(sympy.sympify(left))
return self.shape_env.evaluate_expr(
sympy.sympify(left), size_oblivious=size_oblivious
)
def evaluate_min(self, left: Expr, right: Expr) -> Expr:
"""return the smaller of left and right, and guard on that choice"""