mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Migrate lt and lt_ from the TH to Aten (#25998)
Summary: https://github.com/pytorch/pytorch/issues/24593 https://github.com/pytorch/pytorch/issues/24727 **torch.lt(Tensor a, Tensor b)** will compute common dtype (highest) based on inputs and then compare values. The result will be Bool tensor ``` >>> x = torch.tensor([0], dtype=torch.int) >>> y = torch.tensor([0.5], dtype=torch.double) >>> x < y tensor([True]) ``` Previously it was impossible to make comparison of two tensors with different dtype. **torch.lt(Tensor a, Tensor b, out=c)** will compute common dtype (highest) based on inputs and then compare values. The result can be populated only to Bool tensor ``` >>> x = torch.tensor([0], dtype=torch.int) >>> y = torch.tensor([0.5], dtype=torch.double) >>> z = torch.empty([1], dtype=torch.bool) >>> torch.lt(x, y, out=z) tensor([True]) ``` Previously it was impossible to make comparison of two tensors with different dtype. Also previously the result dtype could be Bool and Byte(deprecated). Currently it will accept only Bool result. **a.lt_(Tensor b)** Expects that a and b has same dtype, otherwise it's possible to get an overflow(Example: 'a' is uint8, 'b' is float32. 'a' will be promoted to float32 and the result will be also float32. Then it will be casted back to uint8 so potential for overflow). Will not compute common dtype. Result will have type of a. ``` >>> x = torch.tensor([0], dtype=torch.double) >>> y = torch.tensor([0.5], dtype=torch.double) >>> x < y tensor([True]) ``` Works similar to previous implementation. **torch.lt(Tensor a, Scalar b)** will check if there is no overflow when converting b to the same type as a. Then will compute common dtype and compare. ``` >>> x = torch.tensor([0], dtype=torch.double) >>> x < 0.5 tensor([True]) >>> x = torch.tensor([0], dtype=torch.int) >>> x < 0.5 tensor([True]) ``` Fix https://github.com/pytorch/pytorch/issues/22301. **torch.lt(Tensor a, Scalar b, out=c)** will check if there is no overflow when converting b to the same type as a. Then will compute common dtype and compare. The result can be populated only to Bool tensor ``` >>> x = torch.tensor([0], dtype=torch.double) >>> torch.lt(x, 0.5, out=z) tensor([True]) ``` Previously the result dtype could be Bool and Byte(deprecated). Currently it will accept only Bool result. The rest works similar to previous implementation. **torch.lt_(Tensor a, Scalar b)** will check if there is no overflow when converting b to the same type as a. Then will compute common dtype and compare. Result will have type of a. ``` >>> x = torch.tensor([0], dtype=torch.int) >>> x.lt_(1) tensor([1], dtype=torch.int32) >>> x = torch.tensor([0], dtype=torch.int) >>> x.lt_(1.0) tensor([1], dtype=torch.int32) ``` Works similar to previous implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/25998 Differential Revision: D17431853 Pulled By: ifedan fbshipit-source-id: b5effc6a5d9b32da379395b32abc628b604faaf7
This commit is contained in:
parent
9dd8a129de
commit
f99bc714c7
|
|
@ -606,72 +606,6 @@
|
|||
broadcast: other inplace fallback
|
||||
- THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_lt
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cpu_bfloat16: True
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
options:
|
||||
- cname: ltValue
|
||||
arguments:
|
||||
- arg: THBoolTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- real other
|
||||
- cname: ltTensor
|
||||
arguments:
|
||||
- arg: THBoolTensor* result
|
||||
output: True
|
||||
- arg: THTensor* self
|
||||
broadcast: other fallback
|
||||
- THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_lt_byte
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cpu_bfloat16: True
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
options:
|
||||
- cname: ltValueByte
|
||||
arguments:
|
||||
- arg: THByteTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- real other
|
||||
- cname: ltTensorByte
|
||||
arguments:
|
||||
- arg: THByteTensor* result
|
||||
output: True
|
||||
- arg: THTensor* self
|
||||
broadcast: other fallback
|
||||
- THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_lt_
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cpu_bfloat16: True
|
||||
return: self
|
||||
variants: function
|
||||
options:
|
||||
- cname: ltValueT
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THTensor* self
|
||||
- real other
|
||||
- cname: ltTensorT
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- arg: THTensor* self
|
||||
broadcast: other inplace fallback
|
||||
- arg: THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_gt
|
||||
cpu_bool: True
|
||||
|
|
|
|||
|
|
@ -3810,13 +3810,7 @@ inline Tensor Tensor::scatter_add(Dimname dim, const Tensor & index, const Tenso
|
|||
#endif
|
||||
inline Tensor & Tensor::lt_(Scalar other) const {
|
||||
#ifdef USE_STATIC_DISPATCH
|
||||
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
|
||||
case Backend::CPU:
|
||||
return CPUType::lt_(const_cast<Tensor&>(*this), other);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("lt_ not implemented for ", at::toString(type_set()));
|
||||
}
|
||||
return TypeDefault::lt_(const_cast<Tensor&>(*this), other);
|
||||
#else
|
||||
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Scalar"}).value();
|
||||
return c10::Dispatcher::singleton().callUnboxedOnly<Tensor &, Tensor &, Scalar>(
|
||||
|
|
@ -3825,13 +3819,7 @@ inline Tensor & Tensor::lt_(Scalar other) const {
|
|||
}
|
||||
inline Tensor & Tensor::lt_(const Tensor & other) const {
|
||||
#ifdef USE_STATIC_DISPATCH
|
||||
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
|
||||
case Backend::CPU:
|
||||
return CPUType::lt_(const_cast<Tensor&>(*this), other);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("lt_ not implemented for ", at::toString(type_set()));
|
||||
}
|
||||
return TypeDefault::lt_(const_cast<Tensor&>(*this), other);
|
||||
#else
|
||||
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Tensor"}).value();
|
||||
return c10::Dispatcher::singleton().callUnboxedOnly<Tensor &, Tensor &, const Tensor &>(
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@ namespace at { namespace detail {
|
|||
|
||||
template <typename T>
|
||||
inline T load(const void* data, ScalarType src_type) {
|
||||
return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src_type, "load", [&]() {
|
||||
return AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src_type, "load", [&]() {
|
||||
return at::convert<T>(*(scalar_t*)data);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void store(T value, void* dst, ScalarType dst_type) {
|
||||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, dst_type, "store", [&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, dst_type, "store", [&]() {
|
||||
*(scalar_t*)dst = at::convert<scalar_t>(value);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ DEFINE_DISPATCH(mul_stub);
|
|||
DEFINE_DISPATCH(div_stub);
|
||||
DEFINE_DISPATCH(atan2_stub);
|
||||
DEFINE_DISPATCH(logical_xor_stub);
|
||||
DEFINE_DISPATCH(lt_stub);
|
||||
|
||||
static constexpr char alpha_mismatch_err[] =
|
||||
"For integral input tensors, argument alpha must not be a floating point number.";
|
||||
|
|
@ -140,6 +141,18 @@ static Tensor wrapped_scalar_tensor(Scalar scalar) {
|
|||
return tensor;
|
||||
}
|
||||
|
||||
static void check_convert(Scalar scalar, ScalarType scalarType) {
|
||||
// Validate that is possible to convert scalar to tensor dtype without overflow
|
||||
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{
|
||||
scalar.to<scalar_t>();
|
||||
});
|
||||
}
|
||||
|
||||
static Tensor wrapped_scalar_tensor_and_check_convert(Scalar scalar, Tensor tensor) {
|
||||
check_convert(scalar, tensor.scalar_type());
|
||||
return wrapped_scalar_tensor(scalar);
|
||||
}
|
||||
|
||||
Tensor add(const Tensor& self, Scalar other, Scalar alpha) {
|
||||
return native::add(self, wrapped_scalar_tensor(other), alpha);
|
||||
}
|
||||
|
|
@ -203,5 +216,69 @@ Tensor& logical_xor_(Tensor& self, const Tensor& other) {
|
|||
return native::logical_xor_out(self, self, other);
|
||||
}
|
||||
|
||||
template <typename Stub>
|
||||
static inline Tensor& comparison_op_impl_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
|
||||
auto iter = TensorIterator::comparison_op(result, self, other,
|
||||
/*check_mem_overlap=*/true);
|
||||
stub(iter.device_type(), iter);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Stub>
|
||||
Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
|
||||
TORCH_CHECK(result.scalar_type() == kBool,
|
||||
"The output tensor of lt must be a bool, but was ", result.scalar_type());
|
||||
// Validate that is possible to convert zero-dim tensor's dtype to other dtype without overflow
|
||||
if (self.scalar_type() != other.scalar_type()) {
|
||||
if (self.dim() != 0 && other.dim() == 0) {
|
||||
check_convert(other.item(), self.scalar_type());
|
||||
} else if (self.dim() == 0 && other.dim() != 0) {
|
||||
check_convert(self.item(), other.scalar_type());
|
||||
}
|
||||
}
|
||||
return native::comparison_op_impl_out(result, self, other, stub);
|
||||
}
|
||||
|
||||
template <typename Stub>
|
||||
Tensor comparison_op(const Tensor& self, const Tensor& other, Stub& stub) {
|
||||
Tensor result = at::empty({0}, self.options().dtype(kBool));
|
||||
return native::comparison_op_out(result, self, other, stub);
|
||||
}
|
||||
|
||||
// To avoid overflow during type promotion we will check that both dtypes of self and other are same
|
||||
template <typename Stub>
|
||||
Tensor& comparison_op_(Tensor& self, const Tensor& other, Stub& stub) {
|
||||
TORCH_CHECK(self.dtype() == other.dtype(),
|
||||
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
|
||||
other.dtype(), " for argument 'other'");
|
||||
return native::comparison_op_impl_out(self, self, other, stub);
|
||||
}
|
||||
|
||||
// validates that is possible to convert Scalar other to self's dtype without overflow.
|
||||
// This behavior is unique to comparison ops; arithmetic operations don't do this.
|
||||
// In the future, we should reconsider this inconsistency and decide if we want to add the same check to arithmetic ops.
|
||||
template <typename Stub>
|
||||
Tensor& comparison_op_out(Tensor& result, const Tensor& self, Scalar other, Stub& stub) {
|
||||
return native::comparison_op_out(result, self, wrapped_scalar_tensor_and_check_convert(other, self), stub);
|
||||
}
|
||||
|
||||
template <typename Stub>
|
||||
Tensor comparison_op(const Tensor& self, Scalar other, Stub& stub) {
|
||||
Tensor result = at::empty({0}, self.options().dtype(kBool));
|
||||
return native::comparison_op_out(result, self, other, stub);
|
||||
}
|
||||
|
||||
template <typename Stub>
|
||||
Tensor& comparison_op_(Tensor& self, Scalar other, Stub& stub) {
|
||||
return native::comparison_op_impl_out(self, self, wrapped_scalar_tensor_and_check_convert(other, self), stub);
|
||||
}
|
||||
|
||||
Tensor& lt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, lt_stub); }
|
||||
Tensor lt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, lt_stub); }
|
||||
Tensor& lt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, lt_stub); }
|
||||
Tensor& lt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, lt_stub); }
|
||||
Tensor lt(const Tensor& self, Scalar other) { return comparison_op(self, other, lt_stub); }
|
||||
Tensor& lt_(Tensor& self, Scalar other) { return comparison_op_(self, other, lt_stub); }
|
||||
|
||||
}
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -16,5 +16,6 @@ DECLARE_DISPATCH(binary_fn, mul_stub);
|
|||
DECLARE_DISPATCH(binary_fn, div_stub);
|
||||
DECLARE_DISPATCH(binary_fn, atan2_stub);
|
||||
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
|
||||
DECLARE_DISPATCH(binary_fn, lt_stub);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -94,26 +94,6 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s
|
|||
return legacy::cpu::_th_gather(self, dim, index);
|
||||
}
|
||||
|
||||
Tensor & lt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
|
||||
if (result.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use 'out' parameter with dtype torch.bool instead.");
|
||||
return legacy::cpu::_th_lt_byte_out(result, self, other);
|
||||
} else {
|
||||
return legacy::cpu::_th_lt_out(result, self, other);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor & lt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
|
||||
if (result.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use 'out' parameter with dtype torch.bool instead.");
|
||||
return legacy::cpu::_th_lt_byte_out(result, self, value);
|
||||
} else {
|
||||
return legacy::cpu::_th_lt_out(result, self, value);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor & le_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
|
||||
if (result.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
|
||||
|
|
|
|||
|
|
@ -145,7 +145,6 @@ static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype)
|
|||
void TensorIterator::compute_types() {
|
||||
bool missing_dtypes = false;
|
||||
bool missing_output_dtypes = false;
|
||||
bool has_read_write_op = false;
|
||||
ScalarType common_dtype = dtype();
|
||||
for (auto& op : operands_) {
|
||||
if (!op.tensor.defined() && !op.is_type_defined()) {
|
||||
|
|
@ -154,14 +153,10 @@ void TensorIterator::compute_types() {
|
|||
missing_output_dtypes = true;
|
||||
}
|
||||
}
|
||||
if (op.is_read_write) {
|
||||
has_read_write_op = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS) {
|
||||
TORCH_CHECK(!missing_output_dtypes, "unable to compute and promote common dtype based only on inputs if there are missing dtypes for outputs");
|
||||
TORCH_CHECK(!has_read_write_op, "unable to compute and promote common dtype based only on inputs if input is same as output");
|
||||
}
|
||||
|
||||
bool compute_common_dtype = (compute_common_dtype_strategy_ != CommonDTypeStrategy::COMPUTE_NONE);
|
||||
|
|
@ -611,6 +606,19 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
|
|||
return iter;
|
||||
}
|
||||
|
||||
TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
|
||||
const Tensor& b, bool check_mem_overlap) {
|
||||
auto iter = TensorIterator();
|
||||
iter.set_check_mem_overlap(check_mem_overlap);
|
||||
iter.add_output(out);
|
||||
iter.add_input(a);
|
||||
iter.add_input(b);
|
||||
iter.allow_cpu_scalars_ = true;
|
||||
iter.compute_common_dtype_only_for_inputs();
|
||||
iter.build();
|
||||
return iter;
|
||||
}
|
||||
|
||||
TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a,
|
||||
bool check_mem_overlap) {
|
||||
auto iter = TensorIterator();
|
||||
|
|
|
|||
|
|
@ -156,6 +156,8 @@ struct CAFFE2_API TensorIterator {
|
|||
|
||||
static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b,
|
||||
bool check_mem_overlap = false);
|
||||
static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b,
|
||||
bool check_mem_overlap = false);
|
||||
static TensorIterator unary_op(Tensor& out, const Tensor& a,
|
||||
bool check_mem_overlap = false);
|
||||
static TensorIterator nullary_op(Tensor& out);
|
||||
|
|
|
|||
|
|
@ -103,6 +103,24 @@ void logical_xor_kernel(TensorIterator& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
void lt_kernel(TensorIterator& iter) {
|
||||
if (iter.dtype() == ScalarType::Bool) {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16, iter.input_dtype(), "lt_cpu", [&]() {
|
||||
cpu_kernel(iter,
|
||||
[=](scalar_t a, scalar_t b) -> bool {
|
||||
return a < b;
|
||||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "lt_cpu", [&]() {
|
||||
cpu_kernel(iter,
|
||||
[=](scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a < b;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
|
|
@ -112,5 +130,6 @@ REGISTER_DISPATCH(mul_stub, &mul_kernel);
|
|||
REGISTER_DISPATCH(div_stub, &div_kernel);
|
||||
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
|
||||
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel);
|
||||
REGISTER_DISPATCH(lt_stub, <_kernel);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -76,11 +76,28 @@ void logical_xor_kernel_cuda(TensorIterator& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
void lt_kernel_cuda(TensorIterator& iter) {
|
||||
if (iter.dtype() == ScalarType::Bool) {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "lt_cpu", [&]() {
|
||||
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
|
||||
return a < b;
|
||||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "lt_cpu", [&]() {
|
||||
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a < b;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(add_stub, &add_kernel_cuda);
|
||||
REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda);
|
||||
REGISTER_DISPATCH(div_stub, &div_kernel_cuda);
|
||||
REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda);
|
||||
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
|
||||
REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel_cuda);
|
||||
REGISTER_DISPATCH(lt_stub, <_kernel_cuda);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -90,26 +90,6 @@ Tensor gather_cuda(const Tensor & self, int64_t dim, const Tensor & index, bool
|
|||
return legacy::cuda::_th_gather(self, dim, index);
|
||||
}
|
||||
|
||||
Tensor & lt_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) {
|
||||
if (result.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use 'out' parameter with dtype torch.bool instead.");
|
||||
return legacy::cuda::_th_lt_byte_out(result, self, other);
|
||||
} else {
|
||||
return legacy::cuda::_th_lt_out(result, self, other);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor & lt_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) {
|
||||
if (result.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
|
||||
"please use 'out' parameter with dtype torch.bool instead.");
|
||||
return legacy::cuda::_th_lt_byte_out(result, self, value);
|
||||
} else {
|
||||
return legacy::cuda::_th_lt_out(result, self, value);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor & le_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) {
|
||||
if (result.dtype() == at::ScalarType::Byte) {
|
||||
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
|
||||
|
|
|
|||
|
|
@ -4008,16 +4008,10 @@
|
|||
- func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
use_c10_dispatcher: unboxed_only
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_th_lt_
|
||||
CUDA: legacy::cuda::_th_lt_
|
||||
|
||||
- func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
|
||||
use_c10_dispatcher: unboxed_only
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_th_lt_
|
||||
CUDA: legacy::cuda::_th_lt_
|
||||
|
||||
- func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
use_c10_dispatcher: unboxed_only
|
||||
|
|
@ -4614,30 +4608,30 @@
|
|||
|
||||
- func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: lt_scalar_out_cpu
|
||||
CUDA: lt_scalar_out_cuda
|
||||
CPU: lt_out
|
||||
CUDA: lt_out
|
||||
QuantizedCPU: lt_out_quantized_cpu
|
||||
|
||||
- func: lt.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_th_lt
|
||||
CUDA: legacy::cuda::_th_lt
|
||||
CPU: lt
|
||||
CUDA: lt
|
||||
QuantizedCPU: lt_quantized_cpu
|
||||
|
||||
- func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: lt_out_cpu
|
||||
CUDA: lt_out_cuda
|
||||
CPU: lt_out
|
||||
CUDA: lt_out
|
||||
QuantizedCPU: lt_out_quantized_cpu
|
||||
|
||||
- func: lt.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_th_lt
|
||||
CUDA: legacy::cuda::_th_lt
|
||||
CPU: lt
|
||||
CUDA: lt
|
||||
QuantizedCPU: lt_quantized_cpu
|
||||
|
||||
- func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ TEST(TensorIteratorTest, MixedDevices) {
|
|||
|
||||
Tensor random_tensor_for_type(at::ScalarType scalar_type) {
|
||||
if (at::isFloatingType(scalar_type)) {
|
||||
return at::randn({5, 5}, kCPU);
|
||||
return at::randn({5, 5}, at::device(kCPU).dtype(scalar_type));
|
||||
} else {
|
||||
return at::randint(1, 10, {5, 5}, kCPU);
|
||||
return at::randint(1, 10, {5, 5}, at::device(kCPU).dtype(scalar_type));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,9 +89,30 @@ TEST(TensorIteratorTest, SerialLoopPointwise_##name) {
|
|||
ASSERT_ANY_THROW(out.equal(expected)); \
|
||||
}
|
||||
|
||||
// The alternative way to calculate a < b is (b - a).clamp(0).toBool()
|
||||
// To prevent an overflow in subtraction (b - a) for unsigned types(unit, bool)
|
||||
// we will convert in to int first
|
||||
#define COMPARISON_TEST_ITER_FOR_TYPE(ctype,name) \
|
||||
TEST(TensorIteratorTest, ComparisonLoopBinary_##name) { \
|
||||
auto in1 = random_tensor_for_type(k##name); \
|
||||
auto in2 = random_tensor_for_type(k##name); \
|
||||
Tensor out = at::empty({0}, in1.options().dtype(kBool)); \
|
||||
Tensor diff; \
|
||||
if (k##name == kByte || k##name == kBool) { \
|
||||
diff = in2.to(kInt).sub(in1.to(kInt)); \
|
||||
} else { \
|
||||
diff = in2.sub(in1); \
|
||||
} \
|
||||
auto expected = diff.clamp_min(0).to(kBool); \
|
||||
auto iter = TensorIterator::comparison_op(out, in1, in2, true); \
|
||||
at::native::cpu_serial_kernel(iter, [=](ctype a, ctype b) -> bool { return a < b; }); \
|
||||
EXPECT_TRUE(out.equal(expected)); \
|
||||
}
|
||||
|
||||
AT_FORALL_SCALAR_TYPES(UNARY_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES(BINARY_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES(POINTWISE_TEST_ITER_FOR_TYPE)
|
||||
AT_FORALL_SCALAR_TYPES_AND(Bool, COMPARISON_TEST_ITER_FOR_TYPE)
|
||||
|
||||
TEST(TensorIteratorTest, SerialLoopSingleThread) {
|
||||
std::thread::id thread_id = std::this_thread::get_id();
|
||||
|
|
@ -142,16 +163,6 @@ TEST(TensorIteratorTest, DoNotComputeCommonDTypeInputOnly) {
|
|||
EXPECT_TRUE(iter.dtype(2) == at::kDouble);
|
||||
}
|
||||
|
||||
TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfInputSameAsOutput) {
|
||||
Tensor inout = at::ones({1, 1}, at::dtype(at::kFloat));
|
||||
auto iter = at::TensorIterator();
|
||||
iter.add_output(inout);
|
||||
iter.add_input(inout);
|
||||
iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble)));
|
||||
iter.compute_common_dtype_only_for_inputs();
|
||||
ASSERT_ANY_THROW(iter.build());
|
||||
}
|
||||
|
||||
TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) {
|
||||
Tensor out;
|
||||
auto iter = at::TensorIterator();
|
||||
|
|
|
|||
|
|
@ -35,8 +35,7 @@
|
|||
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb, \
|
||||
*r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
|
||||
}
|
||||
|
||||
TENSOR_IMPLEMENT_LOGICAL(lt,<)
|
||||
|
||||
TENSOR_IMPLEMENT_LOGICAL(gt,>)
|
||||
TENSOR_IMPLEMENT_LOGICAL(le,<=)
|
||||
TENSOR_IMPLEMENT_LOGICAL(ge,>=)
|
||||
|
|
@ -57,7 +56,6 @@ TENSOR_IMPLEMENT_LOGICAL(ne,!=)
|
|||
*r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
|
||||
} \
|
||||
|
||||
TENSOR_IMPLEMENT_LOGICAL_BYTE(lt,<)
|
||||
TENSOR_IMPLEMENT_LOGICAL_BYTE(gt,>)
|
||||
TENSOR_IMPLEMENT_LOGICAL_BYTE(le,<=)
|
||||
TENSOR_IMPLEMENT_LOGICAL_BYTE(ge,>=)
|
||||
|
|
|
|||
|
|
@ -8,13 +8,6 @@
|
|||
#include <THC/THCNumerics.cuh>
|
||||
#include <THC/THCReduce.cuh>
|
||||
|
||||
template <typename T, typename TOut>
|
||||
struct TensorLTOp {
|
||||
__device__ inline void operator()(TOut* out, T* a, T* b) {
|
||||
*out = ScalarConvert<bool, TOut>::to(THCNumerics<T>::lt(*a, *b));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename TOut>
|
||||
struct TensorGTOp {
|
||||
__device__ inline void operator()(TOut* out, T* a, T* b) {
|
||||
|
|
|
|||
|
|
@ -2,14 +2,6 @@
|
|||
#define THC_GENERIC_FILE "THC/generic/THCTensorMathCompareT.cu"
|
||||
#else
|
||||
|
||||
void THCTensor_(ltTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2)
|
||||
{
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
|
||||
THC_logicalTensor<bool, scalar_t>(state, self_, src1, src2,
|
||||
TensorLTOp<scalar_t,
|
||||
bool>());
|
||||
}
|
||||
|
||||
void THCTensor_(gtTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2)
|
||||
{
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
|
||||
|
|
@ -50,13 +42,6 @@ void THCTensor_(neTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *s
|
|||
bool>());
|
||||
}
|
||||
|
||||
void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
|
||||
{
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
|
||||
THC_logicalTensor<scalar_t, scalar_t>(state, self_, src1, src2,
|
||||
TensorLTOp<scalar_t,
|
||||
scalar_t>());
|
||||
}
|
||||
|
||||
void THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
|
||||
{
|
||||
|
|
@ -98,13 +83,6 @@ void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, T
|
|||
scalar_t>());
|
||||
}
|
||||
|
||||
void THCTensor_(ltTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
|
||||
{
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2));
|
||||
THC_logicalTensor<unsigned char, scalar_t>(state, self_, src1, src2,
|
||||
TensorLTOp<scalar_t,
|
||||
unsigned char>());
|
||||
}
|
||||
|
||||
void THCTensor_(gtTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -2,21 +2,18 @@
|
|||
#define THC_GENERIC_FILE "THC/generic/THCTensorMathCompareT.h"
|
||||
#else
|
||||
|
||||
THC_API void THCTensor_(ltTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(gtTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(leTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(geTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(eqTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(neTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
|
||||
THC_API void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(leTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(geTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(eqTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
|
||||
THC_API void THCTensor_(ltTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(gtTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(leTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
THC_API void THCTensor_(geTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2);
|
||||
|
|
|
|||
|
|
@ -6205,6 +6205,23 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
|||
for idx in iter_indices(x):
|
||||
self.assertEqual(x[idx] >= y[idx], ge[idx] == 1)
|
||||
|
||||
def test_comparison_ops_must_take_bool_output(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'The output tensor of lt must be a bool'):
|
||||
torch.lt(torch.tensor([True]), torch.tensor([False]), out=torch.empty(1, dtype=torch.uint8))
|
||||
|
||||
def test_inplace_comparison_ops_require_inputs_have_same_dtype(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'):
|
||||
torch.tensor([1], dtype=torch.int).lt_(torch.tensor([2], dtype=torch.long))
|
||||
|
||||
def test_comparison_ops_check_for_scalar_overflow(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
||||
torch.tensor([1 << 5], dtype=torch.uint8) < (1 << 20)
|
||||
|
||||
def test_comparison_ops_check_for_zerodim_tensor_overflow(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'):
|
||||
torch.tensor([1 << 5], dtype=torch.uint8) < torch.tensor(1 << 20, dtype=torch.int32)
|
||||
torch.tensor(1 << 40, dtype=torch.int64) < torch.tensor([1 << 30], dtype=torch.int32)
|
||||
|
||||
def test_bitwise_ops(self):
|
||||
x = torch.randn(5, 5).gt(0)
|
||||
y = torch.randn(5, 5).gt(0)
|
||||
|
|
@ -6740,11 +6757,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
|||
e1.fill_diagonal_(v, wrap=True)
|
||||
self.assertEqual(e1, e2)
|
||||
|
||||
def test_function_unwrap_message(self):
|
||||
self.assertRaisesRegex(RuntimeError, ' call to _th_lt',
|
||||
lambda: torch.ones(1, dtype=torch.float) < torch.ones(1, dtype=torch.double))
|
||||
|
||||
|
||||
# Functions to test negative dimension wrapping
|
||||
METHOD = 1
|
||||
INPLACE_METHOD = 2
|
||||
|
|
@ -11145,10 +11157,6 @@ class TestTorchDeviceType(TestCase):
|
|||
byteRes = torch.empty_like(x, device=device).byte()
|
||||
boolRes = torch.empty_like(x, device=device).bool()
|
||||
|
||||
torch.lt(x, b, out=byteRes)
|
||||
torch.lt(x, b, out=boolRes)
|
||||
self.assertEqual(byteRes.bool(), boolRes)
|
||||
|
||||
torch.le(x, b, out=byteRes)
|
||||
torch.le(x, b, out=boolRes)
|
||||
self.assertEqual(byteRes.bool(), boolRes)
|
||||
|
|
@ -11169,7 +11177,7 @@ class TestTorchDeviceType(TestCase):
|
|||
torch.ne(x, b, out=boolRes)
|
||||
self.assertEqual(byteRes.bool(), boolRes)
|
||||
|
||||
self.assertEquals(len(warningsCount), 6)
|
||||
self.assertEquals(len(warningsCount), 5)
|
||||
|
||||
# Bool Tensor
|
||||
x = torch.tensor([True, False, True, False], device=device)
|
||||
|
|
|
|||
|
|
@ -227,6 +227,87 @@ class TestTypePromotion(TestCase):
|
|||
self.assertEqual(torch.result_type(torch.tensor([1., 1.], dtype=torch.float), 1.), torch.float)
|
||||
self.assertEqual(torch.result_type(torch.tensor(1., dtype=torch.float), torch.tensor(1, dtype=torch.double)), torch.double)
|
||||
|
||||
def test_comparison_ops_with_type_promotion(self):
|
||||
value_for_type = {
|
||||
torch.uint8: (1 << 5),
|
||||
torch.int8: (1 << 5),
|
||||
torch.int16: (1 << 10),
|
||||
torch.int32: (1 << 20),
|
||||
torch.int64: (1 << 35),
|
||||
torch.float16: (1 << 10),
|
||||
torch.float32: (1 << 20),
|
||||
torch.float64: (1 << 35)
|
||||
}
|
||||
comparison_ops = [
|
||||
dict(
|
||||
name="lt",
|
||||
out_op=lambda x, y, d: torch.lt(x, y, out=torch.empty(1, dtype=torch.bool, device=d)),
|
||||
ret_op=lambda x, y: torch.lt(x, y),
|
||||
compare_op=lambda x, y: x < y,
|
||||
),
|
||||
]
|
||||
device = self.device
|
||||
for op in comparison_ops:
|
||||
for dt1 in torch.testing.get_all_math_dtypes(device):
|
||||
for dt2 in torch.testing.get_all_math_dtypes(device):
|
||||
val1 = value_for_type[dt1]
|
||||
val2 = value_for_type[dt2]
|
||||
t1 = torch.tensor([val1], dtype=dt1, device=device)
|
||||
t2 = torch.tensor([val2], dtype=dt2, device=device)
|
||||
expected = torch.tensor([op["compare_op"](val1, val2)], dtype=torch.bool)
|
||||
|
||||
out_res = op["out_op"](t1, t2, device)
|
||||
self.assertEqual(out_res, expected)
|
||||
self.assertTrue(out_res.dtype == torch.bool)
|
||||
self.assertTrue(t1.dtype == dt1)
|
||||
self.assertTrue(t2.dtype == dt2)
|
||||
|
||||
out_res = op["ret_op"](t1, t2)
|
||||
self.assertEqual(out_res, expected)
|
||||
self.assertTrue(out_res.dtype == torch.bool)
|
||||
self.assertTrue(t1.dtype == dt1)
|
||||
self.assertTrue(t2.dtype == dt2)
|
||||
|
||||
# test that comparing a zero dim tensor with another zero dim tensor has type promotion behavior
|
||||
t1 = torch.tensor(val1, dtype=dt1, device=device)
|
||||
t2 = torch.tensor(val2, dtype=dt2, device=device)
|
||||
expected = torch.tensor(op["compare_op"](val1, val2), dtype=torch.bool)
|
||||
|
||||
out_res = op["out_op"](t1, t2, device)
|
||||
self.assertEqual(out_res, expected)
|
||||
self.assertTrue(out_res.dtype == torch.bool)
|
||||
self.assertTrue(t1.dtype == dt1)
|
||||
self.assertTrue(t2.dtype == dt2)
|
||||
|
||||
out_res = op["ret_op"](t1, t2)
|
||||
self.assertEqual(out_res, expected)
|
||||
self.assertTrue(out_res.dtype == torch.bool)
|
||||
self.assertTrue(t1.dtype == dt1)
|
||||
self.assertTrue(t2.dtype == dt2)
|
||||
|
||||
def test_lt_with_type_promotion(self):
|
||||
for dt in torch.testing.get_all_math_dtypes(self.device):
|
||||
x = torch.tensor([0], dtype=dt, device=self.device)
|
||||
expected = torch.tensor([True], dtype=torch.bool, device=self.device)
|
||||
|
||||
actual = x < 0.5
|
||||
self.assertTrue(actual, expected)
|
||||
self.assertTrue(actual.dtype == torch.bool)
|
||||
|
||||
actual = x < torch.tensor(0.5)
|
||||
self.assertTrue(actual, expected)
|
||||
self.assertTrue(actual.dtype == torch.bool)
|
||||
|
||||
x = torch.tensor(0, dtype=dt, device=self.device)
|
||||
expected = torch.tensor(True, dtype=torch.bool, device=self.device)
|
||||
actual = x < 0.5
|
||||
self.assertTrue(actual, expected)
|
||||
self.assertTrue(actual.dtype == torch.bool)
|
||||
|
||||
actual = x < torch.tensor(0.5)
|
||||
self.assertTrue(actual, expected)
|
||||
self.assertTrue(actual.dtype == torch.bool)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "no cuda")
|
||||
class TestTypePromotionCuda(TestTypePromotion):
|
||||
def setUp(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user