[aoti] Fix cannot determine truth value of Relation error when propagating unbacked symint in lowering (#150570)

Summary: Fix  cannot determine truth value of Relation error when propagating unbacked symint in lowering

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_runtime_asserts
```

Differential Revision: D72331070

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150570
Approved by: https://github.com/angelayi, https://github.com/henryoier
This commit is contained in:
Shangdi Yu 2025-04-03 20:06:12 +00:00 committed by PyTorch MergeBot
parent c1d503529d
commit 51da241c0a
2 changed files with 51 additions and 2 deletions

View File

@ -3369,6 +3369,56 @@ class AOTInductorTestsTemplate:
)
self.check_model(Model(), example_inputs)
def test_aoti_runtime_asserts(self):
from torch._dispatch.python import enable_python_dispatcher
from torch.export._draft_export import draft_export
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a[: b.item()]
@torch.library.impl_abstract("mylib::foo", lib=lib)
def foo_fake_impl(a, b):
ctx = torch.library.get_ctx()
u = ctx.new_dynamic_size()
return torch.empty(u)
class M(torch.nn.Module):
def forward(self, a, b):
res = torch.ops.mylib.foo(a, b)
s = res.shape[0]
torch._check(s > 3)
torch._check(s < a.shape[0])
return a[s - 3]
example_inputs = (torch.randn(100), torch.tensor(10))
ep = draft_export(M(), example_inputs)
m = ep.module()
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
example_inputs = [
node.meta["val"] for node in m.graph.nodes if node.op == "placeholder"
]
fake_mode = example_inputs[0].fake_mode
with enable_python_dispatcher(), fake_mode:
FakeTensorProp(m, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)
# TODO: change to the tests below after MetadataMismatchError is fixed
# pt2_file = torch._inductor.aoti_compile_and_package(ep)
# optimized = torch._inductor.aoti_load_package(pt2_file)
# self.assertTrue(same(optimized(example_inputs), m(example_inputs)))
def test_index_put_with_none_index(self):
# index_put falls back in the deterministic mode
with DeterministicGuard(True):

View File

@ -2308,10 +2308,9 @@ class FakeTensorMode(TorchDispatchMode):
if (
self.propagate_real_tensors
and all(e.real_tensor is not None for e in flat_arg_fake_tensors)
# TODO: Handle SymFloat/SymBool
and not any(
(
isinstance(a, SymInt)
isinstance(a, py_sym_types)
and (syms := free_unbacked_symbols(a))
and self.shape_env is not None
and any(s not in self.shape_env.unbacked_var_to_val for s in syms)