Migrate lt and lt_ from the TH to Aten (#25998)

Summary:
https://github.com/pytorch/pytorch/issues/24593
https://github.com/pytorch/pytorch/issues/24727

**torch.lt(Tensor a, Tensor b)**
will compute common dtype (highest) based on inputs and then compare values. The result will be Bool tensor
```
>>> x = torch.tensor([0], dtype=torch.int)
>>> y = torch.tensor([0.5], dtype=torch.double)
>>> x < y
tensor([True])
```
Previously it was impossible to make comparison of two tensors with different dtype.

**torch.lt(Tensor a, Tensor b, out=c)**
will compute common dtype (highest) based on inputs and then compare values. The result can be populated only to Bool tensor
```
>>> x = torch.tensor([0], dtype=torch.int)
>>> y = torch.tensor([0.5], dtype=torch.double)
>>> z = torch.empty([1], dtype=torch.bool)
>>> torch.lt(x, y, out=z)
tensor([True])
```
Previously it was impossible to make comparison of two tensors with different dtype. Also previously the result dtype could be Bool and Byte(deprecated). Currently it will accept only Bool result.

**a.lt_(Tensor b)**
Expects that a and b has same dtype, otherwise it's possible to get an overflow(Example: 'a' is uint8, 'b' is float32. 'a' will be promoted to float32 and the result will be also float32. Then it will be casted back to uint8 so potential for overflow). Will not compute common dtype. Result will have type of a.
```
>>> x = torch.tensor([0], dtype=torch.double)
>>> y = torch.tensor([0.5], dtype=torch.double)
>>> x < y
tensor([True])
```
Works similar to previous implementation.

**torch.lt(Tensor a, Scalar b)**
will check if there is no overflow when converting b to the same type as a. Then will compute common dtype and compare.
```
>>> x = torch.tensor([0], dtype=torch.double)
>>> x < 0.5
tensor([True])

>>> x = torch.tensor([0], dtype=torch.int)
>>> x < 0.5
tensor([True])
```
Fix https://github.com/pytorch/pytorch/issues/22301.

**torch.lt(Tensor a, Scalar b, out=c)**
will check if there is no overflow when converting b to the same type as a. Then will compute common dtype and compare. The result can be populated only to Bool tensor
```
>>> x = torch.tensor([0], dtype=torch.double)
>>> torch.lt(x, 0.5, out=z)
tensor([True])
```
Previously the result dtype could be Bool and Byte(deprecated). Currently it will accept only Bool result. The rest works similar to previous implementation.

**torch.lt_(Tensor a, Scalar b)**
will check if there is no overflow when converting b to the same type as a. Then will compute common dtype and compare. Result will have type of a.
```
>>> x = torch.tensor([0], dtype=torch.int)
>>> x.lt_(1)
tensor([1], dtype=torch.int32)
>>> x = torch.tensor([0], dtype=torch.int)
>>> x.lt_(1.0)
tensor([1], dtype=torch.int32)
```
Works similar to previous implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25998

Differential Revision: D17431853

Pulled By: ifedan

fbshipit-source-id: b5effc6a5d9b32da379395b32abc628b604faaf7
This commit is contained in:
Igor Fedan 2019-09-26 16:03:43 -07:00 committed by Facebook Github Bot
parent 9dd8a129de
commit f99bc714c7
19 changed files with 264 additions and 198 deletions

View File

@ -606,72 +606,6 @@
broadcast: other inplace fallback
- THTensor* other
]]
[[
name: _th_lt
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
variants:
- function
return: argument 0
options:
- cname: ltValue
arguments:
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: ltTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_lt_byte
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
variants:
- function
return: argument 0
options:
- cname: ltValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: ltTensorByte
arguments:
- arg: THByteTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_lt_
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
return: self
variants: function
options:
- cname: ltValueT
arguments:
- THTensor* self
- THTensor* self
- real other
- cname: ltTensorT
arguments:
- THTensor* self
- arg: THTensor* self
broadcast: other inplace fallback
- arg: THTensor* other
]]
[[
name: _th_gt
cpu_bool: True

View File

@ -3810,13 +3810,7 @@ inline Tensor Tensor::scatter_add(Dimname dim, const Tensor & index, const Tenso
#endif
inline Tensor & Tensor::lt_(Scalar other) const {
#ifdef USE_STATIC_DISPATCH
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
case Backend::CPU:
return CPUType::lt_(const_cast<Tensor&>(*this), other);
break;
default:
AT_ERROR("lt_ not implemented for ", at::toString(type_set()));
}
return TypeDefault::lt_(const_cast<Tensor&>(*this), other);
#else
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Scalar"}).value();
return c10::Dispatcher::singleton().callUnboxedOnly<Tensor &, Tensor &, Scalar>(
@ -3825,13 +3819,7 @@ inline Tensor & Tensor::lt_(Scalar other) const {
}
inline Tensor & Tensor::lt_(const Tensor & other) const {
#ifdef USE_STATIC_DISPATCH
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
case Backend::CPU:
return CPUType::lt_(const_cast<Tensor&>(*this), other);
break;
default:
AT_ERROR("lt_ not implemented for ", at::toString(type_set()));
}
return TypeDefault::lt_(const_cast<Tensor&>(*this), other);
#else
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Tensor"}).value();
return c10::Dispatcher::singleton().callUnboxedOnly<Tensor &, Tensor &, const Tensor &>(

View File

@ -9,14 +9,14 @@ namespace at { namespace detail {
template <typename T>
inline T load(const void* data, ScalarType src_type) {
return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src_type, "load", [&]() {
return AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src_type, "load", [&]() {
return at::convert<T>(*(scalar_t*)data);
});
}
template <typename T>
inline void store(T value, void* dst, ScalarType dst_type) {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, dst_type, "store", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, dst_type, "store", [&]() {
*(scalar_t*)dst = at::convert<scalar_t>(value);
});
}

View File

@ -15,6 +15,7 @@ DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(div_stub);
DEFINE_DISPATCH(atan2_stub);
DEFINE_DISPATCH(logical_xor_stub);
DEFINE_DISPATCH(lt_stub);
static constexpr char alpha_mismatch_err[] =
"For integral input tensors, argument alpha must not be a floating point number.";
@ -140,6 +141,18 @@ static Tensor wrapped_scalar_tensor(Scalar scalar) {
return tensor;
}
static void check_convert(Scalar scalar, ScalarType scalarType) {
// Validate that is possible to convert scalar to tensor dtype without overflow
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{
scalar.to<scalar_t>();
});
}
static Tensor wrapped_scalar_tensor_and_check_convert(Scalar scalar, Tensor tensor) {
check_convert(scalar, tensor.scalar_type());
return wrapped_scalar_tensor(scalar);
}
Tensor add(const Tensor& self, Scalar other, Scalar alpha) {
return native::add(self, wrapped_scalar_tensor(other), alpha);
}
@ -203,5 +216,69 @@ Tensor& logical_xor_(Tensor& self, const Tensor& other) {
return native::logical_xor_out(self, self, other);
}
template <typename Stub>
static inline Tensor& comparison_op_impl_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
auto iter = TensorIterator::comparison_op(result, self, other,
/*check_mem_overlap=*/true);
stub(iter.device_type(), iter);
return result;
}
template <typename Stub>
Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
TORCH_CHECK(result.scalar_type() == kBool,
"The output tensor of lt must be a bool, but was ", result.scalar_type());
// Validate that is possible to convert zero-dim tensor's dtype to other dtype without overflow
if (self.scalar_type() != other.scalar_type()) {
if (self.dim() != 0 && other.dim() == 0) {
check_convert(other.item(), self.scalar_type());
} else if (self.dim() == 0 && other.dim() != 0) {
check_convert(self.item(), other.scalar_type());
}
}
return native::comparison_op_impl_out(result, self, other, stub);
}
template <typename Stub>
Tensor comparison_op(const Tensor& self, const Tensor& other, Stub& stub) {
Tensor result = at::empty({0}, self.options().dtype(kBool));
return native::comparison_op_out(result, self, other, stub);
}
// To avoid overflow during type promotion we will check that both dtypes of self and other are same
template <typename Stub>
Tensor& comparison_op_(Tensor& self, const Tensor& other, Stub& stub) {
TORCH_CHECK(self.dtype() == other.dtype(),
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
other.dtype(), " for argument 'other'");
return native::comparison_op_impl_out(self, self, other, stub);
}
// validates that is possible to convert Scalar other to self's dtype without overflow.
// This behavior is unique to comparison ops; arithmetic operations don't do this.
// In the future, we should reconsider this inconsistency and decide if we want to add the same check to arithmetic ops.
template <typename Stub>
Tensor& comparison_op_out(Tensor& result, const Tensor& self, Scalar other, Stub& stub) {
return native::comparison_op_out(result, self, wrapped_scalar_tensor_and_check_convert(other, self), stub);
}
template <typename Stub>
Tensor comparison_op(const Tensor& self, Scalar other, Stub& stub) {
Tensor result = at::empty({0}, self.options().dtype(kBool));
return native::comparison_op_out(result, self, other, stub);
}
template <typename Stub>
Tensor& comparison_op_(Tensor& self, Scalar other, Stub& stub) {
return native::comparison_op_impl_out(self, self, wrapped_scalar_tensor_and_check_convert(other, self), stub);
}
Tensor& lt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, lt_stub); }
Tensor lt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, lt_stub); }
Tensor& lt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, lt_stub); }
Tensor& lt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, lt_stub); }
Tensor lt(const Tensor& self, Scalar other) { return comparison_op(self, other, lt_stub); }
Tensor& lt_(Tensor& self, Scalar other) { return comparison_op_(self, other, lt_stub); }
}
} // namespace at

