mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update unbacked symints in torch.nonzero more precisely (#137663)
### Summary The fake impl for `nonzero` sets the symint's upper range to `sys.maxsize - 1` if there are any SymInts in the original input tensor shape. This PR constrains the range more intelligently by using the upper ranges of each SymInt in the input tensor shape. See https://github.com/pytorch/pytorch/pull/134899 as a merged solution for a similar problem for a different op. ### Test plan Added unit test to verify upper bound reduction calculation (`python test/export/test_export.py TestExport.test_nonzero_dynamic`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137663 Approved by: https://github.com/ezyang
This commit is contained in:
parent
8fa0479dd8
commit
dd688099af
|
|
@ -619,6 +619,61 @@ graph():
|
|||
for vr_upper in vr_upper_bounds:
|
||||
self.assertTrue(vr_upper <= expected_upper_bound)
|
||||
|
||||
def test_nonzero_dynamic(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, as_tuple: bool) -> torch.Tensor:
|
||||
return torch.nonzero(x, as_tuple=as_tuple)
|
||||
|
||||
# Case 1 and 2: as_tuple is True and as_tuple is False.
|
||||
for as_tuple in [True, False]:
|
||||
example_args = (torch.randn(3, 4, 5), as_tuple)
|
||||
dim0_x_max, dim1_x_max = 100, 7
|
||||
dynamic_shapes = {
|
||||
"x": {
|
||||
0: Dim("dim0_x", max=dim0_x_max),
|
||||
1: Dim("dim1_x_max", max=dim1_x_max),
|
||||
},
|
||||
"as_tuple": None,
|
||||
}
|
||||
m = M()
|
||||
exported_program: torch.export.ExportedProgram = export(
|
||||
m, args=example_args, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
|
||||
# Test that the expected upper bound is among the range constraints.
|
||||
expected_upper_bound = dim0_x_max * dim1_x_max * 5
|
||||
vr_upper_bounds = [
|
||||
vr.upper for vr in exported_program.range_constraints.values()
|
||||
]
|
||||
self.assertTrue(expected_upper_bound in set(vr_upper_bounds))
|
||||
# Test that none of the upper bounds are larger.
|
||||
for vr_upper in vr_upper_bounds:
|
||||
self.assertTrue(vr_upper <= expected_upper_bound)
|
||||
|
||||
# Case 3: Test special case when input has zero dimensions and a nonzero
|
||||
# scalar value.
|
||||
example_args = (torch.tensor(10), as_tuple)
|
||||
dim0_x_max = 100
|
||||
dynamic_shapes = {
|
||||
"x": None,
|
||||
"as_tuple": None,
|
||||
}
|
||||
m = M()
|
||||
exported_program: torch.export.ExportedProgram = export(
|
||||
m, args=example_args, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
|
||||
# Test that the expected upper bound is equal to 1, since our output
|
||||
# for this edge case should always be a tensor of size 1.
|
||||
vr_upper_bounds = [
|
||||
vr.upper for vr in exported_program.range_constraints.values()
|
||||
]
|
||||
for vr_upper in vr_upper_bounds:
|
||||
self.assertEqual(vr_upper, 1)
|
||||
|
||||
def test_setgrad_lifted_tensor(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
|
|
|||
|
|
@ -396,6 +396,11 @@ def local_scalar_dense(fake_mode, func, arg):
|
|||
return r
|
||||
|
||||
|
||||
@register_op_impl(torch.ops.aten.nonzero_numpy.default)
|
||||
def nonzero_numpy(fake_mode, func, arg):
|
||||
return torch.ops.aten.nonzero.default(arg).unbind(1)
|
||||
|
||||
|
||||
@register_op_impl(torch.ops.aten.nonzero.default)
|
||||
def nonzero(fake_mode, func, arg):
|
||||
if (
|
||||
|
|
@ -411,6 +416,8 @@ def nonzero(fake_mode, func, arg):
|
|||
_constrain_range_for_size,
|
||||
has_free_symbols,
|
||||
)
|
||||
from torch.utils._sympy.numbers import IntInfinity
|
||||
from torch.utils._sympy.value_ranges import bound_sympy
|
||||
|
||||
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
|
||||
# If numel is zero, then the output size must be zero.
|
||||
|
|
@ -429,6 +436,15 @@ def nonzero(fake_mode, func, arg):
|
|||
|
||||
if not has_free_symbols(arg.numel()):
|
||||
maxval = int(arg.numel())
|
||||
else:
|
||||
prod_node = math.prod(arg.shape).node
|
||||
prod_range = bound_sympy(
|
||||
prod_node.expr, prod_node.shape_env.var_to_range
|
||||
)
|
||||
if isinstance(prod_range.upper, IntInfinity):
|
||||
maxval = sys.maxsize - 1
|
||||
else:
|
||||
maxval = prod_range.upper
|
||||
|
||||
_constrain_range_for_size(nnz, max=maxval)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user