[dynamic shapes] make backed_size_oblivious behavior consistent b/w symbolic_shapes/inductor (#164796)

Summary: call guard_or_ directly to enable backed_size_obl in inductor calls to guard_or

Test Plan: CI and unit test added.

Differential Revision: D84009392

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164796
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan 2025-10-08 00:19:03 +00:00 committed by PyTorch MergeBot
parent e98c4e835b
commit bd3b98a8a5
2 changed files with 23 additions and 0 deletions

View File

@ -1426,6 +1426,15 @@ class f(torch.nn.Module):
f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
)
@fresh_cache()
def test_slice_backed_size_oblivious(self):
@torch.compile(backend="inductor", fullgraph=True, dynamic=True)
def f(x):
return x[:5]
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
f(torch.randn(10, 10))
def test_baddbmm_symint(self):
from torch._subclasses.fake_tensor import FakeTensorMode

View File

@ -453,9 +453,23 @@ class SizeVarAllocator:
# Similar to the functions guard_or_false/guard_or_true in symbolic_shapes.py
# but operates on sympy expressions instead of symnodes. see Note [guard_or_].
def guard_or_false(self, left):
import torch.fx.experimental._config as exp_config
if exp_config.backed_size_oblivious:
static_val = self.shape_env._maybe_evaluate_static(left)
if static_val is not None:
return static_val
return False
return self.evaluate_expr(left, fallback_value=False)
def guard_or_true(self, left):
import torch.fx.experimental._config as exp_config
if exp_config.backed_size_oblivious:
static_val = self.shape_env._maybe_evaluate_static(left)
if static_val is not None:
return static_val
return True
return self.evaluate_expr(left, fallback_value=True)
# The evaluate functions evaluate some symbolic sympy expression