View File

@ -16,5 +16,6 @@ DECLARE_DISPATCH(binary_fn, mul_stub);
DECLARE_DISPATCH(binary_fn, div_stub);
DECLARE_DISPATCH(binary_fn, atan2_stub);
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
DECLARE_DISPATCH(binary_fn, lt_stub);
}} // namespace at::native

View File

@ -94,26 +94,6 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s
return legacy::cpu::_th_gather(self, dim, index);
}
Tensor & lt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_lt_byte_out(result, self, other);
} else {
return legacy::cpu::_th_lt_out(result, self, other);
}
}
Tensor & lt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_lt_byte_out(result, self, value);
} else {
return legacy::cpu::_th_lt_out(result, self, value);
}
}
Tensor & le_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \

View File

@ -145,7 +145,6 @@ static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype)
void TensorIterator::compute_types() {
bool missing_dtypes = false;
bool missing_output_dtypes = false;
bool has_read_write_op = false;
ScalarType common_dtype = dtype();
for (auto& op : operands_) {
if (!op.tensor.defined() && !op.is_type_defined()) {
@ -154,14 +153,10 @@ void TensorIterator::compute_types() {
missing_output_dtypes = true;
}
}
if (op.is_read_write) {
has_read_write_op = true;
}
}
if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS) {
TORCH_CHECK(!missing_output_dtypes, "unable to compute and promote common dtype based only on inputs if there are missing dtypes for outputs");
TORCH_CHECK(!has_read_write_op, "unable to compute and promote common dtype based only on inputs if input is same as output");
}
bool compute_common_dtype = (compute_common_dtype_strategy_ != CommonDTypeStrategy::COMPUTE_NONE);
@ -611,6 +606,19 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
return iter;
}
TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
const Tensor& b, bool check_mem_overlap) {
auto iter = TensorIterator();
iter.set_check_mem_overlap(check_mem_overlap);
iter.add_output(out);
iter.add_input(a);
iter.add_input(b);
iter.allow_cpu_scalars_ = true;
iter.compute_common_dtype_only_for_inputs();
iter.build();
return iter;
}
TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a,
bool check_mem_overlap) {
auto iter = TensorIterator();

View File

@ -156,6 +156,8 @@ struct CAFFE2_API TensorIterator {
static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b,
bool check_mem_overlap = false);
static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b,
bool check_mem_overlap = false);
static TensorIterator unary_op(Tensor& out, const Tensor& a,
bool check_mem_overlap = false);
static TensorIterator nullary_op(Tensor& out);

