diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index d8accd80c9b..2c975d67309 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -606,72 +606,6 @@ broadcast: other inplace fallback - THTensor* other ]] -[[ - name: _th_lt - cpu_bool: True - cuda_bool: True - cpu_bfloat16: True - variants: - - function - return: argument 0 - options: - - cname: ltValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* self - - real other - - cname: ltTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] -[[ - name: _th_lt_byte - cpu_bool: True - cuda_bool: True - cpu_bfloat16: True - variants: - - function - return: argument 0 - options: - - cname: ltValueByte - arguments: - - arg: THByteTensor* result - output: True - - THTensor* self - - real other - - cname: ltTensorByte - arguments: - - arg: THByteTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] -[[ - name: _th_lt_ - cpu_bool: True - cuda_bool: True - cpu_bfloat16: True - return: self - variants: function - options: - - cname: ltValueT - arguments: - - THTensor* self - - THTensor* self - - real other - - cname: ltTensorT - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - arg: THTensor* other -]] [[ name: _th_gt cpu_bool: True diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 84f4d7b520e..dc577d6860d 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -3810,13 +3810,7 @@ inline Tensor Tensor::scatter_add(Dimname dim, const Tensor & index, const Tenso #endif inline Tensor & Tensor::lt_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lt_(const_cast(*this), other); - break; - default: - AT_ERROR("lt_ not implemented for ", at::toString(type_set())); - } + return TypeDefault::lt_(const_cast(*this), other); #else static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Scalar"}).value(); return c10::Dispatcher::singleton().callUnboxedOnly( @@ -3825,13 +3819,7 @@ inline Tensor & Tensor::lt_(Scalar other) const { } inline Tensor & Tensor::lt_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lt_(const_cast(*this), other); - break; - default: - AT_ERROR("lt_ not implemented for ", at::toString(type_set())); - } + return TypeDefault::lt_(const_cast(*this), other); #else static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Tensor"}).value(); return c10::Dispatcher::singleton().callUnboxedOnly( diff --git a/aten/src/ATen/detail/ScalarTypeConversions.h b/aten/src/ATen/detail/ScalarTypeConversions.h index 10f53bd1736..5c197b2d440 100644 --- a/aten/src/ATen/detail/ScalarTypeConversions.h +++ b/aten/src/ATen/detail/ScalarTypeConversions.h @@ -9,14 +9,14 @@ namespace at { namespace detail { template inline T load(const void* data, ScalarType src_type) { - return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src_type, "load", [&]() { + return AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src_type, "load", [&]() { return at::convert(*(scalar_t*)data); }); } template inline void store(T value, void* dst, ScalarType dst_type) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, dst_type, "store", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, dst_type, "store", [&]() { *(scalar_t*)dst = at::convert(value); }); } diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 0c75df68b76..2af31f364fd 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -15,6 +15,7 @@ DEFINE_DISPATCH(mul_stub); DEFINE_DISPATCH(div_stub); DEFINE_DISPATCH(atan2_stub); DEFINE_DISPATCH(logical_xor_stub); +DEFINE_DISPATCH(lt_stub); static constexpr char alpha_mismatch_err[] = "For integral input tensors, argument alpha must not be a floating point number."; @@ -140,6 +141,18 @@ static Tensor wrapped_scalar_tensor(Scalar scalar) { return tensor; } +static void check_convert(Scalar scalar, ScalarType scalarType) { + // Validate that is possible to convert scalar to tensor dtype without overflow + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{ + scalar.to(); + }); +} + +static Tensor wrapped_scalar_tensor_and_check_convert(Scalar scalar, Tensor tensor) { + check_convert(scalar, tensor.scalar_type()); + return wrapped_scalar_tensor(scalar); +} + Tensor add(const Tensor& self, Scalar other, Scalar alpha) { return native::add(self, wrapped_scalar_tensor(other), alpha); } @@ -203,5 +216,69 @@ Tensor& logical_xor_(Tensor& self, const Tensor& other) { return native::logical_xor_out(self, self, other); } +template +static inline Tensor& comparison_op_impl_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) { + auto iter = TensorIterator::comparison_op(result, self, other, + /*check_mem_overlap=*/true); + stub(iter.device_type(), iter); + return result; +} + +template +Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) { + TORCH_CHECK(result.scalar_type() == kBool, + "The output tensor of lt must be a bool, but was ", result.scalar_type()); + // Validate that is possible to convert zero-dim tensor's dtype to other dtype without overflow + if (self.scalar_type() != other.scalar_type()) { + if (self.dim() != 0 && other.dim() == 0) { + check_convert(other.item(), self.scalar_type()); + } else if (self.dim() == 0 && other.dim() != 0) { + check_convert(self.item(), other.scalar_type()); + } + } + return native::comparison_op_impl_out(result, self, other, stub); +} + +template +Tensor comparison_op(const Tensor& self, const Tensor& other, Stub& stub) { + Tensor result = at::empty({0}, self.options().dtype(kBool)); + return native::comparison_op_out(result, self, other, stub); +} + +// To avoid overflow during type promotion we will check that both dtypes of self and other are same +template +Tensor& comparison_op_(Tensor& self, const Tensor& other, Stub& stub) { + TORCH_CHECK(self.dtype() == other.dtype(), + "Expected object of scalar type ", self.dtype(), " but got scalar type ", + other.dtype(), " for argument 'other'"); + return native::comparison_op_impl_out(self, self, other, stub); +} + +// validates that is possible to convert Scalar other to self's dtype without overflow. +// This behavior is unique to comparison ops; arithmetic operations don't do this. +// In the future, we should reconsider this inconsistency and decide if we want to add the same check to arithmetic ops. +template +Tensor& comparison_op_out(Tensor& result, const Tensor& self, Scalar other, Stub& stub) { + return native::comparison_op_out(result, self, wrapped_scalar_tensor_and_check_convert(other, self), stub); +} + +template +Tensor comparison_op(const Tensor& self, Scalar other, Stub& stub) { + Tensor result = at::empty({0}, self.options().dtype(kBool)); + return native::comparison_op_out(result, self, other, stub); +} + +template +Tensor& comparison_op_(Tensor& self, Scalar other, Stub& stub) { + return native::comparison_op_impl_out(self, self, wrapped_scalar_tensor_and_check_convert(other, self), stub); +} + +Tensor& lt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, lt_stub); } +Tensor lt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, lt_stub); } +Tensor& lt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, lt_stub); } +Tensor& lt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, lt_stub); } +Tensor lt(const Tensor& self, Scalar other) { return comparison_op(self, other, lt_stub); } +Tensor& lt_(Tensor& self, Scalar other) { return comparison_op_(self, other, lt_stub); } + } } // namespace at diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 41d1a8f2a04..a8c23fe97ec 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -16,5 +16,6 @@ DECLARE_DISPATCH(binary_fn, mul_stub); DECLARE_DISPATCH(binary_fn, div_stub); DECLARE_DISPATCH(binary_fn, atan2_stub); DECLARE_DISPATCH(binary_fn, logical_xor_stub); +DECLARE_DISPATCH(binary_fn, lt_stub); }} // namespace at::native diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index 5f20195b47a..93226879c3b 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -94,26 +94,6 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s return legacy::cpu::_th_gather(self, dim, index); } -Tensor & lt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) { - if (result.dtype() == at::ScalarType::Byte) { - AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ - "please use 'out' parameter with dtype torch.bool instead."); - return legacy::cpu::_th_lt_byte_out(result, self, other); - } else { - return legacy::cpu::_th_lt_out(result, self, other); - } -} - -Tensor & lt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) { - if (result.dtype() == at::ScalarType::Byte) { - AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ - "please use 'out' parameter with dtype torch.bool instead."); - return legacy::cpu::_th_lt_byte_out(result, self, value); - } else { - return legacy::cpu::_th_lt_out(result, self, value); - } -} - Tensor & le_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) { if (result.dtype() == at::ScalarType::Byte) { AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp index b235c058d68..444bafc965f 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/native/TensorIterator.cpp @@ -145,7 +145,6 @@ static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype) void TensorIterator::compute_types() { bool missing_dtypes = false; bool missing_output_dtypes = false; - bool has_read_write_op = false; ScalarType common_dtype = dtype(); for (auto& op : operands_) { if (!op.tensor.defined() && !op.is_type_defined()) { @@ -154,14 +153,10 @@ void TensorIterator::compute_types() { missing_output_dtypes = true; } } - if (op.is_read_write) { - has_read_write_op = true; - } } if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS) { TORCH_CHECK(!missing_output_dtypes, "unable to compute and promote common dtype based only on inputs if there are missing dtypes for outputs"); - TORCH_CHECK(!has_read_write_op, "unable to compute and promote common dtype based only on inputs if input is same as output"); } bool compute_common_dtype = (compute_common_dtype_strategy_ != CommonDTypeStrategy::COMPUTE_NONE); @@ -611,6 +606,19 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, return iter; } +TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a, + const Tensor& b, bool check_mem_overlap) { + auto iter = TensorIterator(); + iter.set_check_mem_overlap(check_mem_overlap); + iter.add_output(out); + iter.add_input(a); + iter.add_input(b); + iter.allow_cpu_scalars_ = true; + iter.compute_common_dtype_only_for_inputs(); + iter.build(); + return iter; +} + TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a, bool check_mem_overlap) { auto iter = TensorIterator(); diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h index 47d75a1533c..82259e8ffff 100644 --- a/aten/src/ATen/native/TensorIterator.h +++ b/aten/src/ATen/native/TensorIterator.h @@ -156,6 +156,8 @@ struct CAFFE2_API TensorIterator { static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b, bool check_mem_overlap = false); + static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b, + bool check_mem_overlap = false); static TensorIterator unary_op(Tensor& out, const Tensor& a, bool check_mem_overlap = false); static TensorIterator nullary_op(Tensor& out); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 326bdcbf9ec..233e67b8f19 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -103,6 +103,24 @@ void logical_xor_kernel(TensorIterator& iter) { }); } +void lt_kernel(TensorIterator& iter) { + if (iter.dtype() == ScalarType::Bool) { + AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16, iter.input_dtype(), "lt_cpu", [&]() { + cpu_kernel(iter, + [=](scalar_t a, scalar_t b) -> bool { + return a < b; + }); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "lt_cpu", [&]() { + cpu_kernel(iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return a < b; + }); + }); + } +} + } // anonymous namespace @@ -112,5 +130,6 @@ REGISTER_DISPATCH(mul_stub, &mul_kernel); REGISTER_DISPATCH(div_stub, &div_kernel); REGISTER_DISPATCH(atan2_stub, &atan2_kernel); REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel); +REGISTER_DISPATCH(lt_stub, <_kernel); }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BinaryOpsKernel.cu b/aten/src/ATen/native/cuda/BinaryOpsKernel.cu index 4b6cba748f6..38a4f3654a3 100644 --- a/aten/src/ATen/native/cuda/BinaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryOpsKernel.cu @@ -76,11 +76,28 @@ void logical_xor_kernel_cuda(TensorIterator& iter) { }); } +void lt_kernel_cuda(TensorIterator& iter) { + if (iter.dtype() == ScalarType::Bool) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "lt_cpu", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return a < b; + }); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "lt_cpu", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a < b; + }); + }); + } +} + REGISTER_DISPATCH(add_stub, &add_kernel_cuda); REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda); REGISTER_DISPATCH(div_stub, &div_kernel_cuda); REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda); REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda); REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel_cuda); +REGISTER_DISPATCH(lt_stub, <_kernel_cuda); }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp index 892f0d8c88a..3e8f9d3d872 100644 --- a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp @@ -90,26 +90,6 @@ Tensor gather_cuda(const Tensor & self, int64_t dim, const Tensor & index, bool return legacy::cuda::_th_gather(self, dim, index); } -Tensor & lt_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) { - if (result.dtype() == at::ScalarType::Byte) { - AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ - "please use 'out' parameter with dtype torch.bool instead."); - return legacy::cuda::_th_lt_byte_out(result, self, other); - } else { - return legacy::cuda::_th_lt_out(result, self, other); - } -} - -Tensor & lt_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) { - if (result.dtype() == at::ScalarType::Byte) { - AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ - "please use 'out' parameter with dtype torch.bool instead."); - return legacy::cuda::_th_lt_byte_out(result, self, value); - } else { - return legacy::cuda::_th_lt_out(result, self, value); - } -} - Tensor & le_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) { if (result.dtype() == at::ScalarType::Byte) { AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8ad5064e320..7b7156d7ab0 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4008,16 +4008,10 @@ - func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) use_c10_dispatcher: unboxed_only variants: method - dispatch: - CPU: legacy::cpu::_th_lt_ - CUDA: legacy::cuda::_th_lt_ - func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) use_c10_dispatcher: unboxed_only variants: method - dispatch: - CPU: legacy::cpu::_th_lt_ - CUDA: legacy::cuda::_th_lt_ - func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) use_c10_dispatcher: unboxed_only @@ -4614,30 +4608,30 @@ - func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: lt_scalar_out_cpu - CUDA: lt_scalar_out_cuda + CPU: lt_out + CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu - func: lt.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full variants: method, function dispatch: - CPU: legacy::cpu::_th_lt - CUDA: legacy::cuda::_th_lt + CPU: lt + CUDA: lt QuantizedCPU: lt_quantized_cpu - func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: lt_out_cpu - CUDA: lt_out_cuda + CPU: lt_out + CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu - func: lt.Tensor(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full variants: method, function dispatch: - CPU: legacy::cpu::_th_lt - CUDA: legacy::cuda::_th_lt + CPU: lt + CUDA: lt QuantizedCPU: lt_quantized_cpu - func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) diff --git a/aten/src/ATen/test/tensor_iterator_test.cpp b/aten/src/ATen/test/tensor_iterator_test.cpp index 0ba5da05810..ef4d17a2f8b 100644 --- a/aten/src/ATen/test/tensor_iterator_test.cpp +++ b/aten/src/ATen/test/tensor_iterator_test.cpp @@ -45,9 +45,9 @@ TEST(TensorIteratorTest, MixedDevices) { Tensor random_tensor_for_type(at::ScalarType scalar_type) { if (at::isFloatingType(scalar_type)) { - return at::randn({5, 5}, kCPU); + return at::randn({5, 5}, at::device(kCPU).dtype(scalar_type)); } else { - return at::randint(1, 10, {5, 5}, kCPU); + return at::randint(1, 10, {5, 5}, at::device(kCPU).dtype(scalar_type)); } } @@ -89,9 +89,30 @@ TEST(TensorIteratorTest, SerialLoopPointwise_##name) { ASSERT_ANY_THROW(out.equal(expected)); \ } +// The alternative way to calculate a < b is (b - a).clamp(0).toBool() +// To prevent an overflow in subtraction (b - a) for unsigned types(unit, bool) +// we will convert in to int first +#define COMPARISON_TEST_ITER_FOR_TYPE(ctype,name) \ +TEST(TensorIteratorTest, ComparisonLoopBinary_##name) { \ + auto in1 = random_tensor_for_type(k##name); \ + auto in2 = random_tensor_for_type(k##name); \ + Tensor out = at::empty({0}, in1.options().dtype(kBool)); \ + Tensor diff; \ + if (k##name == kByte || k##name == kBool) { \ + diff = in2.to(kInt).sub(in1.to(kInt)); \ + } else { \ + diff = in2.sub(in1); \ + } \ + auto expected = diff.clamp_min(0).to(kBool); \ + auto iter = TensorIterator::comparison_op(out, in1, in2, true); \ + at::native::cpu_serial_kernel(iter, [=](ctype a, ctype b) -> bool { return a < b; }); \ + EXPECT_TRUE(out.equal(expected)); \ +} + AT_FORALL_SCALAR_TYPES(UNARY_TEST_ITER_FOR_TYPE) AT_FORALL_SCALAR_TYPES(BINARY_TEST_ITER_FOR_TYPE) AT_FORALL_SCALAR_TYPES(POINTWISE_TEST_ITER_FOR_TYPE) +AT_FORALL_SCALAR_TYPES_AND(Bool, COMPARISON_TEST_ITER_FOR_TYPE) TEST(TensorIteratorTest, SerialLoopSingleThread) { std::thread::id thread_id = std::this_thread::get_id(); @@ -142,16 +163,6 @@ TEST(TensorIteratorTest, DoNotComputeCommonDTypeInputOnly) { EXPECT_TRUE(iter.dtype(2) == at::kDouble); } -TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfInputSameAsOutput) { - Tensor inout = at::ones({1, 1}, at::dtype(at::kFloat)); - auto iter = at::TensorIterator(); - iter.add_output(inout); - iter.add_input(inout); - iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble))); - iter.compute_common_dtype_only_for_inputs(); - ASSERT_ANY_THROW(iter.build()); -} - TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) { Tensor out; auto iter = at::TensorIterator(); diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index 174e6df519e..ad343674951 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -35,8 +35,7 @@ TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb, \ *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ } - -TENSOR_IMPLEMENT_LOGICAL(lt,<) + TENSOR_IMPLEMENT_LOGICAL(gt,>) TENSOR_IMPLEMENT_LOGICAL(le,<=) TENSOR_IMPLEMENT_LOGICAL(ge,>=) @@ -57,7 +56,6 @@ TENSOR_IMPLEMENT_LOGICAL(ne,!=) *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ } \ -TENSOR_IMPLEMENT_LOGICAL_BYTE(lt,<) TENSOR_IMPLEMENT_LOGICAL_BYTE(gt,>) TENSOR_IMPLEMENT_LOGICAL_BYTE(le,<=) TENSOR_IMPLEMENT_LOGICAL_BYTE(ge,>=) diff --git a/aten/src/THC/THCTensorMathCompareT.cuh b/aten/src/THC/THCTensorMathCompareT.cuh index 936c825d655..f31ad31aa84 100644 --- a/aten/src/THC/THCTensorMathCompareT.cuh +++ b/aten/src/THC/THCTensorMathCompareT.cuh @@ -8,13 +8,6 @@ #include #include -template -struct TensorLTOp { - __device__ inline void operator()(TOut* out, T* a, T* b) { - *out = ScalarConvert::to(THCNumerics::lt(*a, *b)); - } -}; - template struct TensorGTOp { __device__ inline void operator()(TOut* out, T* a, T* b) { diff --git a/aten/src/THC/generic/THCTensorMathCompareT.cu b/aten/src/THC/generic/THCTensorMathCompareT.cu index 106b4d70aea..0183429de33 100644 --- a/aten/src/THC/generic/THCTensorMathCompareT.cu +++ b/aten/src/THC/generic/THCTensorMathCompareT.cu @@ -2,14 +2,6 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathCompareT.cu" #else -void THCTensor_(ltTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, - TensorLTOp()); -} - void THCTensor_(gtTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); @@ -50,13 +42,6 @@ void THCTensor_(neTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *s bool>()); } -void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, - TensorLTOp()); -} void THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2) { @@ -98,13 +83,6 @@ void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, T scalar_t>()); } -void THCTensor_(ltTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, - TensorLTOp()); -} void THCTensor_(gtTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) { diff --git a/aten/src/THC/generic/THCTensorMathCompareT.h b/aten/src/THC/generic/THCTensorMathCompareT.h index d4387ceb2a2..7e6ebfe59cb 100644 --- a/aten/src/THC/generic/THCTensorMathCompareT.h +++ b/aten/src/THC/generic/THCTensorMathCompareT.h @@ -2,21 +2,18 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathCompareT.h" #else -THC_API void THCTensor_(ltTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(gtTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(leTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(geTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(eqTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(neTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(leTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(geTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(eqTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(ltTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(gtTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(leTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(geTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); diff --git a/test/test_torch.py b/test/test_torch.py index f3f12f5a913..bba04604e9e 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6205,6 +6205,23 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], for idx in iter_indices(x): self.assertEqual(x[idx] >= y[idx], ge[idx] == 1) + def test_comparison_ops_must_take_bool_output(self): + with self.assertRaisesRegex(RuntimeError, 'The output tensor of lt must be a bool'): + torch.lt(torch.tensor([True]), torch.tensor([False]), out=torch.empty(1, dtype=torch.uint8)) + + def test_inplace_comparison_ops_require_inputs_have_same_dtype(self): + with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'): + torch.tensor([1], dtype=torch.int).lt_(torch.tensor([2], dtype=torch.long)) + + def test_comparison_ops_check_for_scalar_overflow(self): + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + torch.tensor([1 << 5], dtype=torch.uint8) < (1 << 20) + + def test_comparison_ops_check_for_zerodim_tensor_overflow(self): + with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + torch.tensor([1 << 5], dtype=torch.uint8) < torch.tensor(1 << 20, dtype=torch.int32) + torch.tensor(1 << 40, dtype=torch.int64) < torch.tensor([1 << 30], dtype=torch.int32) + def test_bitwise_ops(self): x = torch.randn(5, 5).gt(0) y = torch.randn(5, 5).gt(0) @@ -6740,11 +6757,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], e1.fill_diagonal_(v, wrap=True) self.assertEqual(e1, e2) - def test_function_unwrap_message(self): - self.assertRaisesRegex(RuntimeError, ' call to _th_lt', - lambda: torch.ones(1, dtype=torch.float) < torch.ones(1, dtype=torch.double)) - - # Functions to test negative dimension wrapping METHOD = 1 INPLACE_METHOD = 2 @@ -11145,10 +11157,6 @@ class TestTorchDeviceType(TestCase): byteRes = torch.empty_like(x, device=device).byte() boolRes = torch.empty_like(x, device=device).bool() - torch.lt(x, b, out=byteRes) - torch.lt(x, b, out=boolRes) - self.assertEqual(byteRes.bool(), boolRes) - torch.le(x, b, out=byteRes) torch.le(x, b, out=boolRes) self.assertEqual(byteRes.bool(), boolRes) @@ -11169,7 +11177,7 @@ class TestTorchDeviceType(TestCase): torch.ne(x, b, out=boolRes) self.assertEqual(byteRes.bool(), boolRes) - self.assertEquals(len(warningsCount), 6) + self.assertEquals(len(warningsCount), 5) # Bool Tensor x = torch.tensor([True, False, True, False], device=device) diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 7c7de388991..233c9a34c09 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -227,6 +227,87 @@ class TestTypePromotion(TestCase): self.assertEqual(torch.result_type(torch.tensor([1., 1.], dtype=torch.float), 1.), torch.float) self.assertEqual(torch.result_type(torch.tensor(1., dtype=torch.float), torch.tensor(1, dtype=torch.double)), torch.double) + def test_comparison_ops_with_type_promotion(self): + value_for_type = { + torch.uint8: (1 << 5), + torch.int8: (1 << 5), + torch.int16: (1 << 10), + torch.int32: (1 << 20), + torch.int64: (1 << 35), + torch.float16: (1 << 10), + torch.float32: (1 << 20), + torch.float64: (1 << 35) + } + comparison_ops = [ + dict( + name="lt", + out_op=lambda x, y, d: torch.lt(x, y, out=torch.empty(1, dtype=torch.bool, device=d)), + ret_op=lambda x, y: torch.lt(x, y), + compare_op=lambda x, y: x < y, + ), + ] + device = self.device + for op in comparison_ops: + for dt1 in torch.testing.get_all_math_dtypes(device): + for dt2 in torch.testing.get_all_math_dtypes(device): + val1 = value_for_type[dt1] + val2 = value_for_type[dt2] + t1 = torch.tensor([val1], dtype=dt1, device=device) + t2 = torch.tensor([val2], dtype=dt2, device=device) + expected = torch.tensor([op["compare_op"](val1, val2)], dtype=torch.bool) + + out_res = op["out_op"](t1, t2, device) + self.assertEqual(out_res, expected) + self.assertTrue(out_res.dtype == torch.bool) + self.assertTrue(t1.dtype == dt1) + self.assertTrue(t2.dtype == dt2) + + out_res = op["ret_op"](t1, t2) + self.assertEqual(out_res, expected) + self.assertTrue(out_res.dtype == torch.bool) + self.assertTrue(t1.dtype == dt1) + self.assertTrue(t2.dtype == dt2) + + # test that comparing a zero dim tensor with another zero dim tensor has type promotion behavior + t1 = torch.tensor(val1, dtype=dt1, device=device) + t2 = torch.tensor(val2, dtype=dt2, device=device) + expected = torch.tensor(op["compare_op"](val1, val2), dtype=torch.bool) + + out_res = op["out_op"](t1, t2, device) + self.assertEqual(out_res, expected) + self.assertTrue(out_res.dtype == torch.bool) + self.assertTrue(t1.dtype == dt1) + self.assertTrue(t2.dtype == dt2) + + out_res = op["ret_op"](t1, t2) + self.assertEqual(out_res, expected) + self.assertTrue(out_res.dtype == torch.bool) + self.assertTrue(t1.dtype == dt1) + self.assertTrue(t2.dtype == dt2) + + def test_lt_with_type_promotion(self): + for dt in torch.testing.get_all_math_dtypes(self.device): + x = torch.tensor([0], dtype=dt, device=self.device) + expected = torch.tensor([True], dtype=torch.bool, device=self.device) + + actual = x < 0.5 + self.assertTrue(actual, expected) + self.assertTrue(actual.dtype == torch.bool) + + actual = x < torch.tensor(0.5) + self.assertTrue(actual, expected) + self.assertTrue(actual.dtype == torch.bool) + + x = torch.tensor(0, dtype=dt, device=self.device) + expected = torch.tensor(True, dtype=torch.bool, device=self.device) + actual = x < 0.5 + self.assertTrue(actual, expected) + self.assertTrue(actual.dtype == torch.bool) + + actual = x < torch.tensor(0.5) + self.assertTrue(actual, expected) + self.assertTrue(actual.dtype == torch.bool) + @unittest.skipIf(not torch.cuda.is_available(), "no cuda") class TestTypePromotionCuda(TestTypePromotion): def setUp(self):