mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Bail on checking internal overlap when dealing with unbacked symints (#145385)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145385 Approved by: https://github.com/ezyang
This commit is contained in:
parent
e1407f5aeb
commit
6f07847efe
|
|
@ -12,12 +12,22 @@ MemOverlap has_internal_overlap(const TensorBase& tensor) {
|
|||
MemOverlap has_internal_overlap(TensorImpl* t) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided);
|
||||
|
||||
auto sizes = t->sym_sizes();
|
||||
auto strides = t->sym_strides();
|
||||
|
||||
// When we have unbacked symint strides, is_non_overlapping_and_dense
|
||||
// often results in guard on data dependent errors. For now
|
||||
// let us bail early if there are unbacked symint strides.
|
||||
for (const auto i : c10::irange(strides.size())) {
|
||||
if (!strides[i].has_hint()) {
|
||||
return MemOverlap::TooHard;
|
||||
}
|
||||
}
|
||||
|
||||
if (t->is_non_overlapping_and_dense()) {
|
||||
return MemOverlap::No;
|
||||
}
|
||||
|
||||
auto strides = t->sym_strides();
|
||||
auto sizes = t->sym_sizes();
|
||||
for (const auto i : c10::irange(strides.size())) {
|
||||
// NB: The size oblivious test is written very carefully here. When
|
||||
// unbacked SymInts are involved, we should try to conservatively report
|
||||
|
|
|
|||
|
|
@ -7614,6 +7614,26 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
opt = torch.compile(fn, fullgraph=True)
|
||||
opt(*inputs)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
@torch._dynamo.config.patch(assume_static_by_default=True)
|
||||
def test_symint_copy_into_unbacked_slice(self):
|
||||
@torch.compile()
|
||||
def fn(a, x):
|
||||
u0 = torch.tensor(x[0].to(torch.int64).item()).item()
|
||||
B, H, T, D = a.shape
|
||||
a_padding = torch.zeros((B, H, u0, D), dtype=torch.float64)
|
||||
b = torch.cat([a, a_padding], dim=2)
|
||||
c = torch.randn(B, H, 152, D)
|
||||
b[:, :, :152, :] = c
|
||||
return b
|
||||
|
||||
x = torch.tensor([0])
|
||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||
a = torch.zeros((1, 16, 152, 96))
|
||||
|
||||
# Previously would crash with guard on data dependent error
|
||||
fn(a, x)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_symint_fold_nontrivial_product_modulo(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
|
|
|
|||
|
|
@ -1066,7 +1066,7 @@ def forward(self, x_1):
|
|||
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
|
||||
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
|
||||
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
|
||||
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0)
|
||||
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2)
|
||||
Max = torch.sym_max
|
||||
self.assertEqual(
|
||||
cf(
|
||||
|
|
@ -1076,7 +1076,7 @@ def forward(self, x_1):
|
|||
device="meta",
|
||||
)
|
||||
),
|
||||
0,
|
||||
2,
|
||||
)
|
||||
|
||||
# Wobbling these to zero is OK too
|
||||
|
|
|
|||
|
|
@ -1009,7 +1009,7 @@ def squeeze(x, dim=None):
|
|||
for d, s in enumerate(x.get_size()):
|
||||
if not (
|
||||
d in dims
|
||||
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1, size_oblivious=True))
|
||||
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
|
||||
):
|
||||
new_shape.append(s)
|
||||
|
||||
|
|
|
|||
|
|
@ -455,9 +455,15 @@ class SizeVarAllocator:
|
|||
# as this will ensure that you actually have a sympy'ified expression,
|
||||
# and will prevent you from incorrectly writing evaluate_expr(a == b)
|
||||
# which does the wrong thing if a or b is a sympy expression
|
||||
def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:
|
||||
def evaluate_expr(
|
||||
self,
|
||||
left: Union[Expr, sympy.logic.boolalg.Boolean],
|
||||
size_oblivious: bool = False,
|
||||
) -> bool:
|
||||
assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
|
||||
return self.shape_env.evaluate_expr(sympy.sympify(left))
|
||||
return self.shape_env.evaluate_expr(
|
||||
sympy.sympify(left), size_oblivious=size_oblivious
|
||||
)
|
||||
|
||||
def evaluate_min(self, left: Expr, right: Expr) -> Expr:
|
||||
"""return the smaller of left and right, and guard on that choice"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user