View File

@ -103,6 +103,24 @@ void logical_xor_kernel(TensorIterator& iter) {
});
}
void lt_kernel(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16, iter.input_dtype(), "lt_cpu", [&]() {
cpu_kernel(iter,
[=](scalar_t a, scalar_t b) -> bool {
return a < b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "lt_cpu", [&]() {
cpu_kernel(iter,
[=](scalar_t a, scalar_t b) -> scalar_t {
return a < b;
});
});
}
}
} // anonymous namespace
@ -112,5 +130,6 @@ REGISTER_DISPATCH(mul_stub, &mul_kernel);
REGISTER_DISPATCH(div_stub, &div_kernel);
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel);
REGISTER_DISPATCH(lt_stub, &lt_kernel);
}} // namespace at::native

View File

@ -76,11 +76,28 @@ void logical_xor_kernel_cuda(TensorIterator& iter) {
});
}
void lt_kernel_cuda(TensorIterator& iter) {
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "lt_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a < b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "lt_cpu", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a < b;
});
});
}
}
REGISTER_DISPATCH(add_stub, &add_kernel_cuda);
REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda);
REGISTER_DISPATCH(div_stub, &div_kernel_cuda);
REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda);
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel_cuda);
REGISTER_DISPATCH(lt_stub, &lt_kernel_cuda);
}} // namespace at::native

