[export] support SymInt minlength for torch.bincount() (#152497)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152497
Approved by: https://github.com/angelayi
This commit is contained in:
Pian Pawakapan 2025-05-01 00:45:53 +00:00 committed by PyTorch MergeBot
parent ad9e209ea3
commit 5521e6b671
3 changed files with 18 additions and 2 deletions

View File

@ -1195,7 +1195,7 @@
CompositeExplicitAutograd: binary_cross_entropy_with_logits
autogen: binary_cross_entropy_with_logits.out
- func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor
- func: bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor
variants: function, method
dispatch:
CPU: _bincount_cpu

View File

@ -775,6 +775,21 @@ graph():
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(gm(*args), m(*args))
def test_unbacked_bincount(self):
class Foo(torch.nn.Module):
def forward(self, xs):
u0, u1 = xs.tolist()
x = torch.ones(u0, dtype=torch.int64)
y = torch.bincount(x, minlength=u1)
return y
m = Foo()
x = torch.tensor([20, 10])
ep = export(m, (x,))
self.assertTrue(torch.allclose(ep.module()(x), m(x)))
y = torch.tensor([5, 10])
self.assertTrue(torch.allclose(ep.module()(y), m(y)))
@requires_gpu
def test_export_custom_triton_kernel(self):
@triton.jit

View File

@ -844,7 +844,8 @@ def bincount(fake_mode, func, inputs, weights=None, minlength=0):
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
_constrain_range_for_size(new_size, min=minlength)
_constrain_range_for_size(new_size)
torch._check(new_size >= minlength)
return inputs.new_empty(new_size)