diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 000927f282f..f3a9e8c9361 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -304,50 +304,6 @@ - THTensor* self - THTensor* other ]] -[[ - name: _th_and - cpu_bool: True - cuda_bool: True - cname: __and__ - variants: - - function - return: argument 0 - options: - - cname: bitand - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real other - - cname: cbitand - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] -[[ - name: _th_iand_ - cname: __iand__ - cpu_bool: True - cuda_bool: True - variants: - - function - return: argument 0 - options: - - cname: bitand - arguments: - - THTensor* self - - THTensor* self - - real other - - cname: cbitand - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] [[ name: _th_or cname: __or__ diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 82c607a9fc7..7d3c3f8817d 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -502,6 +502,7 @@ _(aten, native_tensor) \ _(aten, native_zero) \ _(aten, ne) \ _(aten, neg) \ +_(aten, bitwise_and) \ _(aten, bitwise_not) \ _(aten, bitwise_xor) \ _(aten, nll_loss) \ diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 1980b142c30..af84b8ec333 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -15,6 +15,7 @@ DEFINE_DISPATCH(sub_stub); DEFINE_DISPATCH(mul_stub); DEFINE_DISPATCH(div_stub); DEFINE_DISPATCH(atan2_stub); +DEFINE_DISPATCH(bitwise_and_stub); DEFINE_DISPATCH(bitwise_xor_stub); DEFINE_DISPATCH(logical_and_stub); DEFINE_DISPATCH(logical_or_stub); @@ -234,6 +235,53 @@ Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) { return native::rsub(self, wrapped_scalar_tensor(other), alpha); } +Tensor& bitwise_and_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_op(result, self, other, + /*check_mem_overlap=*/true); + bitwise_and_stub(iter.device_type(), iter); + return result; +} + +Tensor bitwise_and(const Tensor& self, const Tensor& other) { + Tensor result = at::empty({0}, self.options()); + at::bitwise_and_out(result, self, other); + return result; +} + +Tensor& bitwise_and_(Tensor& self, const Tensor& other) { + return at::bitwise_and_out(self, self, other); +} + +Tensor& bitwise_and_out(Tensor& result, const Tensor& self, Scalar other) { + return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other)); +} + +Tensor bitwise_and(const Tensor& self, Scalar other) { + Tensor result = at::empty({0}, self.options()); + return at::bitwise_and_out(result, self, other); +} + +Tensor& bitwise_and_(Tensor& self, Scalar other) { + return at::bitwise_and_out(self, self, other); +} + +// Legacy and interfaces. They are aliased to bitwise_and* functions +Tensor __and__(const Tensor& self, const Tensor& other) { + return at::bitwise_and(self, other); +} + +Tensor __and__(const Tensor& self, Scalar other) { + return at::bitwise_and(self, other); +} + +Tensor& __iand__(Tensor& self, const Tensor& other) { + return self.bitwise_and_(other); +} + +Tensor& __iand__(Tensor& self, Scalar other) { + return self.bitwise_and_(other); +} + Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) { auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true); diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 8e32bbbf8a2..a8f2bb0a196 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -32,6 +32,7 @@ DECLARE_DISPATCH(binary_fn_alpha, sub_stub); DECLARE_DISPATCH(binary_fn, mul_stub); DECLARE_DISPATCH(binary_fn, div_stub); DECLARE_DISPATCH(binary_fn, atan2_stub); +DECLARE_DISPATCH(binary_fn, bitwise_and_stub); DECLARE_DISPATCH(binary_fn, bitwise_xor_stub); DECLARE_DISPATCH(binary_fn, logical_xor_stub); DECLARE_DISPATCH(binary_fn, logical_and_stub); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 5f6e1dd23df..b7188e4d103 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -93,6 +93,27 @@ void div_kernel(TensorIterator& iter) { } } +void bitwise_and_kernel(TensorIterator& iter) { + if (iter.dtype() == ScalarType::Bool) { + cpu_kernel( + iter, + [](bool a, bool b) { + return a && b; + }); + } else { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() { + cpu_kernel_vec( + iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return a & b; + }, + [](Vec256 a, Vec256 b) { + return a & b; + }); + }); + } +} + void bitwise_xor_kernel(TensorIterator& iter) { if (iter.dtype() == ScalarType::Bool) { // Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and @@ -341,6 +362,7 @@ REGISTER_DISPATCH(sub_stub, &sub_kernel); REGISTER_DISPATCH(mul_stub, &mul_kernel); REGISTER_DISPATCH(div_stub, &div_kernel); REGISTER_DISPATCH(atan2_stub, &atan2_kernel); +REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel); REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel); REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel); REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel); diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu index 3517abdace5..e1396a39ddd 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu @@ -18,6 +18,24 @@ void atan2_kernel_cuda(TensorIterator& iter) { }); } +void bitwise_and_kernel_cuda(TensorIterator& iter) { + if (iter.dtype() == ScalarType::Bool) { + gpu_kernel_with_scalars( + iter, + []GPU_LAMBDA(bool a, bool b) { + return a && b; + }); + } else { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cuda", [&]() { + gpu_kernel_with_scalars( + iter, + []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a & b; + }); + }); + } +} + void bitwise_xor_kernel_cuda(TensorIterator& iter) { if (iter.dtype() == ScalarType::Bool) { // Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and @@ -97,6 +115,7 @@ void mse_kernel_cuda(TensorIterator& iter) { } REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda); +REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_cuda); REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_cuda); REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel_cuda); REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel_cuda); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6a429d0d0bd..2dee634c3c9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3934,31 +3934,43 @@ - func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) variants: method +- func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CPU: bitwise_and_out + CUDA: bitwise_and_out + +- func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + CPU: bitwise_and_out + CUDA: bitwise_and_out + +- func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor + variants: method, function + +- func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor + variants: method, function + +- func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + variants: method + +- func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + - func: __and__.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full variants: method, function - dispatch: - CPU: legacy::cpu::_th_and - CUDA: legacy::cuda::_th_and - func: __and__.Tensor(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full variants: method, function - dispatch: - CPU: legacy::cpu::_th_and - CUDA: legacy::cuda::_th_and - func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) variants: method - dispatch: - CPU: legacy::cpu::_th_iand_ - CUDA: legacy::cuda::_th_iand_ - func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) variants: method - dispatch: - CPU: legacy::cpu::_th_iand_ - CUDA: legacy::cuda::_th_iand_ - func: __or__.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index 5b989582ad9..2f78aff2f6d 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -295,33 +295,6 @@ accreal THTensor_(sumall)(THTensor *tensor) return sum; } -void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value) -{ -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) || defined(TH_REAL_IS_BFLOAT16) - (void)r_; - (void)t; - (void)value; - return THError("bitand is only supported for integer type tensors"); -#else - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - if (r_Contig && tContig) { - scalar_t *tp = t->data(); - scalar_t *rp = r_->data(); - at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD * 100, - [&](int64_t start, int64_t end) { - for (auto i = start; i < end; i++) { - rp[i] = tp[i] & value; - } - }); - } else { - TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data & value;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } -#endif -} - scalar_t THTensor_(minall)(THTensor *tensor) { scalar_t theMin; diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp index ee47094f7f3..8ffc6aa117a 100644 --- a/aten/src/TH/generic/THTensorMath.cpp +++ b/aten/src/TH/generic/THTensorMath.cpp @@ -22,39 +22,6 @@ // sense (rather than just having cut the file down the middle, which is // what I did when I split these up originally). -void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src) -{ -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) - (void)r_; - (void)t; - (void)src; - return THError("cbitand is only supported for integer type tensors"); -#else - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int64_t srcSize = THTensor_(nElement)(src); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int srcContig = THTensor_(isContiguous)(src); - if (srcSize == r_Size){ - if (r_Contig && tContig && srcContig) { - scalar_t *tp = t->data(); - scalar_t *sp = src->data(); - scalar_t *rp = r_->data(); - at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD, - [&](int64_t start, int64_t end) { - for (auto i = start; i < end; i++) { - rp[i] = tp[i] & sp[i]; - } - }); - } else { - TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data & *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } - } else { - TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data & *src_data;); - } -#endif -} void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src) { diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index ce9e690cac7..a262e684e91 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -74,8 +74,6 @@ TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value); TH_API accreal THTensor_(sumall)(THTensor *t); -TH_API void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value); -TH_API void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src); TH_API void THTensor_(bitor)(THTensor *r_, THTensor *t, scalar_t value); TH_API void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src); diff --git a/aten/src/THC/THCTensorMathPairwise.cu b/aten/src/THC/THCTensorMathPairwise.cu index fe781feb0d5..b2724ef5270 100644 --- a/aten/src/THC/THCTensorMathPairwise.cu +++ b/aten/src/THC/THCTensorMathPairwise.cu @@ -259,20 +259,6 @@ struct TensorRShiftConstantOp { const T val; }; -template -struct TensorBitAndConstantOp { - TensorBitAndConstantOp(T v) : val(v) {} - __device__ __forceinline__ void operator()(T* out, T* in) { - *out = *in & val; - } - - __device__ __forceinline__ void operator()(T* v) { - *v &= val; - } - - const T val; -}; - template struct TensorBitOrConstantOp { TensorBitOrConstantOp(T v) : val(v) {} @@ -287,20 +273,6 @@ struct TensorBitOrConstantOp { const T val; }; -template -struct TensorBitXorConstantOp { - TensorBitXorConstantOp(T v) : val(v) {} - __device__ __forceinline__ void operator()(T* out, T* in) { - *out = *in ^ val; - } - - __device__ __forceinline__ void operator()(T* v) { - *v ^= val; - } - - const T val; -}; - #include #include diff --git a/aten/src/THC/THCTensorMathPointwise.cuh b/aten/src/THC/THCTensorMathPointwise.cuh index ece97e715cd..d26bdd1cac7 100644 --- a/aten/src/THC/THCTensorMathPointwise.cuh +++ b/aten/src/THC/THCTensorMathPointwise.cuh @@ -321,19 +321,6 @@ struct TensorRShiftOp { } }; -template -struct TensorBitAndOp { - __device__ __forceinline__ void - operator()(T* out, T* in) { - *out &= *in; - } - - __device__ __forceinline__ void - operator()(T* out, T* in1, T* in2) { - *out = *in1 & *in2; - } -}; - template struct TensorBitOrOp { __device__ __forceinline__ void @@ -347,17 +334,4 @@ struct TensorBitOrOp { } }; -template -struct TensorBitXorOp { - __device__ __forceinline__ void - operator()(T* out, T* in) { - *out ^= *in; - } - - __device__ __forceinline__ void - operator()(T* out, T* in1, T* in2) { - *out = *in1 ^ *in2; - } -}; - #endif // THC_TENSORMATH_POINTWISE_CUH diff --git a/aten/src/THC/generic/THCTensorMathPairwise.cu b/aten/src/THC/generic/THCTensorMathPairwise.cu index 734347ec62c..226ab73fe70 100644 --- a/aten/src/THC/generic/THCTensorMathPairwise.cu +++ b/aten/src/THC/generic/THCTensorMathPairwise.cu @@ -39,27 +39,6 @@ int THCTensor_(equal)(THCState *state, THCTensor *self_, THCTensor *src_) { return THCTensor_(equalImpl)(state, self_, src_); } -void THCTensor_(bitand)(THCState* state, THCTensor *self_, THCTensor *src_, scalar_t value) -{ -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) - return THError("bitand only supported for integer type tensors"); -#else - if (self_ == src_) { - if (!THC_pointwiseApply1(state, self_, TensorBitAndConstantOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCTensor_(resizeAs)(state, self_, src_); - - if (!THC_pointwiseApply2(state, self_, src_, TensorBitAndConstantOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -#endif -} - void THCTensor_(bitor)(THCState* state, THCTensor *self_, THCTensor *src_, scalar_t value) { #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) diff --git a/aten/src/THC/generic/THCTensorMathPairwise.h b/aten/src/THC/generic/THCTensorMathPairwise.h index 246be6be705..bee23e2d05e 100644 --- a/aten/src/THC/generic/THCTensorMathPairwise.h +++ b/aten/src/THC/generic/THCTensorMathPairwise.h @@ -4,7 +4,6 @@ THC_API int THCTensor_(equal)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(bitand)(THCState *state, THCTensor *self, THCTensor *src, scalar_t value); THC_API void THCTensor_(bitor)(THCState *state, THCTensor *self, THCTensor *src, scalar_t value); #if !defined(THC_REAL_IS_BOOL) diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu index 4cef2d152f5..dee8374c979 100644 --- a/aten/src/THC/generic/THCTensorMathPointwise.cu +++ b/aten/src/THC/generic/THCTensorMathPointwise.cu @@ -5,33 +5,6 @@ #include #include -void THCTensor_(cbitand)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2) -{ -#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) - return THError("cbitand is only supported for integer type tensors"); -#else - THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THArgCheck(THCTensor_(nElement)(state, src1) == - THCTensor_(nElement)(state, src2), 3, "sizes do not match"); - - if (self_ == src1) { - // self /= src2 - if (!THC_pointwiseApply2(state, self_, src2, TensorBitAndOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCTensor_(resizeAs)(state, self_, src1); - - // self = src1 / src2 - if (!THC_pointwiseApply3(state, self_, src1, src2, TensorBitAndOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -#endif -} - void THCTensor_(cbitor)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2) { #if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) diff --git a/aten/src/THC/generic/THCTensorMathPointwise.h b/aten/src/THC/generic/THCTensorMathPointwise.h index 145ac545b04..765bfd6ab42 100644 --- a/aten/src/THC/generic/THCTensorMathPointwise.h +++ b/aten/src/THC/generic/THCTensorMathPointwise.h @@ -2,7 +2,6 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathPointwise.h" #else -THC_API void THCTensor_(cbitand)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(cbitor)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(cmax)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index c13492eb936..40b65c6e589 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -193,6 +193,8 @@ view of a storage and defines numeric operations on it. .. automethod:: bincount .. automethod:: bitwise_not .. automethod:: bitwise_not_ + .. automethod:: bitwise_and + .. automethod:: bitwise_and_ .. automethod:: bitwise_xor .. automethod:: bitwise_xor_ .. automethod:: bmm diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 361abe780df..1d888105418 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -198,6 +198,7 @@ Pointwise Ops .. autofunction:: atan .. autofunction:: atan2 .. autofunction:: bitwise_not +.. autofunction:: bitwise_and .. autofunction:: bitwise_xor .. autofunction:: ceil .. autofunction:: clamp diff --git a/test/test_torch.py b/test/test_torch.py index f894bd52a52..b2e9183a793 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6687,6 +6687,36 @@ class TestTorchDeviceType(TestCase): with self.assertRaises(RuntimeError): a.bitwise_not_() + def test_bitwise_and(self, device): + for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + a = torch.tensor([1, -2, 3], dtype=dtype, device=device) + b = torch.tensor([2, 1, 3], dtype=dtype, device=device) + expected_res = torch.tensor([0, 0, 3], dtype=dtype, device=device) + b_scalar = 2 + expected_res_scalar = torch.tensor([0, 2, 2], dtype=dtype, device=device) + + # standard version + self.assertEqual(torch.bitwise_and(a, b), expected_res) + self.assertEqual(torch.bitwise_and(a, b_scalar), expected_res_scalar) + + # out + c = torch.empty(0, dtype=dtype, device=device) + torch.bitwise_and(a, b, out=c) + self.assertEqual(c, expected_res) + torch.bitwise_and(a, b_scalar, out=c) + self.assertEqual(c, expected_res_scalar) + + # in-place + a1 = a.clone() + a1.bitwise_and_(b) + self.assertEqual(a1, expected_res) + a.bitwise_and_(b_scalar) + self.assertEqual(a, expected_res_scalar) + + self.assertEqual(torch.tensor([False, True, False], device=device), + torch.bitwise_and(torch.tensor([True, True, False], device=device), + torch.tensor([False, True, False], device=device))) + def test_bitwise_xor(self, device): for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): a = torch.tensor([1, -2, 3], dtype=dtype, device=device) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 82797593adf..585c3a5540b 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -557,6 +557,20 @@ bitwise_not_() -> Tensor In-place version of :meth:`~Tensor.bitwise_not` """) +add_docstr_all('bitwise_and', + r""" +bitwise_and() -> Tensor + +See :func:`torch.bitwise_and` +""") + +add_docstr_all('bitwise_and_', + r""" +bitwise_and_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_and` +""") + add_docstr_all('bitwise_xor', r""" bitwise_xor() -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index a481af79b60..d0b9c6e7f35 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -825,6 +825,26 @@ Example:: torch.Size([10, 3, 5]) """.format(**common_args)) +add_docstr(torch.bitwise_and, + r""" +bitwise_and(input, other, out=None) -> Tensor + +Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical AND. + +Args: + input: the first input tensor + other: the second input tensor + {out} + +Example: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) +""".format(**common_args)) + add_docstr(torch.bitwise_xor, r""" bitwise_xor(input, other, out=None) -> Tensor