Update unbacked symints in torch.nonzero more precisely (#137663)

### Summary
The fake impl for `nonzero` sets the symint's upper range to `sys.maxsize - 1` if there are any SymInts in the original input tensor shape. This PR constrains the range more intelligently by using the upper ranges of each SymInt in the input tensor shape.

See https://github.com/pytorch/pytorch/pull/134899 as a merged solution for a similar problem for a different op.

### Test plan
Added unit test to verify upper bound reduction calculation (`python test/export/test_export.py TestExport.test_nonzero_dynamic`)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137663
Approved by: https://github.com/ezyang
This commit is contained in:
Jack Zhang 2024-10-28 20:57:20 +00:00 committed by PyTorch MergeBot
parent 8fa0479dd8
commit dd688099af
2 changed files with 71 additions and 0 deletions

View File

@ -619,6 +619,61 @@ graph():
for vr_upper in vr_upper_bounds:
self.assertTrue(vr_upper <= expected_upper_bound)
def test_nonzero_dynamic(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor, as_tuple: bool) -> torch.Tensor:
return torch.nonzero(x, as_tuple=as_tuple)
# Case 1 and 2: as_tuple is True and as_tuple is False.
for as_tuple in [True, False]:
example_args = (torch.randn(3, 4, 5), as_tuple)
dim0_x_max, dim1_x_max = 100, 7
dynamic_shapes = {
"x": {
0: Dim("dim0_x", max=dim0_x_max),
1: Dim("dim1_x_max", max=dim1_x_max),
},
"as_tuple": None,
}
m = M()
exported_program: torch.export.ExportedProgram = export(
m, args=example_args, dynamic_shapes=dynamic_shapes
)
# Test that the expected upper bound is among the range constraints.
expected_upper_bound = dim0_x_max * dim1_x_max * 5
vr_upper_bounds = [
vr.upper for vr in exported_program.range_constraints.values()
]
self.assertTrue(expected_upper_bound in set(vr_upper_bounds))
# Test that none of the upper bounds are larger.
for vr_upper in vr_upper_bounds:
self.assertTrue(vr_upper <= expected_upper_bound)
# Case 3: Test special case when input has zero dimensions and a nonzero
# scalar value.
example_args = (torch.tensor(10), as_tuple)
dim0_x_max = 100
dynamic_shapes = {
"x": None,
"as_tuple": None,
}
m = M()
exported_program: torch.export.ExportedProgram = export(
m, args=example_args, dynamic_shapes=dynamic_shapes
)
# Test that the expected upper bound is equal to 1, since our output
# for this edge case should always be a tensor of size 1.
vr_upper_bounds = [
vr.upper for vr in exported_program.range_constraints.values()
]
for vr_upper in vr_upper_bounds:
self.assertEqual(vr_upper, 1)
def test_setgrad_lifted_tensor(self):
class M(torch.nn.Module):
def forward(self, x, y):

View File

@ -396,6 +396,11 @@ def local_scalar_dense(fake_mode, func, arg):
return r
@register_op_impl(torch.ops.aten.nonzero_numpy.default)
def nonzero_numpy(fake_mode, func, arg):
return torch.ops.aten.nonzero.default(arg).unbind(1)
@register_op_impl(torch.ops.aten.nonzero.default)
def nonzero(fake_mode, func, arg):
if (
@ -411,6 +416,8 @@ def nonzero(fake_mode, func, arg):
_constrain_range_for_size,
has_free_symbols,
)
from torch.utils._sympy.numbers import IntInfinity
from torch.utils._sympy.value_ranges import bound_sympy
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
# If numel is zero, then the output size must be zero.
@ -429,6 +436,15 @@ def nonzero(fake_mode, func, arg):
if not has_free_symbols(arg.numel()):
maxval = int(arg.numel())
else:
prod_node = math.prod(arg.shape).node
prod_range = bound_sympy(
prod_node.expr, prod_node.shape_env.var_to_range
)
if isinstance(prod_range.upper, IntInfinity):
maxval = sys.maxsize - 1
else:
maxval = prod_range.upper
_constrain_range_for_size(nnz, max=maxval)