From b57fe3cc66cfc348d56a7f41fa74b784119fa89a Mon Sep 17 00:00:00 2001 From: bhushan Date: Mon, 11 Mar 2019 08:55:01 -0700 Subject: [PATCH] Introducing array-like sequence methods __contains__ (#17733) Summary: for tensor Fixes: #17000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17733 Differential Revision: D14401952 Pulled By: soumith fbshipit-source-id: c841b128c5a1fceda1094323ed4ef1d0cf494909 --- test/test_torch.py | 11 +++++++++++ torch/tensor.py | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/test/test_torch.py b/test/test_torch.py index 4960b239de1..b42dd9e41eb 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8151,6 +8151,17 @@ class _TestTorchMixin(object): val = torch.tensor(42) self.assertEqual(reversed(val), torch.tensor(42)) + def test_contains(self): + x = torch.arange(0, 10) + self.assertEqual(4 in x, True) + self.assertEqual(12 in x, False) + + x = torch.arange(1, 10).view(3, 3) + val = torch.arange(1, 4) + self.assertEqual(val in x, True) + val += 10 + self.assertEqual(val in x, False) + @staticmethod def _test_rot90(self, use_cuda=False): device = torch.device("cuda" if use_cuda else "cpu") diff --git a/torch/tensor.py b/torch/tensor.py index bc4be2b15f0..17c6fa86145 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -7,6 +7,7 @@ import warnings import weakref from torch._six import imap from torch._C import _add_docstr +from numbers import Number # NB: If you subclass Tensor, and want to share the subclassed class @@ -426,6 +427,17 @@ class Tensor(torch._C._TensorBase): array = array.astype('uint8') return torch.from_numpy(array) + def __contains__(self, element): + r"""Check if `element` is present in tensor + + Arguments: + element (Tensor or scalar): element to be checked + for presence in current tensor" + """ + if isinstance(element, (torch.Tensor, Number)): + return (element == self).any().item() + return NotImplemented + @property def __cuda_array_interface__(self): """Array view description for cuda tensors.