Skip nonzero unbacked SymInt memo in inference mode (#122147)

Summary: In `torch.inference_mode()`, fake tensors don't have `_version`s. This breaks unbacked SymInt memoization in `torch.nonzero` tracing. Here we disable the latter in inference mode.

Fixes https://github.com/pytorch/pytorch/issues/122127

Test Plan:

```
$ python test/inductor/test_unbacked_symints.py -k test_nonzero_in_inference_mode
...
----------------------------------------------------------------------
Ran 2 tests in 14.060s

OK
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122147
Approved by: https://github.com/ezyang
This commit is contained in:
Adnan Akhundov 2024-03-19 22:53:05 -07:00 committed by PyTorch MergeBot
parent 15a8185cd3
commit 2e02e1efad
2 changed files with 22 additions and 4 deletions

View File

@ -142,6 +142,20 @@ class TestUnbackedSymints(InductorTestCase):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipCUDAIf(not HAS_CUDA, "requires cuda")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_nonzero_in_inference_mode(self, device):
def fn(x):
return torch.nonzero(x)
example_inputs = (torch.randint(0, 2, (128,), device=device),)
with torch.inference_mode():
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
instantiate_device_type_tests(
TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu")

View File

@ -301,7 +301,9 @@ def nonzero(fake_mode, func, arg):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
if arg.nonzero_memo is None:
if arg.nonzero_memo is not None:
nnz = arg.nonzero_memo
else:
nnz = fake_mode.shape_env.create_unbacked_symint()
# This is unsound, but it works well in practice
@ -330,10 +332,12 @@ def nonzero(fake_mode, func, arg):
_constrain_range_for_size(nnz, max=maxval)
arg._nonzero_memo = nnz
arg._nonzero_memo_vc = arg._version
if not torch.is_inference_mode_enabled():
# arg._version N/A in inference mode
arg._nonzero_memo = nnz
arg._nonzero_memo_vc = arg._version
return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64)
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
@register_op_impl(torch.ops.aten.masked_select.default)