mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add op bitwise_and (#31104)
Summary: Refer to https://github.com/pytorch/pytorch/pull/25665, add `bitwise_and` operator. Benchmark script : ``` import timeit #for __and__ for n, t in [(10, 100000),(1000, 10000)]: print('__and__ (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a & b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t)) #for __iand__ for n, t in [(10, 100000),(1000, 10000)]: print('__iand__ (a.numel() == {}) for {} times'.format(n, t)) for device in ('cpu', 'cuda'): for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'): print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a & b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t)) ``` Device: **Tesla P100, skx-8180** Cuda verison: **9.0.176** Before: ``` __and__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.1766007635742426 device: cpu, dtype: torch.uint8, 100000 times 0.17322628945112228 device: cpu, dtype: torch.int16, 100000 times 0.17650844901800156 device: cpu, dtype: torch.int32, 100000 times 0.17711848113685846 device: cpu, dtype: torch.int64, 100000 times 0.18240160401910543 device: cuda, dtype: torch.int8, 100000 times 1.273967768996954 device: cuda, dtype: torch.uint8, 100000 times 1.2778537990525365 device: cuda, dtype: torch.int16, 100000 times 1.2753686187788844 device: cuda, dtype: torch.int32, 100000 times 1.2797665279358625 device: cuda, dtype: torch.int64, 100000 times 1.2933144550770521 __and__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.031139614060521126 device: cpu, dtype: torch.uint8, 10000 times 0.03091452084481716 device: cpu, dtype: torch.int16, 10000 times 0.022756479680538177 device: cpu, dtype: torch.int32, 10000 times 0.025045674294233322 device: cpu, dtype: torch.int64, 10000 times 0.024164282716810703 device: cuda, dtype: torch.int8, 10000 times 0.12820732593536377 device: cuda, dtype: torch.uint8, 10000 times 0.12775669433176517 device: cuda, dtype: torch.int16, 10000 times 0.12697868794202805 device: cuda, dtype: torch.int32, 10000 times 0.12832533661276102 device: cuda, dtype: torch.int64, 10000 times 0.1280576130375266 __iand__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.3687064303085208 device: cpu, dtype: torch.uint8, 100000 times 0.36253443732857704 device: cpu, dtype: torch.int16, 100000 times 0.362891579978168 device: cpu, dtype: torch.int32, 100000 times 0.37680106051266193 device: cpu, dtype: torch.int64, 100000 times 0.3689364707097411 device: cuda, dtype: torch.int8, 100000 times 1.419940729625523 device: cuda, dtype: torch.uint8, 100000 times 1.4247053815051913 device: cuda, dtype: torch.int16, 100000 times 1.4191444097086787 device: cuda, dtype: torch.int32, 100000 times 1.4305962566286325 device: cuda, dtype: torch.int64, 100000 times 1.4567416654899716 __iand__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.06224383972585201 device: cpu, dtype: torch.uint8, 10000 times 0.06205617543309927 device: cpu, dtype: torch.int16, 10000 times 0.05016433447599411 device: cpu, dtype: torch.int32, 10000 times 0.05216377507895231 device: cpu, dtype: torch.int64, 10000 times 0.06139362137764692 device: cuda, dtype: torch.int8, 10000 times 0.14827249851077795 device: cuda, dtype: torch.uint8, 10000 times 0.14801877550780773 device: cuda, dtype: torch.int16, 10000 times 0.14952312968671322 device: cuda, dtype: torch.int32, 10000 times 0.14999118447303772 device: cuda, dtype: torch.int64, 10000 times 0.14951884001493454 ``` After: ``` __and__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.23157884553074837 device: cpu, dtype: torch.uint8, 100000 times 0.23063660878688097 device: cpu, dtype: torch.int16, 100000 times 0.23005440644919872 device: cpu, dtype: torch.int32, 100000 times 0.23748818412423134 device: cpu, dtype: torch.int64, 100000 times 0.24106105230748653 device: cuda, dtype: torch.int8, 100000 times 1.4394256137311459 device: cuda, dtype: torch.uint8, 100000 times 1.4436759827658534 device: cuda, dtype: torch.int16, 100000 times 1.4631587155163288 device: cuda, dtype: torch.int32, 100000 times 1.459101552143693 device: cuda, dtype: torch.int64, 100000 times 1.4784048134461045 __and__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.028442862443625927 device: cpu, dtype: torch.uint8, 10000 times 0.028130197897553444 device: cpu, dtype: torch.int16, 10000 times 0.025318274274468422 device: cpu, dtype: torch.int32, 10000 times 0.02519288007169962 device: cpu, dtype: torch.int64, 10000 times 0.028299466706812382 device: cuda, dtype: torch.int8, 10000 times 0.14342594426125288 device: cuda, dtype: torch.uint8, 10000 times 0.145280827768147 device: cuda, dtype: torch.int16, 10000 times 0.14673697855323553 device: cuda, dtype: torch.int32, 10000 times 0.14499565307050943 device: cuda, dtype: torch.int64, 10000 times 0.14582364354282618 __iand__ (a.numel() == 10) for 100000 times device: cpu, dtype: torch.int8, 100000 times 0.25548241566866636 device: cpu, dtype: torch.uint8, 100000 times 0.2552562616765499 device: cpu, dtype: torch.int16, 100000 times 0.25905191246420145 device: cpu, dtype: torch.int32, 100000 times 0.26635489892214537 device: cpu, dtype: torch.int64, 100000 times 0.26269810926169157 device: cuda, dtype: torch.int8, 100000 times 1.485458506271243 device: cuda, dtype: torch.uint8, 100000 times 1.4742380809038877 device: cuda, dtype: torch.int16, 100000 times 1.507783885113895 device: cuda, dtype: torch.int32, 100000 times 1.4926990242674947 device: cuda, dtype: torch.int64, 100000 times 1.519851053133607 __iand__ (a.numel() == 1000) for 10000 times device: cpu, dtype: torch.int8, 10000 times 0.03425929415971041 device: cpu, dtype: torch.uint8, 10000 times 0.03293587639927864 device: cpu, dtype: torch.int16, 10000 times 0.029559112153947353 device: cpu, dtype: torch.int32, 10000 times 0.030915481969714165 device: cpu, dtype: torch.int64, 10000 times 0.03292469773441553 device: cuda, dtype: torch.int8, 10000 times 0.15792148280888796 device: cuda, dtype: torch.uint8, 10000 times 0.16000914946198463 device: cuda, dtype: torch.int16, 10000 times 0.1600684942677617 device: cuda, dtype: torch.int32, 10000 times 0.16162546630948782 device: cuda, dtype: torch.int64, 10000 times 0.1629159888252616 ``` Fix https://github.com/pytorch/pytorch/issues/24508, https://github.com/pytorch/pytorch/issues/24509, https://github.com/pytorch/pytorch/issues/24655, https://github.com/pytorch/pytorch/issues/24656. Pull Request resolved: https://github.com/pytorch/pytorch/pull/31104 Differential Revision: D18938930 Pulled By: VitalyFedyunin fbshipit-source-id: a77e805a0b84e8ace16c6e648c2f67dad44f2e44
This commit is contained in:
parent
68f3782106
commit
b47e9b97a2
|
|
@ -304,50 +304,6 @@
|
|||
- THTensor* self
|
||||
- THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_and
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cname: __and__
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
options:
|
||||
- cname: bitand
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- real other
|
||||
- cname: cbitand
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- arg: THTensor* self
|
||||
broadcast: other fallback
|
||||
- THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_iand_
|
||||
cname: __iand__
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
options:
|
||||
- cname: bitand
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THTensor* self
|
||||
- real other
|
||||
- cname: cbitand
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- arg: THTensor* self
|
||||
broadcast: other inplace fallback
|
||||
- THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_or
|
||||
cname: __or__
|
||||
|
|
|
|||
|
|
@ -502,6 +502,7 @@ _(aten, native_tensor) \
|
|||
_(aten, native_zero) \
|
||||
_(aten, ne) \
|
||||
_(aten, neg) \
|
||||
_(aten, bitwise_and) \
|
||||
_(aten, bitwise_not) \
|
||||
_(aten, bitwise_xor) \
|
||||
_(aten, nll_loss) \
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ DEFINE_DISPATCH(sub_stub);
|
|||
DEFINE_DISPATCH(mul_stub);
|
||||
DEFINE_DISPATCH(div_stub);
|
||||
DEFINE_DISPATCH(atan2_stub);
|
||||
DEFINE_DISPATCH(bitwise_and_stub);
|
||||
DEFINE_DISPATCH(bitwise_xor_stub);
|
||||
DEFINE_DISPATCH(logical_and_stub);
|
||||
DEFINE_DISPATCH(logical_or_stub);
|
||||
|
|
@ -234,6 +235,53 @@ Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) {
|
|||
return native::rsub(self, wrapped_scalar_tensor(other), alpha);
|
||||
}
|
||||
|
||||
Tensor& bitwise_and_out(Tensor& result, const Tensor& self, const Tensor& other) {
|
||||
auto iter = TensorIterator::binary_op(result, self, other,
|
||||
/*check_mem_overlap=*/true);
|
||||
bitwise_and_stub(iter.device_type(), iter);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor bitwise_and(const Tensor& self, const Tensor& other) {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
at::bitwise_and_out(result, self, other);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& bitwise_and_(Tensor& self, const Tensor& other) {
|
||||
return at::bitwise_and_out(self, self, other);
|
||||
}
|
||||
|
||||
Tensor& bitwise_and_out(Tensor& result, const Tensor& self, Scalar other) {
|
||||
return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other));
|
||||
}
|
||||
|
||||
Tensor bitwise_and(const Tensor& self, Scalar other) {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
return at::bitwise_and_out(result, self, other);
|
||||
}
|
||||
|
||||
Tensor& bitwise_and_(Tensor& self, Scalar other) {
|
||||
return at::bitwise_and_out(self, self, other);
|
||||
}
|
||||
|
||||
// Legacy and interfaces. They are aliased to bitwise_and* functions
|
||||
Tensor __and__(const Tensor& self, const Tensor& other) {
|
||||
return at::bitwise_and(self, other);
|
||||
}
|
||||
|
||||
Tensor __and__(const Tensor& self, Scalar other) {
|
||||
return at::bitwise_and(self, other);
|
||||
}
|
||||
|
||||
Tensor& __iand__(Tensor& self, const Tensor& other) {
|
||||
return self.bitwise_and_(other);
|
||||
}
|
||||
|
||||
Tensor& __iand__(Tensor& self, Scalar other) {
|
||||
return self.bitwise_and_(other);
|
||||
}
|
||||
|
||||
Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) {
|
||||
auto iter = TensorIterator::binary_op(result, self, other,
|
||||
/*check_mem_overlap=*/true);
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ DECLARE_DISPATCH(binary_fn_alpha, sub_stub);
|
|||
DECLARE_DISPATCH(binary_fn, mul_stub);
|
||||
DECLARE_DISPATCH(binary_fn, div_stub);
|
||||
DECLARE_DISPATCH(binary_fn, atan2_stub);
|
||||
DECLARE_DISPATCH(binary_fn, bitwise_and_stub);
|
||||
DECLARE_DISPATCH(binary_fn, bitwise_xor_stub);
|
||||
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
|
||||
DECLARE_DISPATCH(binary_fn, logical_and_stub);
|
||||
|
|
|
|||
|
|
@ -93,6 +93,27 @@ void div_kernel(TensorIterator& iter) {
|
|||
}
|
||||
}
|
||||
|
||||
void bitwise_and_kernel(TensorIterator& iter) {
|
||||
if (iter.dtype() == ScalarType::Bool) {
|
||||
cpu_kernel(
|
||||
iter,
|
||||
[](bool a, bool b) {
|
||||
return a && b;
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a & b;
|
||||
},
|
||||
[](Vec256<scalar_t> a, Vec256<scalar_t> b) {
|
||||
return a & b;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void bitwise_xor_kernel(TensorIterator& iter) {
|
||||
if (iter.dtype() == ScalarType::Bool) {
|
||||
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
|
||||
|
|
@ -341,6 +362,7 @@ REGISTER_DISPATCH(sub_stub, &sub_kernel);
|
|||
REGISTER_DISPATCH(mul_stub, &mul_kernel);
|
||||
REGISTER_DISPATCH(div_stub, &div_kernel);
|
||||
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
|
||||
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
|
||||
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel);
|
||||
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel);
|
||||
REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel);
|
||||
|
|
|
|||
|
|
@ -18,6 +18,24 @@ void atan2_kernel_cuda(TensorIterator& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
void bitwise_and_kernel_cuda(TensorIterator& iter) {
|
||||
if (iter.dtype() == ScalarType::Bool) {
|
||||
gpu_kernel_with_scalars(
|
||||
iter,
|
||||
[]GPU_LAMBDA(bool a, bool b) {
|
||||
return a && b;
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cuda", [&]() {
|
||||
gpu_kernel_with_scalars(
|
||||
iter,
|
||||
[]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a & b;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void bitwise_xor_kernel_cuda(TensorIterator& iter) {
|
||||
if (iter.dtype() == ScalarType::Bool) {
|
||||
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
|
||||
|
|
@ -97,6 +115,7 @@ void mse_kernel_cuda(TensorIterator& iter) {
|
|||
}
|
||||
|
||||
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
|
||||
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_cuda);
|
||||
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_cuda);
|
||||
REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel_cuda);
|
||||
REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel_cuda);
|
||||
|
|
|
|||
|
|
@ -3934,31 +3934,43 @@
|
|||
- func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
|
||||
variants: method
|
||||
|
||||
- func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: bitwise_and_out
|
||||
CUDA: bitwise_and_out
|
||||
|
||||
- func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: bitwise_and_out
|
||||
CUDA: bitwise_and_out
|
||||
|
||||
- func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
variants: method, function
|
||||
|
||||
- func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
variants: method, function
|
||||
|
||||
- func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
variants: method
|
||||
|
||||
- func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
|
||||
variants: method
|
||||
|
||||
- func: __and__.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_th_and
|
||||
CUDA: legacy::cuda::_th_and
|
||||
|
||||
- func: __and__.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_th_and
|
||||
CUDA: legacy::cuda::_th_and
|
||||
|
||||
- func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_th_iand_
|
||||
CUDA: legacy::cuda::_th_iand_
|
||||
|
||||
- func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_th_iand_
|
||||
CUDA: legacy::cuda::_th_iand_
|
||||
|
||||
- func: __or__.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
|
|
|||
|
|
@ -295,33 +295,6 @@ accreal THTensor_(sumall)(THTensor *tensor)
|
|||
return sum;
|
||||
}
|
||||
|
||||
void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value)
|
||||
{
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) || defined(TH_REAL_IS_BFLOAT16)
|
||||
(void)r_;
|
||||
(void)t;
|
||||
(void)value;
|
||||
return THError("bitand is only supported for integer type tensors");
|
||||
#else
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
int64_t r_Size = THTensor_(nElement)(r_);
|
||||
int r_Contig = THTensor_(isContiguous)(r_);
|
||||
int tContig = THTensor_(isContiguous)(t);
|
||||
if (r_Contig && tContig) {
|
||||
scalar_t *tp = t->data<scalar_t>();
|
||||
scalar_t *rp = r_->data<scalar_t>();
|
||||
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD * 100,
|
||||
[&](int64_t start, int64_t end) {
|
||||
for (auto i = start; i < end; i++) {
|
||||
rp[i] = tp[i] & value;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data & value;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
scalar_t THTensor_(minall)(THTensor *tensor)
|
||||
{
|
||||
scalar_t theMin;
|
||||
|
|
|
|||
|
|
@ -22,39 +22,6 @@
|
|||
// sense (rather than just having cut the file down the middle, which is
|
||||
// what I did when I split these up originally).
|
||||
|
||||
void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src)
|
||||
{
|
||||
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
|
||||
(void)r_;
|
||||
(void)t;
|
||||
(void)src;
|
||||
return THError("cbitand is only supported for integer type tensors");
|
||||
#else
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
int64_t r_Size = THTensor_(nElement)(r_);
|
||||
int64_t srcSize = THTensor_(nElement)(src);
|
||||
int r_Contig = THTensor_(isContiguous)(r_);
|
||||
int tContig = THTensor_(isContiguous)(t);
|
||||
int srcContig = THTensor_(isContiguous)(src);
|
||||
if (srcSize == r_Size){
|
||||
if (r_Contig && tContig && srcContig) {
|
||||
scalar_t *tp = t->data<scalar_t>();
|
||||
scalar_t *sp = src->data<scalar_t>();
|
||||
scalar_t *rp = r_->data<scalar_t>();
|
||||
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD,
|
||||
[&](int64_t start, int64_t end) {
|
||||
for (auto i = start; i < end; i++) {
|
||||
rp[i] = tp[i] & sp[i];
|
||||
}
|
||||
});
|
||||
} else {
|
||||
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data & *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
|
||||
}
|
||||
} else {
|
||||
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data & *src_data;);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -74,8 +74,6 @@ TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value);
|
|||
|
||||
TH_API accreal THTensor_(sumall)(THTensor *t);
|
||||
|
||||
TH_API void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value);
|
||||
TH_API void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src);
|
||||
TH_API void THTensor_(bitor)(THTensor *r_, THTensor *t, scalar_t value);
|
||||
TH_API void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src);
|
||||
|
||||
|
|
|
|||
|
|
@ -259,20 +259,6 @@ struct TensorRShiftConstantOp {
|
|||
const T val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TensorBitAndConstantOp {
|
||||
TensorBitAndConstantOp(T v) : val(v) {}
|
||||
__device__ __forceinline__ void operator()(T* out, T* in) {
|
||||
*out = *in & val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void operator()(T* v) {
|
||||
*v &= val;
|
||||
}
|
||||
|
||||
const T val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TensorBitOrConstantOp {
|
||||
TensorBitOrConstantOp(T v) : val(v) {}
|
||||
|
|
@ -287,20 +273,6 @@ struct TensorBitOrConstantOp {
|
|||
const T val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TensorBitXorConstantOp {
|
||||
TensorBitXorConstantOp(T v) : val(v) {}
|
||||
__device__ __forceinline__ void operator()(T* out, T* in) {
|
||||
*out = *in ^ val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void operator()(T* v) {
|
||||
*v ^= val;
|
||||
}
|
||||
|
||||
const T val;
|
||||
};
|
||||
|
||||
#include <THC/generic/THCTensorMathPairwise.cu>
|
||||
#include <THC/THCGenerateAllTypes.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -321,19 +321,6 @@ struct TensorRShiftOp<double> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TensorBitAndOp {
|
||||
__device__ __forceinline__ void
|
||||
operator()(T* out, T* in) {
|
||||
*out &= *in;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
operator()(T* out, T* in1, T* in2) {
|
||||
*out = *in1 & *in2;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TensorBitOrOp {
|
||||
__device__ __forceinline__ void
|
||||
|
|
@ -347,17 +334,4 @@ struct TensorBitOrOp {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TensorBitXorOp {
|
||||
__device__ __forceinline__ void
|
||||
operator()(T* out, T* in) {
|
||||
*out ^= *in;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
operator()(T* out, T* in1, T* in2) {
|
||||
*out = *in1 ^ *in2;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // THC_TENSORMATH_POINTWISE_CUH
|
||||
|
|
|
|||
|
|
@ -39,27 +39,6 @@ int THCTensor_(equal)(THCState *state, THCTensor *self_, THCTensor *src_) {
|
|||
return THCTensor_(equalImpl)(state, self_, src_);
|
||||
}
|
||||
|
||||
void THCTensor_(bitand)(THCState* state, THCTensor *self_, THCTensor *src_, scalar_t value)
|
||||
{
|
||||
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
|
||||
return THError("bitand only supported for integer type tensors");
|
||||
#else
|
||||
if (self_ == src_) {
|
||||
if (!THC_pointwiseApply1<scalar_t>(state, self_, TensorBitAndConstantOp<scalar_t>(value))) {
|
||||
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
|
||||
}
|
||||
} else {
|
||||
THCTensor_(resizeAs)(state, self_, src_);
|
||||
|
||||
if (!THC_pointwiseApply2<scalar_t, scalar_t>(state, self_, src_, TensorBitAndConstantOp<scalar_t>(value))) {
|
||||
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
|
||||
}
|
||||
}
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
#endif
|
||||
}
|
||||
|
||||
void THCTensor_(bitor)(THCState* state, THCTensor *self_, THCTensor *src_, scalar_t value)
|
||||
{
|
||||
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
THC_API int THCTensor_(equal)(THCState *state, THCTensor *self, THCTensor *src);
|
||||
|
||||
THC_API void THCTensor_(bitand)(THCState *state, THCTensor *self, THCTensor *src, scalar_t value);
|
||||
THC_API void THCTensor_(bitor)(THCState *state, THCTensor *self, THCTensor *src, scalar_t value);
|
||||
|
||||
#if !defined(THC_REAL_IS_BOOL)
|
||||
|
|
|
|||
|
|
@ -5,33 +5,6 @@
|
|||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
|
||||
void THCTensor_(cbitand)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
|
||||
{
|
||||
#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
|
||||
return THError("cbitand is only supported for integer type tensors");
|
||||
#else
|
||||
THAssert(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
|
||||
THArgCheck(THCTensor_(nElement)(state, src1) ==
|
||||
THCTensor_(nElement)(state, src2), 3, "sizes do not match");
|
||||
|
||||
if (self_ == src1) {
|
||||
// self /= src2
|
||||
if (!THC_pointwiseApply2<scalar_t, scalar_t>(state, self_, src2, TensorBitAndOp<scalar_t>())) {
|
||||
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
|
||||
}
|
||||
} else {
|
||||
THCTensor_(resizeAs)(state, self_, src1);
|
||||
|
||||
// self = src1 / src2
|
||||
if (!THC_pointwiseApply3<scalar_t, scalar_t, scalar_t>(state, self_, src1, src2, TensorBitAndOp<scalar_t>())) {
|
||||
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
|
||||
}
|
||||
}
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
#endif
|
||||
}
|
||||
|
||||
void THCTensor_(cbitor)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
|
||||
{
|
||||
#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
#define THC_GENERIC_FILE "THC/generic/THCTensorMathPointwise.h"
|
||||
#else
|
||||
|
||||
THC_API void THCTensor_(cbitand)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(cbitor)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
|
||||
|
||||
THC_API void THCTensor_(cmax)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
|
||||
|
|
|
|||
|
|
@ -193,6 +193,8 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: bincount
|
||||
.. automethod:: bitwise_not
|
||||
.. automethod:: bitwise_not_
|
||||
.. automethod:: bitwise_and
|
||||
.. automethod:: bitwise_and_
|
||||
.. automethod:: bitwise_xor
|
||||
.. automethod:: bitwise_xor_
|
||||
.. automethod:: bmm
|
||||
|
|
|
|||
|
|
@ -198,6 +198,7 @@ Pointwise Ops
|
|||
.. autofunction:: atan
|
||||
.. autofunction:: atan2
|
||||
.. autofunction:: bitwise_not
|
||||
.. autofunction:: bitwise_and
|
||||
.. autofunction:: bitwise_xor
|
||||
.. autofunction:: ceil
|
||||
.. autofunction:: clamp
|
||||
|
|
|
|||
|
|
@ -6687,6 +6687,36 @@ class TestTorchDeviceType(TestCase):
|
|||
with self.assertRaises(RuntimeError):
|
||||
a.bitwise_not_()
|
||||
|
||||
def test_bitwise_and(self, device):
|
||||
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
|
||||
a = torch.tensor([1, -2, 3], dtype=dtype, device=device)
|
||||
b = torch.tensor([2, 1, 3], dtype=dtype, device=device)
|
||||
expected_res = torch.tensor([0, 0, 3], dtype=dtype, device=device)
|
||||
b_scalar = 2
|
||||
expected_res_scalar = torch.tensor([0, 2, 2], dtype=dtype, device=device)
|
||||
|
||||
# standard version
|
||||
self.assertEqual(torch.bitwise_and(a, b), expected_res)
|
||||
self.assertEqual(torch.bitwise_and(a, b_scalar), expected_res_scalar)
|
||||
|
||||
# out
|
||||
c = torch.empty(0, dtype=dtype, device=device)
|
||||
torch.bitwise_and(a, b, out=c)
|
||||
self.assertEqual(c, expected_res)
|
||||
torch.bitwise_and(a, b_scalar, out=c)
|
||||
self.assertEqual(c, expected_res_scalar)
|
||||
|
||||
# in-place
|
||||
a1 = a.clone()
|
||||
a1.bitwise_and_(b)
|
||||
self.assertEqual(a1, expected_res)
|
||||
a.bitwise_and_(b_scalar)
|
||||
self.assertEqual(a, expected_res_scalar)
|
||||
|
||||
self.assertEqual(torch.tensor([False, True, False], device=device),
|
||||
torch.bitwise_and(torch.tensor([True, True, False], device=device),
|
||||
torch.tensor([False, True, False], device=device)))
|
||||
|
||||
def test_bitwise_xor(self, device):
|
||||
for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
|
||||
a = torch.tensor([1, -2, 3], dtype=dtype, device=device)
|
||||
|
|
|
|||
|
|
@ -557,6 +557,20 @@ bitwise_not_() -> Tensor
|
|||
In-place version of :meth:`~Tensor.bitwise_not`
|
||||
""")
|
||||
|
||||
add_docstr_all('bitwise_and',
|
||||
r"""
|
||||
bitwise_and() -> Tensor
|
||||
|
||||
See :func:`torch.bitwise_and`
|
||||
""")
|
||||
|
||||
add_docstr_all('bitwise_and_',
|
||||
r"""
|
||||
bitwise_and_() -> Tensor
|
||||
|
||||
In-place version of :meth:`~Tensor.bitwise_and`
|
||||
""")
|
||||
|
||||
add_docstr_all('bitwise_xor',
|
||||
r"""
|
||||
bitwise_xor() -> Tensor
|
||||
|
|
|
|||
|
|
@ -825,6 +825,26 @@ Example::
|
|||
torch.Size([10, 3, 5])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.bitwise_and,
|
||||
r"""
|
||||
bitwise_and(input, other, out=None) -> Tensor
|
||||
|
||||
Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of
|
||||
integral or Boolean types. For bool tensors, it computes the logical AND.
|
||||
|
||||
Args:
|
||||
input: the first input tensor
|
||||
other: the second input tensor
|
||||
{out}
|
||||
|
||||
Example:
|
||||
|
||||
>>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8))
|
||||
tensor([1, 0, 3], dtype=torch.int8)
|
||||
>>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False]))
|
||||
tensor([ False, True, False])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.bitwise_xor,
|
||||
r"""
|
||||
bitwise_xor(input, other, out=None) -> Tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user