mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Introducing array-like sequence methods __contains__ (#17733)
Summary: for tensor Fixes: #17000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17733 Differential Revision: D14401952 Pulled By: soumith fbshipit-source-id: c841b128c5a1fceda1094323ed4ef1d0cf494909
This commit is contained in:
parent
906f9efc57
commit
b57fe3cc66
|
|
@ -8151,6 +8151,17 @@ class _TestTorchMixin(object):
|
|||
val = torch.tensor(42)
|
||||
self.assertEqual(reversed(val), torch.tensor(42))
|
||||
|
||||
def test_contains(self):
|
||||
x = torch.arange(0, 10)
|
||||
self.assertEqual(4 in x, True)
|
||||
self.assertEqual(12 in x, False)
|
||||
|
||||
x = torch.arange(1, 10).view(3, 3)
|
||||
val = torch.arange(1, 4)
|
||||
self.assertEqual(val in x, True)
|
||||
val += 10
|
||||
self.assertEqual(val in x, False)
|
||||
|
||||
@staticmethod
|
||||
def _test_rot90(self, use_cuda=False):
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import warnings
|
|||
import weakref
|
||||
from torch._six import imap
|
||||
from torch._C import _add_docstr
|
||||
from numbers import Number
|
||||
|
||||
|
||||
# NB: If you subclass Tensor, and want to share the subclassed class
|
||||
|
|
@ -426,6 +427,17 @@ class Tensor(torch._C._TensorBase):
|
|||
array = array.astype('uint8')
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def __contains__(self, element):
|
||||
r"""Check if `element` is present in tensor
|
||||
|
||||
Arguments:
|
||||
element (Tensor or scalar): element to be checked
|
||||
for presence in current tensor"
|
||||
"""
|
||||
if isinstance(element, (torch.Tensor, Number)):
|
||||
return (element == self).any().item()
|
||||
return NotImplemented
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self):
|
||||
"""Array view description for cuda tensors.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user