[Inductor] Fix the decompositions of torch isin (#147519)

**Summary**
Fixed two decomposition issues in `torch.isin`:

- Issue 1: As reported in [#147329](https://github.com/pytorch/pytorch/issues/147329), the current decomposition does not support cases where test_element is a scalar. This is now implemented by referring to the ead970c8d0/aten/src/ATen/native/TensorCompare.cpp (L1004-L1008)

- Issue 2: Found while enabling a unit test with `elements = 1` and `test_elements = torch.tensor([1, 2, 3, 4])`, where Inductor produced different results compared to eager mode. This issue is fixed by referring to ead970c8d0/aten/src/ATen/native/cpu/TensorCompareKernel.cpp (L329-L338)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_isin_tensor_scalar
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147519
Approved by: https://github.com/jgong5, https://github.com/FFFrog, https://github.com/peterbell10
This commit is contained in:
leslie-fang-intel 2025-02-21 17:35:46 -08:00 committed by PyTorch MergeBot
parent 2c8cd41c1f
commit c644f4c5fe
2 changed files with 17 additions and 7 deletions

View File

@ -12253,6 +12253,17 @@ class CommonTemplate:
self.common(forward, (a, b))
def test_isin_tensor_scalar(self):
for invert in [True, False]:
torch._dynamo.reset()
elements = 1
test_elements = torch.tensor([1, 2, 3, 4])
self.common(torch.isin, (elements, test_elements), {"invert": invert})
torch._dynamo.reset()
elements = torch.tensor([1, 2, 3, 4])
test_elements = 1
self.common(torch.isin, (elements, test_elements), {"invert": invert})
def test_mul_index_expr(self):
# Minified repro from https://github.com/pytorch/pytorch/issues/111884
def forward():

View File

@ -5091,7 +5091,10 @@ def isin(elements, test_elements, *, assume_unique=False, invert=False):
if not isinstance(elements, torch.Tensor):
elements = torch.tensor(elements, device=test_elements.device)
if not isinstance(test_elements, torch.Tensor):
test_elements = torch.tensor(test_elements, device=elements.device)
if invert:
return torch.ne(elements, test_elements)
else:
return torch.eq(elements, test_elements)
if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145):
return isin_default(elements, test_elements, invert=invert)
@ -5123,14 +5126,10 @@ def bernoulli(
def isin_default(elements, test_elements, *, invert=False):
if elements.numel() == 0:
return torch.empty_like(elements, dtype=torch.bool)
x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
if not invert:
cmp = x == test_elements
else:
cmp = x != test_elements
dim = tuple(range(-1, -test_elements.ndim - 1, -1))
return cmp.any(dim=dim)
res = (x == test_elements).any(dim=dim)
return ~res if invert else res
def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False):