mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Skip nonzero unbacked SymInt memo in inference mode (#122147)
Summary: In `torch.inference_mode()`, fake tensors don't have `_version`s. This breaks unbacked SymInt memoization in `torch.nonzero` tracing. Here we disable the latter in inference mode. Fixes https://github.com/pytorch/pytorch/issues/122127 Test Plan: ``` $ python test/inductor/test_unbacked_symints.py -k test_nonzero_in_inference_mode ... ---------------------------------------------------------------------- Ran 2 tests in 14.060s OK ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/122147 Approved by: https://github.com/ezyang
This commit is contained in:
parent
15a8185cd3
commit
2e02e1efad
|
|
@ -142,6 +142,20 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipCUDAIf(not HAS_CUDA, "requires cuda")
|
||||
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||
def test_nonzero_in_inference_mode(self, device):
|
||||
def fn(x):
|
||||
return torch.nonzero(x)
|
||||
|
||||
example_inputs = (torch.randint(0, 2, (128,), device=device),)
|
||||
|
||||
with torch.inference_mode():
|
||||
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
|
||||
expected = fn(*example_inputs)
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu")
|
||||
|
|
|
|||
|
|
@ -301,7 +301,9 @@ def nonzero(fake_mode, func, arg):
|
|||
# Without symints/symfloats, cannot handle this
|
||||
raise DynamicOutputShapeException(func)
|
||||
|
||||
if arg.nonzero_memo is None:
|
||||
if arg.nonzero_memo is not None:
|
||||
nnz = arg.nonzero_memo
|
||||
else:
|
||||
nnz = fake_mode.shape_env.create_unbacked_symint()
|
||||
|
||||
# This is unsound, but it works well in practice
|
||||
|
|
@ -330,10 +332,12 @@ def nonzero(fake_mode, func, arg):
|
|||
|
||||
_constrain_range_for_size(nnz, max=maxval)
|
||||
|
||||
arg._nonzero_memo = nnz
|
||||
arg._nonzero_memo_vc = arg._version
|
||||
if not torch.is_inference_mode_enabled():
|
||||
# arg._version N/A in inference mode
|
||||
arg._nonzero_memo = nnz
|
||||
arg._nonzero_memo_vc = arg._version
|
||||
|
||||
return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64)
|
||||
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
|
||||
|
||||
|
||||
@register_op_impl(torch.ops.aten.masked_select.default)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user