View File

@ -90,26 +90,6 @@ Tensor gather_cuda(const Tensor & self, int64_t dim, const Tensor & index, bool
return legacy::cuda::_th_gather(self, dim, index);
}
Tensor & lt_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cuda::_th_lt_byte_out(result, self, other);
} else {
return legacy::cuda::_th_lt_out(result, self, other);
}
}
Tensor & lt_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cuda::_th_lt_byte_out(result, self, value);
} else {
return legacy::cuda::_th_lt_out(result, self, value);
}
}
Tensor & le_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \

View File

@ -4008,16 +4008,10 @@
- func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
use_c10_dispatcher: unboxed_only
variants: method
dispatch:
CPU: legacy::cpu::_th_lt_
CUDA: legacy::cuda::_th_lt_
- func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
use_c10_dispatcher: unboxed_only
variants: method
dispatch:
CPU: legacy::cpu::_th_lt_
CUDA: legacy::cuda::_th_lt_
- func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
use_c10_dispatcher: unboxed_only
@ -4614,30 +4608,30 @@
- func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: lt_scalar_out_cpu
CUDA: lt_scalar_out_cuda
CPU: lt_out
CUDA: lt_out
QuantizedCPU: lt_out_quantized_cpu
- func: lt.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: legacy::cpu::_th_lt
CUDA: legacy::cuda::_th_lt
CPU: lt
CUDA: lt
QuantizedCPU: lt_quantized_cpu
- func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: lt_out_cpu
CUDA: lt_out_cuda
CPU: lt_out
CUDA: lt_out
QuantizedCPU: lt_out_quantized_cpu
- func: lt.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: legacy::cpu::_th_lt
CUDA: legacy::cuda::_th_lt
CPU: lt
CUDA: lt
QuantizedCPU: lt_quantized_cpu
- func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -45,9 +45,9 @@ TEST(TensorIteratorTest, MixedDevices) {
Tensor random_tensor_for_type(at::ScalarType scalar_type) {
if (at::isFloatingType(scalar_type)) {
return at::randn({5, 5}, kCPU);
return at::randn({5, 5}, at::device(kCPU).dtype(scalar_type));
} else {
return at::randint(1, 10, {5, 5}, kCPU);
return at::randint(1, 10, {5, 5}, at::device(kCPU).dtype(scalar_type));
}
}
@ -89,9 +89,30 @@ TEST(TensorIteratorTest, SerialLoopPointwise_##name) {
ASSERT_ANY_THROW(out.equal(expected)); \
}
// The alternative way to calculate a < b is (b - a).clamp(0).toBool()
// To prevent an overflow in subtraction (b - a) for unsigned types(unit, bool)
// we will convert in to int first
#define COMPARISON_TEST_ITER_FOR_TYPE(ctype,name) \
TEST(TensorIteratorTest, ComparisonLoopBinary_##name) { \
auto in1 = random_tensor_for_type(k##name); \
auto in2 = random_tensor_for_type(k##name); \
Tensor out = at::empty({0}, in1.options().dtype(kBool)); \
Tensor diff; \
if (k##name == kByte || k##name == kBool) { \
diff = in2.to(kInt).sub(in1.to(kInt)); \
} else { \
diff = in2.sub(in1); \
} \
auto expected = diff.clamp_min(0).to(kBool); \
auto iter = TensorIterator::comparison_op(out, in1, in2, true); \
at::native::cpu_serial_kernel(iter, [=](ctype a, ctype b) -> bool { return a < b; }); \
EXPECT_TRUE(out.equal(expected)); \
}
AT_FORALL_SCALAR_TYPES(UNARY_TEST_ITER_FOR_TYPE)
AT_FORALL_SCALAR_TYPES(BINARY_TEST_ITER_FOR_TYPE)
AT_FORALL_SCALAR_TYPES(POINTWISE_TEST_ITER_FOR_TYPE)
AT_FORALL_SCALAR_TYPES_AND(Bool, COMPARISON_TEST_ITER_FOR_TYPE)
TEST(TensorIteratorTest, SerialLoopSingleThread) {
std::thread::id thread_id = std::this_thread::get_id();
@ -142,16 +163,6 @@ TEST(TensorIteratorTest, DoNotComputeCommonDTypeInputOnly) {
EXPECT_TRUE(iter.dtype(2) == at::kDouble);
}
TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfInputSameAsOutput) {
Tensor inout = at::ones({1, 1}, at::dtype(at::kFloat));
auto iter = at::TensorIterator();
iter.add_output(inout);
iter.add_input(inout);
iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble)));
iter.compute_common_dtype_only_for_inputs();
ASSERT_ANY_THROW(iter.build());
}
TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) {
Tensor out;
auto iter = at::TensorIterator();

View File

@ -36,7 +36,6 @@
*r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
}
TENSOR_IMPLEMENT_LOGICAL(lt,<)
TENSOR_IMPLEMENT_LOGICAL(gt,>)
TENSOR_IMPLEMENT_LOGICAL(le,<=)
TENSOR_IMPLEMENT_LOGICAL(ge,>=)
@ -57,7 +56,6 @@ TENSOR_IMPLEMENT_LOGICAL(ne,!=)
*r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
} \
TENSOR_IMPLEMENT_LOGICAL_BYTE(lt,<)
TENSOR_IMPLEMENT_LOGICAL_BYTE(gt,>)
TENSOR_IMPLEMENT_LOGICAL_BYTE(le,<=)
TENSOR_IMPLEMENT_LOGICAL_BYTE(ge,>=)

