mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Fix the decompositions of torch isin (#147519)
**Summary** Fixed two decomposition issues in `torch.isin`: - Issue 1: As reported in [#147329](https://github.com/pytorch/pytorch/issues/147329), the current decomposition does not support cases where test_element is a scalar. This is now implemented by referring to theead970c8d0/aten/src/ATen/native/TensorCompare.cpp (L1004-L1008)- Issue 2: Found while enabling a unit test with `elements = 1` and `test_elements = torch.tensor([1, 2, 3, 4])`, where Inductor produced different results compared to eager mode. This issue is fixed by referring toead970c8d0/aten/src/ATen/native/cpu/TensorCompareKernel.cpp (L329-L338)**Test Plan** ``` python test/inductor/test_torchinductor.py -k test_isin_tensor_scalar ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/147519 Approved by: https://github.com/jgong5, https://github.com/FFFrog, https://github.com/peterbell10
This commit is contained in:
parent
2c8cd41c1f
commit
c644f4c5fe
|
|
@ -12253,6 +12253,17 @@ class CommonTemplate:
|
|||
|
||||
self.common(forward, (a, b))
|
||||
|
||||
def test_isin_tensor_scalar(self):
|
||||
for invert in [True, False]:
|
||||
torch._dynamo.reset()
|
||||
elements = 1
|
||||
test_elements = torch.tensor([1, 2, 3, 4])
|
||||
self.common(torch.isin, (elements, test_elements), {"invert": invert})
|
||||
torch._dynamo.reset()
|
||||
elements = torch.tensor([1, 2, 3, 4])
|
||||
test_elements = 1
|
||||
self.common(torch.isin, (elements, test_elements), {"invert": invert})
|
||||
|
||||
def test_mul_index_expr(self):
|
||||
# Minified repro from https://github.com/pytorch/pytorch/issues/111884
|
||||
def forward():
|
||||
|
|
|
|||
|
|
@ -5091,7 +5091,10 @@ def isin(elements, test_elements, *, assume_unique=False, invert=False):
|
|||
if not isinstance(elements, torch.Tensor):
|
||||
elements = torch.tensor(elements, device=test_elements.device)
|
||||
if not isinstance(test_elements, torch.Tensor):
|
||||
test_elements = torch.tensor(test_elements, device=elements.device)
|
||||
if invert:
|
||||
return torch.ne(elements, test_elements)
|
||||
else:
|
||||
return torch.eq(elements, test_elements)
|
||||
|
||||
if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145):
|
||||
return isin_default(elements, test_elements, invert=invert)
|
||||
|
|
@ -5123,14 +5126,10 @@ def bernoulli(
|
|||
def isin_default(elements, test_elements, *, invert=False):
|
||||
if elements.numel() == 0:
|
||||
return torch.empty_like(elements, dtype=torch.bool)
|
||||
|
||||
x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
|
||||
if not invert:
|
||||
cmp = x == test_elements
|
||||
else:
|
||||
cmp = x != test_elements
|
||||
dim = tuple(range(-1, -test_elements.ndim - 1, -1))
|
||||
return cmp.any(dim=dim)
|
||||
res = (x == test_elements).any(dim=dim)
|
||||
return ~res if invert else res
|
||||
|
||||
|
||||
def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user