mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] support SymInt minlength for torch.bincount() (#152497)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152497 Approved by: https://github.com/angelayi
This commit is contained in:
parent
ad9e209ea3
commit
5521e6b671
|
|
@ -1195,7 +1195,7 @@
|
|||
CompositeExplicitAutograd: binary_cross_entropy_with_logits
|
||||
autogen: binary_cross_entropy_with_logits.out
|
||||
|
||||
- func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor
|
||||
- func: bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: _bincount_cpu
|
||||
|
|
|
|||
|
|
@ -775,6 +775,21 @@ graph():
|
|||
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
|
||||
self.assertEqual(gm(*args), m(*args))
|
||||
|
||||
def test_unbacked_bincount(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, xs):
|
||||
u0, u1 = xs.tolist()
|
||||
x = torch.ones(u0, dtype=torch.int64)
|
||||
y = torch.bincount(x, minlength=u1)
|
||||
return y
|
||||
|
||||
m = Foo()
|
||||
x = torch.tensor([20, 10])
|
||||
ep = export(m, (x,))
|
||||
self.assertTrue(torch.allclose(ep.module()(x), m(x)))
|
||||
y = torch.tensor([5, 10])
|
||||
self.assertTrue(torch.allclose(ep.module()(y), m(y)))
|
||||
|
||||
@requires_gpu
|
||||
def test_export_custom_triton_kernel(self):
|
||||
@triton.jit
|
||||
|
|
|
|||
|
|
@ -844,7 +844,8 @@ def bincount(fake_mode, func, inputs, weights=None, minlength=0):
|
|||
|
||||
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
|
||||
|
||||
_constrain_range_for_size(new_size, min=minlength)
|
||||
_constrain_range_for_size(new_size)
|
||||
torch._check(new_size >= minlength)
|
||||
return inputs.new_empty(new_size)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user