mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aoti] Fix cannot determine truth value of Relation error when propagating unbacked symint in lowering (#150570)
Summary: Fix cannot determine truth value of Relation error when propagating unbacked symint in lowering Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_runtime_asserts ``` Differential Revision: D72331070 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150570 Approved by: https://github.com/angelayi, https://github.com/henryoier
This commit is contained in:
parent
c1d503529d
commit
51da241c0a
|
|
@ -3369,6 +3369,56 @@ class AOTInductorTestsTemplate:
|
|||
)
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
def test_aoti_runtime_asserts(self):
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch.export._draft_export import draft_export
|
||||
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor a, Tensor b) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return a[: b.item()]
|
||||
|
||||
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
||||
def foo_fake_impl(a, b):
|
||||
ctx = torch.library.get_ctx()
|
||||
u = ctx.new_dynamic_size()
|
||||
return torch.empty(u)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
res = torch.ops.mylib.foo(a, b)
|
||||
s = res.shape[0]
|
||||
torch._check(s > 3)
|
||||
torch._check(s < a.shape[0])
|
||||
return a[s - 3]
|
||||
|
||||
example_inputs = (torch.randn(100), torch.tensor(10))
|
||||
ep = draft_export(M(), example_inputs)
|
||||
m = ep.module()
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
|
||||
example_inputs = [
|
||||
node.meta["val"] for node in m.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
fake_mode = example_inputs[0].fake_mode
|
||||
with enable_python_dispatcher(), fake_mode:
|
||||
FakeTensorProp(m, mode=fake_mode).propagate_dont_convert_inputs(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# TODO: change to the tests below after MetadataMismatchError is fixed
|
||||
# pt2_file = torch._inductor.aoti_compile_and_package(ep)
|
||||
# optimized = torch._inductor.aoti_load_package(pt2_file)
|
||||
|
||||
# self.assertTrue(same(optimized(example_inputs), m(example_inputs)))
|
||||
|
||||
def test_index_put_with_none_index(self):
|
||||
# index_put falls back in the deterministic mode
|
||||
with DeterministicGuard(True):
|
||||
|
|
|
|||
|
|
@ -2308,10 +2308,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
if (
|
||||
self.propagate_real_tensors
|
||||
and all(e.real_tensor is not None for e in flat_arg_fake_tensors)
|
||||
# TODO: Handle SymFloat/SymBool
|
||||
and not any(
|
||||
(
|
||||
isinstance(a, SymInt)
|
||||
isinstance(a, py_sym_types)
|
||||
and (syms := free_unbacked_symbols(a))
|
||||
and self.shape_env is not None
|
||||
and any(s not in self.shape_env.unbacked_var_to_val for s in syms)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user