View File

@ -8,13 +8,6 @@
#include <THC/THCNumerics.cuh>
#include <THC/THCReduce.cuh>
template <typename T, typename TOut>
struct TensorLTOp {
__device__ inline void operator()(TOut* out, T* a, T* b) {
*out = ScalarConvert<bool, TOut>::to(THCNumerics<T>::lt(*a, *b));
}
};
template <typename T, typename TOut>
struct TensorGTOp {
__device__ inline void operator()(TOut* out, T* a, T* b) {

View File

@ -2,14 +2,6 @@
#define THC_GENERIC_FILE "THC/generic/THCTensorMathCompareT.cu"
#else
void THCTensor_(ltTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
THC_logicalTensor<bool, scalar_t>(state, self_, src1, src2,
TensorLTOp<scalar_t,
bool>());
}
void THCTensor_(gtTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
@ -50,13 +42,6 @@ void THCTensor_(neTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *s
bool>());
}
void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
THC_logicalTensor<scalar_t, scalar_t>(state, self_, src1, src2,
TensorLTOp<scalar_t,
scalar_t>());
}
void THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
{
@ -98,13 +83,6 @@ void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, T
scalar_t>());
}
void THCTensor_(ltTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
THC_logicalTensor<unsigned char, scalar_t>(state, self_, src1, src2,
TensorLTOp<scalar_t,
unsigned char>());
}
void THCTensor_(gtTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
{

View File

@ -2,21 +2,18 @@
#define THC_GENERIC_FILE "THC/generic/THCTensorMathCompareT.h"
#else
THC_API void THCTensor_(ltTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(gtTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(leTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(geTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(eqTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(neTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(leTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(geTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(eqTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(ltTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(gtTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(leTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(geTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);

View File

@ -6205,6 +6205,23 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
for idx in iter_indices(x):
self.assertEqual(x[idx] >= y[idx], ge[idx] == 1)
def test_comparison_ops_must_take_bool_output(self):
with self.assertRaisesRegex(RuntimeError, 'The output tensor of lt must be a bool'):
torch.lt(torch.tensor([True]), torch.tensor([False]), out=torch.empty(1, dtype=torch.uint8))
def test_inplace_comparison_ops_require_inputs_have_same_dtype(self):
with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'):
torch.tensor([1], dtype=torch.int).lt_(torch.tensor([2], dtype=torch.long))
def test_comparison_ops_check_for_scalar_overflow(self):
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
torch.tensor([1 << 5], dtype=torch.uint8) < (1 << 20)
def test_comparison_ops_check_for_zerodim_tensor_overflow(self):
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
torch.tensor([1 << 5], dtype=torch.uint8) < torch.tensor(1 << 20, dtype=torch.int32)
torch.tensor(1 << 40, dtype=torch.int64) < torch.tensor([1 << 30], dtype=torch.int32)
def test_bitwise_ops(self):
x = torch.randn(5, 5).gt(0)
y = torch.randn(5, 5).gt(0)
@ -6740,11 +6757,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
e1.fill_diagonal_(v, wrap=True)
self.assertEqual(e1, e2)
def test_function_unwrap_message(self):
self.assertRaisesRegex(RuntimeError, ' call to _th_lt',
lambda: torch.ones(1, dtype=torch.float) < torch.ones(1, dtype=torch.double))
# Functions to test negative dimension wrapping
METHOD = 1
INPLACE_METHOD = 2
@ -11145,10 +11157,6 @@ class TestTorchDeviceType(TestCase):
byteRes = torch.empty_like(x, device=device).byte()
boolRes = torch.empty_like(x, device=device).bool()
torch.lt(x, b, out=byteRes)
torch.lt(x, b, out=boolRes)
self.assertEqual(byteRes.bool(), boolRes)
torch.le(x, b, out=byteRes)
torch.le(x, b, out=boolRes)
self.assertEqual(byteRes.bool(), boolRes)
@ -11169,7 +11177,7 @@ class TestTorchDeviceType(TestCase):
torch.ne(x, b, out=boolRes)
self.assertEqual(byteRes.bool(), boolRes)
self.assertEquals(len(warningsCount), 6)
self.assertEquals(len(warningsCount), 5)
# Bool Tensor
x = torch.tensor([True, False, True, False], device=device)

View File

@ -227,6 +227,87 @@ class TestTypePromotion(TestCase):
self.assertEqual(torch.result_type(torch.tensor([1., 1.], dtype=torch.float), 1.), torch.float)
self.assertEqual(torch.result_type(torch.tensor(1., dtype=torch.float), torch.tensor(1, dtype=torch.double)), torch.double)
def test_comparison_ops_with_type_promotion(self):
value_for_type = {
torch.uint8: (1 << 5),
torch.int8: (1 << 5),
torch.int16: (1 << 10),
torch.int32: (1 << 20),
torch.int64: (1 << 35),
torch.float16: (1 << 10),
torch.float32: (1 << 20),
torch.float64: (1 << 35)
}
comparison_ops = [
dict(
name="lt",
out_op=lambda x, y, d: torch.lt(x, y, out=torch.empty(1, dtype=torch.bool, device=d)),
ret_op=lambda x, y: torch.lt(x, y),
compare_op=lambda x, y: x < y,
),
]
device = self.device
for op in comparison_ops:
for dt1 in torch.testing.get_all_math_dtypes(device):
for dt2 in torch.testing.get_all_math_dtypes(device):
val1 = value_for_type[dt1]
val2 = value_for_type[dt2]
t1 = torch.tensor([val1], dtype=dt1, device=device)
t2 = torch.tensor([val2], dtype=dt2, device=device)
expected = torch.tensor([op["compare_op"](val1, val2)], dtype=torch.bool)
out_res = op["out_op"](t1, t2, device)
self.assertEqual(out_res, expected)
self.assertTrue(out_res.dtype == torch.bool)
self.assertTrue(t1.dtype == dt1)
self.assertTrue(t2.dtype == dt2)
out_res = op["ret_op"](t1, t2)
self.assertEqual(out_res, expected)
self.assertTrue(out_res.dtype == torch.bool)
self.assertTrue(t1.dtype == dt1)
self.assertTrue(t2.dtype == dt2)
# test that comparing a zero dim tensor with another zero dim tensor has type promotion behavior
t1 = torch.tensor(val1, dtype=dt1, device=device)
t2 = torch.tensor(val2, dtype=dt2, device=device)
expected = torch.tensor(op["compare_op"](val1, val2), dtype=torch.bool)
out_res = op["out_op"](t1, t2, device)
self.assertEqual(out_res, expected)
self.assertTrue(out_res.dtype == torch.bool)
self.assertTrue(t1.dtype == dt1)
self.assertTrue(t2.dtype == dt2)
out_res = op["ret_op"](t1, t2)
self.assertEqual(out_res, expected)
self.assertTrue(out_res.dtype == torch.bool)
self.assertTrue(t1.dtype == dt1)
self.assertTrue(t2.dtype == dt2)
def test_lt_with_type_promotion(self):
for dt in torch.testing.get_all_math_dtypes(self.device):
x = torch.tensor([0], dtype=dt, device=self.device)
expected = torch.tensor([True], dtype=torch.bool, device=self.device)
actual = x < 0.5
self.assertTrue(actual, expected)
self.assertTrue(actual.dtype == torch.bool)
actual = x < torch.tensor(0.5)
self.assertTrue(actual, expected)
self.assertTrue(actual.dtype == torch.bool)
x = torch.tensor(0, dtype=dt, device=self.device)
expected = torch.tensor(True, dtype=torch.bool, device=self.device)
actual = x < 0.5
self.assertTrue(actual, expected)
self.assertTrue(actual.dtype == torch.bool)
actual = x < torch.tensor(0.5)
self.assertTrue(actual, expected)
self.assertTrue(actual.dtype == torch.bool)
@unittest.skipIf(not torch.cuda.is_available(), "no cuda")
class TestTypePromotionCuda(TestTypePromotion):
def setUp(self):