Introducing array-like sequence methods __contains__ (#17733)

Summary:
for tensor

Fixes: #17000
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17733

Differential Revision: D14401952

Pulled By: soumith

fbshipit-source-id: c841b128c5a1fceda1094323ed4ef1d0cf494909
This commit is contained in:
bhushan 2019-03-11 08:55:01 -07:00 committed by Facebook Github Bot
parent 906f9efc57
commit b57fe3cc66
2 changed files with 23 additions and 0 deletions

View File

@ -8151,6 +8151,17 @@ class _TestTorchMixin(object):
val = torch.tensor(42)
self.assertEqual(reversed(val), torch.tensor(42))
def test_contains(self):
x = torch.arange(0, 10)
self.assertEqual(4 in x, True)
self.assertEqual(12 in x, False)
x = torch.arange(1, 10).view(3, 3)
val = torch.arange(1, 4)
self.assertEqual(val in x, True)
val += 10
self.assertEqual(val in x, False)
@staticmethod
def _test_rot90(self, use_cuda=False):
device = torch.device("cuda" if use_cuda else "cpu")

View File

@ -7,6 +7,7 @@ import warnings
import weakref
from torch._six import imap
from torch._C import _add_docstr
from numbers import Number
# NB: If you subclass Tensor, and want to share the subclassed class
@ -426,6 +427,17 @@ class Tensor(torch._C._TensorBase):
array = array.astype('uint8')
return torch.from_numpy(array)
def __contains__(self, element):
r"""Check if `element` is present in tensor
Arguments:
element (Tensor or scalar): element to be checked
for presence in current tensor"
"""
if isinstance(element, (torch.Tensor, Number)):
return (element == self).any().item()
return NotImplemented
@property
def __cuda_array_interface__(self):
"""Array view description for cuda tensors.